modopt.JaxProblem

class modopt.JaxProblem(x0, nc=None, name='unnamed_problem', jax_obj=None, jax_con=None, xl=None, xu=None, cl=None, cu=None, x_scaler=1.0, o_scaler=1.0, c_scaler=1.0, grad_free=False, order=1)[source]

Class that wraps jittable Jax functions for objective and constraints. Depending on the order specified, this class will automatically generate the functions for the objective gradient, constraint Jacobian, objective Hessian, Lagrangian, Lagrangian gradient, and Lagrangian Hessian. All functions will be turned into jitted functions and then wrapped for use with Optimizer subclasses. Vector products (HVP, JVP, VJP) are not supported.

__init__(x0, nc=None, name='unnamed_problem', jax_obj=None, jax_con=None, xl=None, xu=None, cl=None, cu=None, x_scaler=1.0, o_scaler=1.0, c_scaler=1.0, grad_free=False, order=1)[source]

Initialize the optimization problem with the given design variables, objective, and constraints. Derivatives are automatically generated using Jax.

Parameters
namestr, default=’unnamed_problem’

Problem name assigned by the user.

x0np.ndarray

Initial guess for design variables.

ncint

Number of constraints. Used for determining jacfwd or jacrev. If None, jacrev is used.

jax_objcallable

Objective function written in Jax. (must be jittable) Signature: jax_obj(x: jnp.array) -> jnp.float

jax_concallable

Constraint function written in Jax. (must be jittable) Signature: jax_con(x: jnp.array) -> jnp.array

xlfloat or np.ndarray

Lower bounds on design variables.

xufloat or np.ndarray

Upper bounds on design variables.

clfloat or np.ndarray

Lower bounds on constraints.

cufloat or np.ndarray

Upper bounds on constraints.

x_scalerfloat or np.ndarray

Scaling factor for design variables.

o_scalerfloat

Scaling factor for the objective function.

c_scalerfloat or np.ndarray

Scaling factor for constraints.

grad_freebool, default=False

Flag to indicate if the problem is gradient-free. If True, JaxProblem will not generate any derivatives.

order{1, 2}, default=1

Order of the problem if grad_free=False. Used for determining up to which order of derivatives need to be generated.

__str__()

Print the details of the UNSCALED optimization problem.