# Cantilever beam optimization with Jax ```python '''Cantilever beam optimization with Jax''' import numpy as np from modopt import JaxProblem, SLSQP import time import jax import jax.numpy as jnp jax.config.update("jax_enable_x64", True) E0, L0, b0, vol0, F0 = 1., 1., 0.1, 0.01, -1. # METHOD 1: Use Jax functions directly in mo.JaxProblem. # ModOpt will auto-generate the gradient, Jacobian, and objective Hessian. # ModOpt will also auto-generate the Lagrangian, its gradient, and Hessian. # No need to manually generate or jit functions or their derivatives and then wrap them. def get_problem(n_el, order=1): # 16 statements excluding comments, and returns. E, L, b, vol = E0, L0, b0, vol0 L_el = L / n_el n_nodes = n_el + 1 def jax_obj(x): # Moment of inertia I = b * x**3 / 12 # Force vector F = np.zeros((n_nodes*2,)) F[-2] = F0 # Stiffness matrix c_el = E / L_el**3 * np.array([[12, 6*L_el, -12, 6*L_el], [6*L_el, 4*L_el**2, -6*L_el, 2*L_el**2], [-12, -6*L_el, 12, -6*L_el], [6*L_el, 2*L_el**2, -6*L_el, 4*L_el**2]]) K = jnp.zeros((n_nodes*2, n_nodes*2)) for i in range(n_el): K = K.at[2*i:2*i+4, 2*i:2*i+4].add(c_el * I[i]) # K[2*i:2*i+4, 2*i:2*i+4] += c_el * I[i] # Displacement vector - solve for u in Ku = F # Apply boundary conditions: u[0] = u[1] = 0, # F[0:1] are unknown reaction forces at the left end. F[0:1] = K[0:1,2:].dot(u[2:]) u = jnp.concatenate((np.array([0., 0.]), jnp.linalg.solve(K[2:,2:], F[2:]))) # Compliance c = jnp.dot(F, u) return c def jax_con(x): return jnp.array([L_el * b * jnp.sum(x) - vol]) return JaxProblem(x0=np.ones(n_el), nc=1, jax_obj=jax_obj, jax_con=jax_con, name=f'cantilever_{n_el}_jax', order=order, xl=1e-2, cl=0., cu=0.) if __name__ == '__main__': # # Test to see if the problem is correctly defined # prob = get_problem(50) # print(prob._compute_objective(np.ones(50))) # 39.99999999905752 # print(prob._compute_constraints(np.ones(50))) # [0.09] # exit() # SLSQP print('\tSLSQP \n\t-----') n_el = 50 optimizer = SLSQP(get_problem(n_el), solver_options={'maxiter': 1000, 'ftol': 1e-9}) start_time = time.time() optimizer.solve() opt_time = time.time() - start_time success = optimizer.results['success'] print('\tTime:', opt_time) print('\tSuccess:', success) print('\tOptimized vars:', optimizer.results['x']) print('\tOptimized obj:', optimizer.results['fun']) optimizer.print_results() import matplotlib.pyplot as plt plt.figure() plt.plot(optimizer.results['x']) plt.xlabel('Lengthwise location') plt.ylabel('Optimized thickness') plt.show() assert np.allclose(optimizer.results['x'], [0.14915754, 0.14764328, 0.14611321, 0.14456715, 0.14300421, 0.14142417, 0.13982611, 0.13820976, 0.13657406, 0.13491866, 0.13324268, 0.13154528, 0.12982575, 0.12808305, 0.12631658, 0.12452477, 0.12270701, 0.12086183, 0.11898809, 0.11708424, 0.11514904, 0.11318072, 0.11117762, 0.10913764, 0.10705891, 0.10493903, 0.10277539, 0.10056526, 0.09830546, 0.09599246, 0.09362243, 0.09119084, 0.08869265, 0.08612198, 0.08347229, 0.08073573, 0.07790323, 0.07496382, 0.07190453, 0.06870925, 0.0653583, 0.06182632, 0.05808044, 0.05407658, 0.04975295, 0.0450185, 0.03972912, 0.03363155, 0.02620192, 0.01610863], rtol=0, atol=1e-5) # # METHOD 2: Create jitted Jax functions and derivatives, and # # wrap them manually before passing to ProblemLite. # from modopt import JaxProblem, SLSQP # def get_problem(n_el): # 16 lines excluding comments, and returns. # E, L, b, vol = E0, L0, b0, vol0 # L_el = L / n_el # n_nodes = n_el + 1 # def jax_obj(x): # # Moment of inertia # I = b * x**3 / 12 # # Force vector # F = np.zeros((n_nodes*2,)) # F[-2] = F0 # # Stiffness matrix # c_el = E / L_el**3 * np.array([[12, 6*L_el, -12, 6*L_el], # [6*L_el, 4*L_el**2, -6*L_el, 2*L_el**2], # [-12, -6*L_el, 12, -6*L_el], # [6*L_el, 2*L_el**2, -6*L_el, 4*L_el**2]]) # K = jnp.zeros((n_nodes*2, n_nodes*2)) # for i in range(n_el): # K = K.at[2*i:2*i+4, 2*i:2*i+4].add(c_el * I[i]) # # K[2*i:2*i+4, 2*i:2*i+4] += c_el * I[i] # # Displacement vector - solve for u in Ku = F # # Apply boundary conditions: u[0] = u[1] = 0, # # F[0:1] are unknown reaction forces at the left end. F[0:1] = K[0:1,2:].dot(u[2:]) # u = jnp.concatenate((np.array([0., 0.]), jnp.linalg.solve(K[2:,2:], F[2:]))) # # Compliance # c = jnp.dot(F, u) # return c # def jax_con(x): # return jnp.array([L_el * b * jnp.sum(x) - vol]) # _obj = jax.jit(jax_obj) # _con = jax.jit(jax_con) # _grad = jax.jit(jax.grad(jax_obj)) # _jac = jax.jit(jax.jacfwd(jax_con)) # obj = lambda x: np.float64(_obj(x)) # grad = lambda x: np.array(_grad(x)) # con = lambda x: np.array(_con(x)) # jac = lambda x: np.array(_jac(x)) # return ProblemLite(x0=np.ones(n_el), obj=obj, grad=grad, con=con, jac=jac, # name=f'Cantilever beam {n_el} elements Jax', # xl=1e-2, cl=0., cu=0.) # if __name__ == '__main__': # # SLSQP # print('\tSLSQP \n\t-----') # optimizer = SLSQP(get_problem(50), solver_options={'maxiter': 1000, 'ftol': 1e-9}) # start_time = time.time() # optimizer.solve() # opt_time = time.time() - start_time # success = optimizer.results['success'] # print('\tTime:', opt_time) # print('\tSuccess:', success) # print('\tOptimized vars:', optimizer.results['x']) # print('\tOptimized obj:', optimizer.results['fun']) # optimizer.print_results() # assert np.allclose(optimizer.results['x'], # [0.14915754, 0.14764328, 0.14611321, 0.14456715, 0.14300421, 0.14142417, # 0.13982611, 0.13820976, 0.13657406, 0.13491866, 0.13324268, 0.13154528, # 0.12982575, 0.12808305, 0.12631658, 0.12452477, 0.12270701, 0.12086183, # 0.11898809, 0.11708424, 0.11514904, 0.11318072, 0.11117762, 0.10913764, # 0.10705891, 0.10493903, 0.10277539, 0.10056526, 0.09830546, 0.09599246, # 0.09362243, 0.09119084, 0.08869265, 0.08612198, 0.08347229, 0.08073573, # 0.07790323, 0.07496382, 0.07190453, 0.06870925, 0.0653583, 0.06182632, # 0.05808044, 0.05407658, 0.04975295, 0.0450185, 0.03972912, 0.03363155, # 0.02620192, 0.01610863], rtol=0, atol=1e-5) ```