{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Jax\n", "\n", "## Define a problem in Jax\n", "\n", "This example does not intend to cover all the features of Jax.\n", "For more details and tutorials on Jax, please refer to **[Jax's documentation](https://jax.readthedocs.io/)**.\n", "For more details on the `JaxProblem` class, please see the **[API Reference](../api.md)**.\n", "In this example, we solve a constrained problem given by\n", "\n", "$$\n", "\\underset{x_1, x_2 \\in \\mathbb{R}}{\\text{minimize}} \\quad x_1^2 + x_2^2\n", "\n", "\\newline\n", "\\text{subject to} \\quad x_1 \\geq 0\n", "\\newline\n", "\\quad \\quad \\quad \\quad x_1 + x_2 = 1\n", "\\newline\n", "\\quad \\quad \\quad \\quad x_1 - x_2 \\geq 1\n", "$$\n", "\n", "We know the solution of this problem is $x_1=1$, and $x_2=0$.\n", "However, we start from an initial guess of $x_1=500.0$, and $x_2=5.0$ for the purposes of this tutorial.\n", "\n", "The problem functions are written using Jax functions as follows:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp \n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "# minimize x^2 + y^2 subject to x>=0, x+y=1, x-y>=1.\n", "\n", "jax_obj = lambda x: jnp.sum(x ** 2)\n", "jax_con = lambda x: jnp.array([x[0] + x[1], x[0] - x[1]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "modOpt will auto-generate and jit-compile the gradient, Jacobian, \n", "and objective Hessian, as well as the Lagrangian, its gradient, and Hessian.\n", "Users do not need to manually generate/jit-compile these functions \n", "or their derivatives using Jax and then wrap them.\n", "Once the problem functions are defined as Jax functions, \n", "create a `JaxProblem` object for modOpt by passing the above functions\n", "along with other problem constants, such as initial guesses, \n", "variable bounds, and constraint bounds." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import modopt as mo\n", "\n", "prob = mo.JaxProblem(x0=np.array([500., 5.]), nc=2, jax_obj=jax_obj, jax_con=jax_con,\n", " xl=np.array([0., -np.inf]), xu=np.array([np.inf, np.inf]),\n", " cl=np.array([1., 1.]), cu=np.array([1., np.inf]), \n", " name='quadratic_jax', order=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solve your problem using an optimizer\n", "\n", "Once your problem model is wrapped for modOpt, import your preferred optimizer\n", "from modOpt and solve it, following the standard procedure.\n", "Here we will use the `SLSQP` optimizer from the SciPy library." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "----------------------------------------------------------------------------\n", "Derivative type | Calc norm | FD norm | Abs error norm | Rel error norm \n", "----------------------------------------------------------------------------\n", "\n", "Gradient | 1.0000e+03 | 1.0000e+03 | 1.5473e-05 | 1.5472e-08 \n", "Jacobian | 2.0000e+00 | 2.0000e+00 | 5.0495e-09 | 2.5248e-09 \n", "----------------------------------------------------------------------------\n", "\n", "\n", "\tSolution from Scipy SLSQP:\n", "\t----------------------------------------------------------------------------------------------------\n", "\tProblem : quadratic_jax\n", "\tSolver : scipy-slsqp\n", "\tSuccess : True\n", "\tMessage : Optimization terminated successfully\n", "\tStatus : 0\n", "\tTotal time : 0.004587888717651367\n", "\tObjective : 1.0000000068019972\n", "\tGradient norm : 2.000000006801997\n", "\tTotal function evals : 2\n", "\tTotal gradient evals : 2\n", "\tMajor iterations : 2\n", "\tTotal callbacks : 17\n", "\tReused callbacks : 0\n", "\tobj callbacks : 5\n", "\tgrad callbacks : 3\n", "\thess callbacks : 0\n", "\tcon callbacks : 6\n", "\tjac callbacks : 3\n", "\t----------------------------------------------------------------------------------------------------\n" ] } ], "source": [ "# Setup your preferred optimizer (SLSQP) with the Problem object \n", "# Pass in the options for your chosen optimizer\n", "optimizer = mo.SLSQP(prob, solver_options={'maxiter':20})\n", "\n", "# Check first derivatives at the initial guess, if needed\n", "optimizer.check_first_derivatives(prob.x0)\n", "\n", "# Solve your optimization problem\n", "optimizer.solve()\n", "\n", "# Print results of optimization\n", "optimizer.print_results()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Scaling API\n", "\n", "Please refer to the code snippet below as a guide for scaling \n", "the design variables, objective, and constraints independent \n", "of their definitions.\n", "```{warning}\n", "The results provided by the optimizer will always be scaled,\n", "while the values from the models will remain unscaled.\n", "```" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\tSolution from Scipy SLSQP:\n", "\t----------------------------------------------------------------------------------------------------\n", "\tProblem : quadratic_jax_scaled\n", "\tSolver : scipy-slsqp\n", "\tSuccess : True\n", "\tMessage : Optimization terminated successfully\n", "\tStatus : 0\n", "\tTotal time : 0.0016129016876220703\n", "\tObjective : 4.999999999999996\n", "\tGradient norm : 4.999999878155281\n", "\tTotal function evals : 3\n", "\tTotal gradient evals : 2\n", "\tMajor iterations : 2\n", "\tTotal callbacks : 11\n", "\tReused callbacks : 0\n", "\tobj callbacks : 3\n", "\tgrad callbacks : 2\n", "\thess callbacks : 0\n", "\tcon callbacks : 4\n", "\tjac callbacks : 2\n", "\t----------------------------------------------------------------------------------------------------\n" ] } ], "source": [ "prob = mo.JaxProblem(x0=np.array([500., 5.]), nc=2, jax_obj=jax_obj, jax_con=jax_con,\n", " xl=np.array([0., -np.inf]), xu=np.array([np.inf, np.inf]),\n", " cl=np.array([1., 1.]), cu=np.array([1., np.inf]), \n", " x_scaler=2., # constant to scale all variables\n", " # x_scaler=np.array([1., 2.]), # scaler to scale each variable differently\n", " o_scaler=5., # objective function scaler\n", " # c_scaler=10., # constant to scale all constraints\n", " c_scaler=np.array([10., 100.]), # scaler to scale each constraint differently\n", " name='quadratic_jax_scaled', order=1)\n", "\n", "optimizer = mo.SLSQP(prob, solver_options={'maxiter':20})\n", "optimizer.solve()\n", "optimizer.print_results()" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }