Source code for modopt.external_libraries.jax.jax_problem

import numpy as np
from modopt import ProblemLite

[docs]class JaxProblem(ProblemLite): ''' Class that wraps **jittable** Jax functions for objective and constraints. Depending on the ``order`` specified, this class will automatically generate the functions for the objective gradient, constraint Jacobian, objective Hessian, Lagrangian, Lagrangian gradient, and Lagrangian Hessian. All functions will be turned into jitted functions and then wrapped for use with ``Optimizer`` subclasses. Vector products (HVP, JVP, VJP) are not supported. '''
[docs] def __init__(self, x0, nc=None, name='unnamed_problem', jax_obj=None, jax_con=None, xl=None, xu=None, cl=None, cu=None, x_scaler=1., o_scaler=1., c_scaler=1., grad_free=False, order=1): ''' Initialize the optimization problem with the given design variables, objective, and constraints. Derivatives are automatically generated using Jax. Parameters ---------- name : str, default='unnamed_problem' Problem name assigned by the user. x0 : np.ndarray Initial guess for design variables. nc : int Number of constraints. Used for determining jacfwd or jacrev. If None, jacrev is used. jax_obj : callable Objective function written in Jax. (must be jittable) Signature: jax_obj(x: jnp.array) -> jnp.float jax_con : callable Constraint function written in Jax. (must be jittable) Signature: jax_con(x: jnp.array) -> jnp.array xl : float or np.ndarray Lower bounds on design variables. xu : float or np.ndarray Upper bounds on design variables. cl : float or np.ndarray Lower bounds on constraints. cu : float or np.ndarray Upper bounds on constraints. x_scaler : float or np.ndarray Scaling factor for design variables. o_scaler : float Scaling factor for the objective function. c_scaler : float or np.ndarray Scaling factor for constraints. grad_free : bool, default=False Flag to indicate if the problem is gradient-free. If True, JaxProblem will not generate any derivatives. order : {1, 2}, default=1 Order of the problem if ``grad_free=False``. Used for determining up to which order of derivatives need to be generated. ''' nx = x0.size if x0.shape != (nx,): raise ValueError(f"Initial guess 'x0' must be a numpy 1d-array.") try: import jax except ImportError: raise ImportError("'jax' could not be imported. Install 'jax' using `pip install jax[cpu]` for using JaxProblem.") import jax.numpy as jnp jax.config.update("jax_enable_x64", True) # Determine if jacrev or jacfwd should be used jacrev = True if nc is not None: if not isinstance(nc, int) or nc < 0: raise ValueError("'nc' must be an integer greater than or equal to 0.") if nc >= nx: jacrev = False obj = None grad = None obj_hess = None # obj_hvp = None if jax_obj is not None: _obj = jax.jit(jax_obj) obj = lambda x: np.float64(_obj(x)) if not grad_free: if order == 1: _grad = jax.jit(jax.grad(jax_obj)) grad = lambda x: np.array(_grad(x)) elif order == 2: _grad = jax.jit(jax.grad(jax_obj)) grad = lambda x: np.array(_grad(x)) _obj_hess = jax.jit(jax.hessian(jax_obj)) obj_hess = lambda x: np.array(_obj_hess(x)) # def jax_obj_hvp(x, v): # return jax.grad(lambda x: jnp.vdot(jax.grad(jax_obj)(x), v))(x) # _obj_hvp = jax.jit(jax_obj_hvp) # obj_hvp = lambda x, v: np.array(_obj_hvp(x, v)) # # obj_hvp = jax.grad(lambda x: jnp.vdot(jax.grad(jax_obj)(x), v))(x) else: raise ValueError(f"Higher order derivatives are not supported. 'order' must be 1 or 2.") con = None jac = None lag = None lag_grad = None lag_hess = None # lag_hvp = None if jax_con is not None: _con = jax.jit(jax_con) con = lambda x: np.array(_con(x)) if not grad_free: if jacrev: _jac = jax.jit(jax.jacrev(jax_con)) # Note: jax.jacobian = jax.jacrev else: _jac = jax.jit(jax.jacfwd(jax_con)) jac = lambda x: np.array(_jac(x)) # Lagrangian functions if jax_obj is not None: jax_lag = lambda x, lam: jax_obj(x) + jnp.dot(jax_con(x), lam) else: jax_lag = lambda x, lam: jax.dot(jax_con(x), lam) _lag = jax.jit(jax_lag) lag = lambda x, lam: np.float64(_lag(x, lam)) if not grad_free: if order==1: _lag_grad = jax.jit(jax.grad(jax_lag)) lag_grad = lambda x, lam: np.array(_lag_grad(x, lam)) elif order==2: _lag_grad = jax.jit(jax.grad(jax_lag)) lag_grad = lambda x, lam: np.array(_lag_grad(x, lam)) _lag_hess = jax.jit(jax.hessian(jax_lag)) lag_hess = lambda x, lam: np.array(_lag_hess(x, lam)) # def jax_lag_hvp(x, lam, v): # return jax.grad(lambda x, lam: jnp.vdot(jax.grad(jax_lag)(x, lam), v))(x, lam) # _lag_hvp = jax.jit(jax_lag_hvp) # lag_hvp = lambda x, lam, v: np.array(_lag_hvp(x, lam, v)) else: raise ValueError(f"Higher order derivatives are not supported. 'order' must be 1 or 2.") super().__init__(x0, name=name, obj=obj, grad=grad, obj_hess=obj_hess, con=con, jac=jac, lag=lag, lag_grad=lag_grad, lag_hess=lag_hess, xl=xl, xu=xu, cl=cl, cu=cu, x_scaler=x_scaler, o_scaler=o_scaler, c_scaler=c_scaler, grad_free=grad_free)
# obj_hvp=obj_hvp, lag_hvp=lag_hvp)