Jax
Define a problem in Jax
This example does not intend to cover all the features of Jax.
For more details and tutorials on Jax, please refer to Jax’s documentation.
For more details on the JaxProblem class, please see the API Reference.
In this example, we solve a constrained problem given by
We know the solution of this problem is \(x_1=1\), and \(x_2=0\). However, we start from an initial guess of \(x_1=500.0\), and \(x_2=5.0\) for the purposes of this tutorial.
The problem functions are written using Jax functions as follows:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
# minimize x^2 + y^2 subject to x>=0, x+y=1, x-y>=1.
jax_obj = lambda x: jnp.sum(x ** 2)
jax_con = lambda x: jnp.array([x[0] + x[1], x[0] - x[1]])
modOpt will auto-generate and jit-compile the gradient, Jacobian,
and objective Hessian, as well as the Lagrangian, its gradient, and Hessian.
Users do not need to manually generate/jit-compile these functions
or their derivatives using Jax and then wrap them.
Once the problem functions are defined as Jax functions,
create a JaxProblem object for modOpt by passing the above functions
along with other problem constants, such as initial guesses,
variable bounds, and constraint bounds.
import numpy as np
import modopt as mo
prob = mo.JaxProblem(x0=np.array([500., 5.]), nc=2, jax_obj=jax_obj, jax_con=jax_con,
xl=np.array([0., -np.inf]), xu=np.array([np.inf, np.inf]),
cl=np.array([1., 1.]), cu=np.array([1., np.inf]),
name='quadratic_jax', order=1)
Solve your problem using an optimizer
Once your problem model is wrapped for modOpt, import your preferred optimizer
from modOpt and solve it, following the standard procedure.
Here we will use the SLSQP optimizer from the SciPy library.
# Setup your preferred optimizer (SLSQP) with the Problem object
# Pass in the options for your chosen optimizer
optimizer = mo.SLSQP(prob, solver_options={'maxiter':20})
# Check first derivatives at the initial guess, if needed
optimizer.check_first_derivatives(prob.x0)
# Solve your optimization problem
optimizer.solve()
# Print results of optimization
optimizer.print_results()
----------------------------------------------------------------------------
Derivative type | Calc norm | FD norm | Abs error norm | Rel error norm
----------------------------------------------------------------------------
Gradient | 1.0000e+03 | 1.0000e+03 | 1.5473e-05 | 1.5472e-08
Jacobian | 2.0000e+00 | 2.0000e+00 | 5.0495e-09 | 2.5248e-09
----------------------------------------------------------------------------
Solution from Scipy SLSQP:
----------------------------------------------------------------------------------------------------
Problem : quadratic_jax
Solver : scipy-slsqp
Success : True
Message : Optimization terminated successfully
Status : 0
Total time : 0.004587888717651367
Objective : 1.0000000068019972
Gradient norm : 2.000000006801997
Total function evals : 2
Total gradient evals : 2
Major iterations : 2
Total callbacks : 17
Reused callbacks : 0
obj callbacks : 5
grad callbacks : 3
hess callbacks : 0
con callbacks : 6
jac callbacks : 3
----------------------------------------------------------------------------------------------------
Scaling API
Please refer to the code snippet below as a guide for scaling the design variables, objective, and constraints independent of their definitions.
Warning
The results provided by the optimizer will always be scaled, while the values from the models will remain unscaled.
prob = mo.JaxProblem(x0=np.array([500., 5.]), nc=2, jax_obj=jax_obj, jax_con=jax_con,
xl=np.array([0., -np.inf]), xu=np.array([np.inf, np.inf]),
cl=np.array([1., 1.]), cu=np.array([1., np.inf]),
x_scaler=2., # constant to scale all variables
# x_scaler=np.array([1., 2.]), # scaler to scale each variable differently
o_scaler=5., # objective function scaler
# c_scaler=10., # constant to scale all constraints
c_scaler=np.array([10., 100.]), # scaler to scale each constraint differently
name='quadratic_jax_scaled', order=1)
optimizer = mo.SLSQP(prob, solver_options={'maxiter':20})
optimizer.solve()
optimizer.print_results()
Solution from Scipy SLSQP:
----------------------------------------------------------------------------------------------------
Problem : quadratic_jax_scaled
Solver : scipy-slsqp
Success : True
Message : Optimization terminated successfully
Status : 0
Total time : 0.0016129016876220703
Objective : 4.999999999999996
Gradient norm : 4.999999878155281
Total function evals : 3
Total gradient evals : 2
Major iterations : 2
Total callbacks : 11
Reused callbacks : 0
obj callbacks : 3
grad callbacks : 2
hess callbacks : 0
con callbacks : 4
jac callbacks : 2
----------------------------------------------------------------------------------------------------