modopt.JaxProblem
- class modopt.JaxProblem(x0, nc=None, name='unnamed_problem', jax_obj=None, jax_con=None, xl=None, xu=None, cl=None, cu=None, x_scaler=1.0, o_scaler=1.0, c_scaler=1.0, grad_free=False, order=1)[source]
Class that wraps jittable Jax functions for objective and constraints. Depending on the
orderspecified, this class will automatically generate the functions for the objective gradient, constraint Jacobian, objective Hessian, Lagrangian, Lagrangian gradient, and Lagrangian Hessian. All functions will be turned into jitted functions and then wrapped for use withOptimizersubclasses. Vector products (HVP, JVP, VJP) are not supported.- __init__(x0, nc=None, name='unnamed_problem', jax_obj=None, jax_con=None, xl=None, xu=None, cl=None, cu=None, x_scaler=1.0, o_scaler=1.0, c_scaler=1.0, grad_free=False, order=1)[source]
Initialize the optimization problem with the given design variables, objective, and constraints. Derivatives are automatically generated using Jax.
- Parameters
- namestr, default=’unnamed_problem’
Problem name assigned by the user.
- x0np.ndarray
Initial guess for design variables.
- ncint
Number of constraints. Used for determining jacfwd or jacrev. If None, jacrev is used.
- jax_objcallable
Objective function written in Jax. (must be jittable) Signature: jax_obj(x: jnp.array) -> jnp.float
- jax_concallable
Constraint function written in Jax. (must be jittable) Signature: jax_con(x: jnp.array) -> jnp.array
- xlfloat or np.ndarray
Lower bounds on design variables.
- xufloat or np.ndarray
Upper bounds on design variables.
- clfloat or np.ndarray
Lower bounds on constraints.
- cufloat or np.ndarray
Upper bounds on constraints.
- x_scalerfloat or np.ndarray
Scaling factor for design variables.
- o_scalerfloat
Scaling factor for the objective function.
- c_scalerfloat or np.ndarray
Scaling factor for constraints.
- grad_freebool, default=False
Flag to indicate if the problem is gradient-free. If True, JaxProblem will not generate any derivatives.
- order{1, 2}, default=1
Order of the problem if
grad_free=False. Used for determining up to which order of derivatives need to be generated.
- __str__()
Print the details of the UNSCALED optimization problem.