'''Cantilever beam optimization with Jax'''
import numpy as np
from modopt import JaxProblem, SLSQP
import time
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
E0, L0, b0, vol0, F0 = 1., 1., 0.1, 0.01, -1.
# METHOD 1: Use Jax functions directly in mo.JaxProblem.
# ModOpt will auto-generate the gradient, Jacobian, and objective Hessian.
# ModOpt will also auto-generate the Lagrangian, its gradient, and Hessian.
# No need to manually generate or jit functions or their derivatives and then wrap them.
def get_problem(n_el, order=1): # 16 statements excluding comments, and returns.
E, L, b, vol = E0, L0, b0, vol0
L_el = L / n_el
n_nodes = n_el + 1
def jax_obj(x):
# Moment of inertia
I = b * x**3 / 12
# Force vector
F = np.zeros((n_nodes*2,))
F[-2] = F0
# Stiffness matrix
c_el = E / L_el**3 * np.array([[12, 6*L_el, -12, 6*L_el],
[6*L_el, 4*L_el**2, -6*L_el, 2*L_el**2],
[-12, -6*L_el, 12, -6*L_el],
[6*L_el, 2*L_el**2, -6*L_el, 4*L_el**2]])
K = jnp.zeros((n_nodes*2, n_nodes*2))
for i in range(n_el):
K = K.at[2*i:2*i+4, 2*i:2*i+4].add(c_el * I[i])
# K[2*i:2*i+4, 2*i:2*i+4] += c_el * I[i]
# Displacement vector - solve for u in Ku = F
# Apply boundary conditions: u[0] = u[1] = 0,
# F[0:1] are unknown reaction forces at the left end. F[0:1] = K[0:1,2:].dot(u[2:])
u = jnp.concatenate((np.array([0., 0.]), jnp.linalg.solve(K[2:,2:], F[2:])))
# Compliance
c = jnp.dot(F, u)
return c
def jax_con(x):
return jnp.array([L_el * b * jnp.sum(x) - vol])
return JaxProblem(x0=np.ones(n_el), nc=1, jax_obj=jax_obj, jax_con=jax_con,
name=f'cantilever_{n_el}_jax', order=order,
xl=1e-2, cl=0., cu=0.)
if __name__ == '__main__':
# # Test to see if the problem is correctly defined
# prob = get_problem(50)
# print(prob._compute_objective(np.ones(50))) # 39.99999999905752
# print(prob._compute_constraints(np.ones(50))) # [0.09]
# exit()
# SLSQP
print('\tSLSQP \n\t-----')
n_el = 50
optimizer = SLSQP(get_problem(n_el), solver_options={'maxiter': 1000, 'ftol': 1e-9})
start_time = time.time()
optimizer.solve()
opt_time = time.time() - start_time
success = optimizer.results['success']
print('\tTime:', opt_time)
print('\tSuccess:', success)
print('\tOptimized vars:', optimizer.results['x'])
print('\tOptimized obj:', optimizer.results['fun'])
optimizer.print_results()
import matplotlib.pyplot as plt
plt.figure()
plt.plot(optimizer.results['x'])
plt.xlabel('Lengthwise location')
plt.ylabel('Optimized thickness')
plt.show()
assert np.allclose(optimizer.results['x'],
[0.14915754, 0.14764328, 0.14611321, 0.14456715, 0.14300421, 0.14142417,
0.13982611, 0.13820976, 0.13657406, 0.13491866, 0.13324268, 0.13154528,
0.12982575, 0.12808305, 0.12631658, 0.12452477, 0.12270701, 0.12086183,
0.11898809, 0.11708424, 0.11514904, 0.11318072, 0.11117762, 0.10913764,
0.10705891, 0.10493903, 0.10277539, 0.10056526, 0.09830546, 0.09599246,
0.09362243, 0.09119084, 0.08869265, 0.08612198, 0.08347229, 0.08073573,
0.07790323, 0.07496382, 0.07190453, 0.06870925, 0.0653583, 0.06182632,
0.05808044, 0.05407658, 0.04975295, 0.0450185, 0.03972912, 0.03363155,
0.02620192, 0.01610863], rtol=0, atol=1e-5)
# # METHOD 2: Create jitted Jax functions and derivatives, and
# # wrap them manually before passing to ProblemLite.
# from modopt import JaxProblem, SLSQP
# def get_problem(n_el): # 16 lines excluding comments, and returns.
# E, L, b, vol = E0, L0, b0, vol0
# L_el = L / n_el
# n_nodes = n_el + 1
# def jax_obj(x):
# # Moment of inertia
# I = b * x**3 / 12
# # Force vector
# F = np.zeros((n_nodes*2,))
# F[-2] = F0
# # Stiffness matrix
# c_el = E / L_el**3 * np.array([[12, 6*L_el, -12, 6*L_el],
# [6*L_el, 4*L_el**2, -6*L_el, 2*L_el**2],
# [-12, -6*L_el, 12, -6*L_el],
# [6*L_el, 2*L_el**2, -6*L_el, 4*L_el**2]])
# K = jnp.zeros((n_nodes*2, n_nodes*2))
# for i in range(n_el):
# K = K.at[2*i:2*i+4, 2*i:2*i+4].add(c_el * I[i])
# # K[2*i:2*i+4, 2*i:2*i+4] += c_el * I[i]
# # Displacement vector - solve for u in Ku = F
# # Apply boundary conditions: u[0] = u[1] = 0,
# # F[0:1] are unknown reaction forces at the left end. F[0:1] = K[0:1,2:].dot(u[2:])
# u = jnp.concatenate((np.array([0., 0.]), jnp.linalg.solve(K[2:,2:], F[2:])))
# # Compliance
# c = jnp.dot(F, u)
# return c
# def jax_con(x):
# return jnp.array([L_el * b * jnp.sum(x) - vol])
# _obj = jax.jit(jax_obj)
# _con = jax.jit(jax_con)
# _grad = jax.jit(jax.grad(jax_obj))
# _jac = jax.jit(jax.jacfwd(jax_con))
# obj = lambda x: np.float64(_obj(x))
# grad = lambda x: np.array(_grad(x))
# con = lambda x: np.array(_con(x))
# jac = lambda x: np.array(_jac(x))
# return ProblemLite(x0=np.ones(n_el), obj=obj, grad=grad, con=con, jac=jac,
# name=f'Cantilever beam {n_el} elements Jax',
# xl=1e-2, cl=0., cu=0.)
# if __name__ == '__main__':
# # SLSQP
# print('\tSLSQP \n\t-----')
# optimizer = SLSQP(get_problem(50), solver_options={'maxiter': 1000, 'ftol': 1e-9})
# start_time = time.time()
# optimizer.solve()
# opt_time = time.time() - start_time
# success = optimizer.results['success']
# print('\tTime:', opt_time)
# print('\tSuccess:', success)
# print('\tOptimized vars:', optimizer.results['x'])
# print('\tOptimized obj:', optimizer.results['fun'])
# optimizer.print_results()
# assert np.allclose(optimizer.results['x'],
# [0.14915754, 0.14764328, 0.14611321, 0.14456715, 0.14300421, 0.14142417,
# 0.13982611, 0.13820976, 0.13657406, 0.13491866, 0.13324268, 0.13154528,
# 0.12982575, 0.12808305, 0.12631658, 0.12452477, 0.12270701, 0.12086183,
# 0.11898809, 0.11708424, 0.11514904, 0.11318072, 0.11117762, 0.10913764,
# 0.10705891, 0.10493903, 0.10277539, 0.10056526, 0.09830546, 0.09599246,
# 0.09362243, 0.09119084, 0.08869265, 0.08612198, 0.08347229, 0.08073573,
# 0.07790323, 0.07496382, 0.07190453, 0.06870925, 0.0653583, 0.06182632,
# 0.05808044, 0.05407658, 0.04975295, 0.0450185, 0.03972912, 0.03363155,
# 0.02620192, 0.01610863], rtol=0, atol=1e-5)