{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ryqPvTKI19zH" }, "source": [ "# Summary tutorial: Getting metrics out of your models\n", "\n", "The goal of the `learned_optimization.summary` module is to seamlessly allow researchers to annotate and extract data from *within* a jax computation / machine learning model.\n", "This could be anything from mean and standard deviation of an activation, to looking at the distribution of outputs.\n", "\n", "Doing this in Jax can be challenging at times as code written in Jax can make use of a large number of function transformations making it difficult to reach in and look at a value.\n", "\n", "This notebook discusses the `learned_optimization.summary` module which provides one solution to this problem and is discussed in this notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "aJfX-Mda59NC" }, "source": [ "## Deps" ] }, { "cell_type": "markdown", "metadata": { "id": "NGqJCT0MwlvI" }, "source": [ "In addition to `learned_optimization`, the summary module requires `oryx`. This can be a bit finicky to install at the moment as it relies upon particular versions of tensorflow (even though we never use these pieces).\n", "All of `learned_optimization` will run without `oryx`, but to get summaries this module must be installed.\n", "In a colab this is even more annoying as we must first upgrade the versions of some installed modules, restart the colab kernel, and then proceed to run the remainder of the cells." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "dLD5VE3EACSI", "outputId": "fe7f4bea-c9d1-4b0c-91e5-75d61a60a11f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting git+https://github.com/google/learned_optimization.git\n", " Cloning https://github.com/google/learned_optimization.git to /tmp/pip-req-build-vyc4pif3\n", " Running command git clone -q https://github.com/google/learned_optimization.git /tmp/pip-req-build-vyc4pif3\n", "Collecting oryx\n", " Downloading oryx-0.2.2-py3-none-any.whl (208 kB)\n", "\u001b[K |████████████████████████████████| 208 kB 5.2 MB/s \n", "\u001b[?25hCollecting tensorflow==2.8.0rc0\n", " Downloading tensorflow-2.8.0rc0-cp37-cp37m-manylinux2010_x86_64.whl (492.1 MB)\n", "\u001b[K |████████████████████████████████| 492.1 MB 15 kB/s \n", "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (1.19.5)\n", "Collecting numpy\n", " Downloading numpy-1.21.5-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)\n", "\u001b[K |████████████████████████████████| 15.7 MB 40.3 MB/s \n", "\u001b[?25hRequirement already satisfied: absl-py==0.12.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.12.0)\n", "Requirement already satisfied: jax>=0.2.6 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.2.25)\n", "Requirement already satisfied: jaxlib>=0.1.68 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.1.71+cuda111)\n", "Collecting nose\n", " Downloading nose-1.3.7-py3-none-any.whl (154 kB)\n", "\u001b[K |████████████████████████████████| 154 kB 25.0 MB/s \n", "\u001b[?25hCollecting dm-launchpad-nightly==0.3.0.dev20211105\n", " Downloading dm_launchpad_nightly-0.3.0.dev20211105-cp37-cp37m-manylinux2010_x86_64.whl (3.8 MB)\n", "\u001b[K |████████████████████████████████| 3.8 MB 33.9 MB/s \n", "\u001b[?25hRequirement already satisfied: tqdm>=4.62.3 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (4.62.3)\n", "Collecting flax==0.3.3\n", " Downloading flax-0.3.3-py3-none-any.whl (179 kB)\n", "\u001b[K |████████████████████████████████| 179 kB 65.4 MB/s \n", "\u001b[?25hCollecting dm-haiku==0.0.5\n", " Downloading dm_haiku-0.0.5-py3-none-any.whl (287 kB)\n", "\u001b[K |████████████████████████████████| 287 kB 71.0 MB/s \n", "\u001b[?25hCollecting optax>=0.0.9\n", " Downloading optax-0.1.0-py3-none-any.whl (126 kB)\n", "\u001b[K |████████████████████████████████| 126 kB 73.8 MB/s \n", "\u001b[?25hCollecting tensorflow-datasets>=4.4.0\n", " Downloading tensorflow_datasets-4.4.0-py3-none-any.whl (4.0 MB)\n", "\u001b[K |████████████████████████████████| 4.0 MB 56.5 MB/s \n", "\u001b[?25hRequirement already satisfied: tensorflow-metadata==1.5.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (1.5.0)\n", "Requirement already satisfied: tensorboard>=2.7.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (2.7.0)\n", "Requirement already satisfied: gin-config>=0.5.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.5.0)\n", "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.6.3)\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.3.0)\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.10.0.2)\n", "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (0.2.0)\n", "Requirement already satisfied: gast>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (0.4.0)\n", "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.1.0)\n", "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.13.3)\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.43.0)\n", "Requirement already satisfied: setuptools<60 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (57.4.0)\n", "Collecting keras<2.9,>=2.8.0rc0\n", " Downloading keras-2.8.0rc1-py2.py3-none-any.whl (1.4 MB)\n", "\u001b[K |████████████████████████████████| 1.4 MB 46.6 MB/s \n", "\u001b[?25hRequirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.15.0)\n", "Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.17.3)\n", "Requirement already satisfied: flatbuffers>=1.12 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (2.0)\n", "Requirement already satisfied: libclang>=9.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (12.0.0)\n", "Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.1.0)\n", "Collecting tf-estimator-nightly==2.8.0.dev2021122109\n", " Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)\n", "\u001b[K |████████████████████████████████| 462 kB 73.4 MB/s \n", "\u001b[?25hRequirement already satisfied: keras-preprocessing>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.1.2)\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (0.23.1)\n", "Collecting jmp>=0.0.2\n", " Downloading jmp-0.0.2-py3-none-any.whl (16 kB)\n", "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5->learned-optimization==0.0.1) (0.8.9)\n", "Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (0.1.6)\n", "Collecting mock\n", " Downloading mock-4.0.3-py3-none-any.whl (28 kB)\n", "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (1.3.0)\n", "Requirement already satisfied: portpicker in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (1.3.9)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (5.4.8)\n", "Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.3.3->learned-optimization==0.0.1) (1.0.3)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.3.3->learned-optimization==0.0.1) (3.2.2)\n", "Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-metadata==1.5.0->learned-optimization==0.0.1) (1.54.0)\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.7/dist-packages (from astunparse>=1.6.0->tensorflow==2.8.0rc0) (0.37.1)\n", "Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py>=2.9.0->tensorflow==2.8.0rc0) (1.5.2)\n", "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.6->learned-optimization==0.0.1) (1.4.1)\n", "Collecting chex>=0.0.4\n", " Downloading chex-0.1.0-py3-none-any.whl (65 kB)\n", "\u001b[K |████████████████████████████████| 65 kB 3.0 MB/s \n", "\u001b[?25hRequirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax>=0.0.9->learned-optimization==0.0.1) (0.11.2)\n", "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (1.0.1)\n", "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (1.35.0)\n", "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (0.4.6)\n", "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (0.6.1)\n", "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (1.8.1)\n", "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (3.3.6)\n", "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (2.23.0)\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (0.2.8)\n", "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (4.2.4)\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (4.8)\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.7.0->learned-optimization==0.0.1) (1.3.0)\n", "Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard>=2.7.0->learned-optimization==0.0.1) (4.10.0)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.7.0->learned-optimization==0.0.1) (3.7.0)\n", "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (0.4.8)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (1.24.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (2021.10.8)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (2.10)\n", "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.7.0->learned-optimization==0.0.1) (3.1.1)\n", "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (5.4.0)\n", "Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (21.4.0)\n", "Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (0.3.4)\n", "Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (0.16.0)\n", "Requirement already satisfied: promise in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (2.3)\n", "Collecting jax>=0.2.6\n", " Downloading jax-0.2.26.tar.gz (850 kB)\n", "\u001b[K |████████████████████████████████| 850 kB 58.4 MB/s \n", "\u001b[?25hCollecting tfp-nightly[jax]==0.16.0.dev20211216\n", " Downloading tfp_nightly-0.16.0.dev20211216-py2.py3-none-any.whl (6.0 MB)\n", "\u001b[K |████████████████████████████████| 6.0 MB 45.6 MB/s \n", "\u001b[?25hCollecting jaxlib>=0.1.68\n", " Downloading jaxlib-0.1.75-cp37-none-manylinux2010_x86_64.whl (62.2 MB)\n", "\u001b[K |████████████████████████████████| 62.2 MB 102 kB/s \n", "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from tfp-nightly[jax]==0.16.0.dev20211216->oryx) (4.4.2)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (2.8.2)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (1.3.2)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (3.0.6)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (0.11.0)\n", "Building wheels for collected packages: learned-optimization, jax\n", " Building wheel for learned-optimization (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for learned-optimization: filename=learned_optimization-0.0.1-py3-none-any.whl size=104813 sha256=41b05ea72b8c369b2f415fec20622961dfafa85177d215a2394fb06f58357d56\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-xjiazbys/wheels/56/23/36/7a65e02fddaf71574a41e4ed9c8bd596ff2ebc35ecef5a93c5\n", " Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for jax: filename=jax-0.2.26-py3-none-any.whl size=985318 sha256=66375fe217775fc60a8684179be28a776160840df891965f784e60c1e8a874a1\n", " Stored in directory: /root/.cache/pip/wheels/2f/5a/a7/b792889b2a43b6a2bdc37060ec43961b2d8d607d0019223c99\n", "Successfully built learned-optimization jax\n", "Installing collected packages: numpy, jaxlib, jax, tfp-nightly, tf-estimator-nightly, mock, keras, jmp, chex, tensorflow-datasets, tensorflow, optax, nose, flax, dm-launchpad-nightly, dm-haiku, oryx, learned-optimization\n", " Attempting uninstall: numpy\n", " Found existing installation: numpy 1.19.5\n", " Uninstalling numpy-1.19.5:\n", " Successfully uninstalled numpy-1.19.5\n", " Attempting uninstall: jaxlib\n", " Found existing installation: jaxlib 0.1.71+cuda111\n", " Uninstalling jaxlib-0.1.71+cuda111:\n", " Successfully uninstalled jaxlib-0.1.71+cuda111\n", " Attempting uninstall: jax\n", " Found existing installation: jax 0.2.25\n", " Uninstalling jax-0.2.25:\n", " Successfully uninstalled jax-0.2.25\n", " Attempting uninstall: keras\n", " Found existing installation: keras 2.7.0\n", " Uninstalling keras-2.7.0:\n", " Successfully uninstalled keras-2.7.0\n", " Attempting uninstall: tensorflow-datasets\n", " Found existing installation: tensorflow-datasets 4.0.1\n", " Uninstalling tensorflow-datasets-4.0.1:\n", " Successfully uninstalled tensorflow-datasets-4.0.1\n", " Attempting uninstall: tensorflow\n", " Found existing installation: tensorflow 2.7.0\n", " Uninstalling tensorflow-2.7.0:\n", " Successfully uninstalled tensorflow-2.7.0\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "yellowbrick 1.3.post1 requires numpy<1.20,>=1.16.0, but you have numpy 1.21.5 which is incompatible.\n", "datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.\n", "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n", "Successfully installed chex-0.1.0 dm-haiku-0.0.5 dm-launchpad-nightly-0.3.0.dev20211105 flax-0.3.3 jax-0.2.26 jaxlib-0.1.75 jmp-0.0.2 keras-2.8.0rc1 learned-optimization-0.0.1 mock-4.0.3 nose-1.3.7 numpy-1.21.5 optax-0.1.0 oryx-0.2.2 tensorflow-2.8.0rc0 tensorflow-datasets-4.4.0 tf-estimator-nightly-2.8.0.dev2021122109 tfp-nightly-0.16.0.dev20211216\n" ] }, { "data": { "application/vnd.colab-display-data+json": { "pip_warning": { "packages": [ "numpy" ] } } }, "metadata": {}, "output_type": "display_data" } ], "source": [ "!pip install --upgrade git+https://github.com/google/learned_optimization.git oryx tensorflow==2.8.0rc0 numpy" ] }, { "cell_type": "markdown", "metadata": { "id": "gzqk9f_t6f8J" }, "source": [ "To check that everything is operating as expected we can check that the imports succeed." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kbWUpquq6XSu" }, "outputs": [], "source": [ "import oryx\n", "from learned_optimization import summary\n", "assert summary.ORYX_LOGGING" ] }, { "cell_type": "markdown", "metadata": { "id": "Lz6MTETQ4R11" }, "source": [ "## Basic Example\n", "Let's say that we have the following function and we wish to look at the `to_look_at` value." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jugy6h9d4QT0" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ozBppGEc4QT0", "outputId": "2b9b497f-5cf7-45e1-dd6f-0195e13b7a24" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/plain": [ "(DeviceArray(1., dtype=float32),\n", " DeviceArray(2., dtype=float32, weak_type=True))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def forward(params):\n", " to_look_at = jnp.mean(params) * 2.\n", " return params\n", "\n", "\n", "def loss(parameters):\n", " loss = jnp.mean(forward(parameters)**2)\n", " return loss\n", "\n", "\n", "value_grad_fn = jax.jit(jax.value_and_grad(loss))\n", "value_grad_fn(1.0)" ] }, { "cell_type": "markdown", "metadata": { "id": "9WzMuzf44j0L" }, "source": [ "With the summary module we can first annotate this value with `summary.summary`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K7N39btX4jUe" }, "outputs": [], "source": [ "def forward(params):\n", " to_look_at = jnp.mean(params) * 2.\n", " summary.summary(\"to_look_at\", to_look_at)\n", " return params\n", "\n", "\n", "@jax.jit\n", "def loss(parameters):\n", " loss = jnp.mean(forward(parameters)**2)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "AL9_xgfR4yPS" }, "source": [ "Then we can transform the `loss` function with the function transformation: `summary.with_summary_output_reduced`.\n", "This transformation goes through the computation and extracts all the tagged values and returns them to us by name in a dictionary.\n", "In implementation, all the hard work here is done by the wonderful `oryx` library (in particular [harvest](https://github.com/tensorflow/probability/blob/main/spinoffs/oryx/oryx/core/interpreters/harvest.py)).\n", "When we wrap a function this, we return a tuple containing the original result, and a dictionary with the desired metrics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hZQkB6Um8PI5", "outputId": "984e4f64-7562-48ae-ca68-ff4014037553" }, "outputs": [ { "data": { "text/plain": [ "(DeviceArray(1., dtype=float32),\n", " {'mean||to_look_at': DeviceArray(2., dtype=float32)})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result, metrics = summary.with_summary_output_reduced(loss)(1.)\n", "result, metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "EJ7sPabc-zby" }, "source": [ "As currently returned, the dictionary contains extra information as well as potentially duplicate values. We can collapse these metrics into a single value with the following:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZpZxTgsF-zJv", "outputId": "920de601-0c5e-4342-8bde-cf5297dc4635" }, "outputs": [ { "data": { "text/plain": [ "{'mean||to_look_at': 2.0}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "markdown", "metadata": { "id": "tPZ-YM6I_Dhk" }, "source": [ "The keys of this dictionary first show how the metric was aggregated. In this case there is only a single metric so the aggregation is ignored. This is followed by `||` and then the summary name." ] }, { "cell_type": "markdown", "metadata": { "id": "Pd9RDqbVAB-L" }, "source": [ "One benefit of this function transformation is it can be nested with other jax function transformations. For example we can jit the transformed function like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3f7vaGON_7Dc", "outputId": "7ad7c330-a0eb-4df7-96c2-0ae6d92ac252" }, "outputs": [ { "data": { "text/plain": [ "{'mean||to_look_at': 2.0}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result, metrics = jax.jit(summary.with_summary_output_reduced(loss))(1.)\n", "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "markdown", "metadata": { "id": "ZeriYYoTJvKd" }, "source": [ "At this point `aggregate_metric_list` cannot be jit. In practice this is fine as it performs very little computation." ] }, { "cell_type": "markdown", "metadata": { "id": "FfONK6la_QYr" }, "source": [ "## Aggregation of the same name summaries.\n", "\n", "Consider the following fake function which calls a `layer` function twice.\n", "This `layer` function creates a summary and thus two summary are created.\n", "When running the transformed function we see not one, but two values returned." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SDToo5ZH-8kP", "outputId": "1a6956f1-7968-4e36-e5bf-de05334ca58e" }, "outputs": [ { "data": { "text/plain": [ "(DeviceArray(16., dtype=float32),\n", " {'mean||to_look_at___2': DeviceArray(2., dtype=float32),\n", " 'mean||to_look_at___3': DeviceArray(4., dtype=float32)})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def layer(params):\n", " to_look_at = jnp.mean(params) * 2.\n", " summary.summary(\"to_look_at\", to_look_at)\n", " return params * 2\n", "\n", "\n", "@jax.jit\n", "def loss(parameters):\n", " loss = jnp.mean(layer(layer(parameters))**2)\n", " return loss\n", "\n", "\n", "result, metrics = summary.with_summary_output_reduced(loss)(1.)\n", "result, metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "JZsHOz6J_t3-" }, "source": [ "These values can be combined with `aggregate_metric_list` as before, but this time the aggregation takes the mean. This `mean` is specified by the `aggregation` keyword argument in `summary.summary` which defaults to `mean`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TW4wQ-kB-cCQ", "outputId": "824fd411-5cad-4368-9698-d1a04e309981" }, "outputs": [ { "data": { "text/plain": [ "{'mean||to_look_at': 3.0}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "markdown", "metadata": { "id": "ZXtFMhmqANGJ" }, "source": [ "Another useful aggregation mode is \"sample\". As jax's random numbers are stateless, an additional RNG key must be passed in for this to work." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hvfgKVZhAeRW", "outputId": "fc44541a-82b5-4a2d-f103-e3b8f929aa56" }, "outputs": [ { "data": { "text/plain": [ "{'sample||to_look_at': 4.0}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def layer(params):\n", " to_look_at = jnp.mean(params) * 2.\n", " summary.summary(\"to_look_at\", to_look_at, aggregation=\"sample\")\n", " return params * 2\n", "\n", "\n", "@jax.jit\n", "def loss(parameters):\n", " loss = jnp.mean(layer(layer(parameters))**2)\n", " return loss\n", "\n", "\n", "key = jax.random.PRNGKey(0)\n", "result, metrics = summary.with_summary_output_reduced(loss)(\n", " 1., sample_rng_key=key)\n", "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "markdown", "metadata": { "id": "0Q98mny6AyNH" }, "source": [ "Finally, there is `\"collect\"` which concatenates all the values together into one long tensor after first raveling all the inputs. This is useful for extracting distributions of quantities." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SzVbsqZIA2O9", "outputId": "9fec7c59-f8bf-4aa1-fa67-3e4bc75153d2" }, "outputs": [ { "data": { "text/plain": [ "{'collect||to_look_at': array([ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18., 0., 4., 8.,\n", " 12., 16., 20., 24., 28., 32., 36.], dtype=float32)}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def layer(params):\n", " to_look_at = jnp.mean(params) * 2.\n", " summary.summary(\n", " \"to_look_at\", jnp.arange(10) * to_look_at, aggregation=\"collect\")\n", " return params * 2\n", "\n", "\n", "@jax.jit\n", "def loss(parameters):\n", " loss = jnp.mean(layer(layer(parameters))**2)\n", " return loss\n", "\n", "\n", "key = jax.random.PRNGKey(0)\n", "result, metrics = summary.with_summary_output_reduced(loss)(\n", " 1., sample_rng_key=key)\n", "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "markdown", "metadata": { "id": "nCN3ET9IQd7A" }, "source": [ "## Summary Scope\n", "Sometimes it is useful to be able to group all summaries from a function inside some code block with some common name. This can be done with the `summary_scope` context." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1SM-Ab-CQgWw", "outputId": "b9dbd717-006d-465e-c4af-b10445b08d2c" }, "outputs": [ { "data": { "text/plain": [ "{'mean||nested/scope2/to_look_at': 1.0,\n", " 'mean||nested/summary2': 1.0,\n", " 'mean||scope1/to_look_at': 1.0}" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@jax.jit\n", "def loss(parameters):\n", " with summary.summary_scope(\"scope1\"):\n", " summary.summary(\"to_look_at\", parameters)\n", "\n", " with summary.summary_scope(\"nested\"):\n", " summary.summary(\"summary2\", parameters)\n", "\n", " with summary.summary_scope(\"scope2\"):\n", " summary.summary(\"to_look_at\", parameters)\n", " return parameters\n", "\n", "\n", "key = jax.random.PRNGKey(0)\n", "result, metrics = summary.with_summary_output_reduced(loss)(\n", " 1., sample_rng_key=key)\n", "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pOEJJw2EQd2h" }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "id": "4pzFDb3FBfi-" }, "source": [ "## Usage with function transforms.\n", "Thanks to `oryx`, all of this functionality works well across a variety of function transformations.\n", "Here is an example with a scan, vmap, and jit.\n", "The aggregation modes will aggregate across all timesteps, and all batched dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FFW4l5RPBn7I", "outputId": "d621048f-921c-4d40-f6ae-9b516c12078a" }, "outputs": [ { "data": { "text/plain": [ "{'collect||collect_loop': array([ 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9,\n", " 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17,\n", " 18, 18, 19, 19, 20, 20], dtype=int32),\n", " 'mean||mean_loop': 10.5,\n", " 'mean||other_val': 2.0}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@jax.jit\n", "def fn(a):\n", " summary.summary(\"other_val\", a[2])\n", "\n", " def update(state, _):\n", " s = state + 1\n", " summary.summary(\"mean_loop\", s[0])\n", " summary.summary(\"collect_loop\", s[0], aggregation=\"collect\")\n", " return s, s\n", "\n", " a, _ = jax.lax.scan(update, a, jnp.arange(20))\n", " return a * 2\n", "\n", "\n", "vmap_fn = jax.vmap(fn)\n", "\n", "result, metrics = jax.jit(summary.with_summary_output_reduced(vmap_fn))(\n", " jnp.tile(jnp.arange(4), (2, 2)))\n", "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "markdown", "metadata": { "id": "gwpKu6neCjyN" }, "source": [ "## Optionally compute metrics: `@add_with_metrics`\n", "\n", "Oftentimes it is useful to define two \"versions\" of a function -- one with metrics, and one without -- as sometimes the computation of the metrics adds unneeded overhead that does not need to be run every iteration.\n", "To create these two versions one can simply wrap the function with the `add_with_summary` decorator.\n", "This adds both a keyword argument, and an extra return to the wrapped function which switches between computing metrics, or not." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Nlw9j1NUCuXD", "outputId": "421f5b94-5446-4fd9-d2aa-f944c54e4a97" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No metrics {}\n", "With metrics {'mean||to_look_at': 13.5}\n" ] } ], "source": [ "from learned_optimization import summary\n", "import functools\n", "\n", "\n", "def layer(params):\n", " to_look_at = jnp.mean(params) * 2.\n", " summary.summary(\"to_look_at\", jnp.arange(10) * to_look_at)\n", " return params * 2\n", "\n", "\n", "@functools.partial(jax.jit, static_argnames=\"with_summary\")\n", "@summary.add_with_summary\n", "def loss(parameters):\n", " loss = jnp.mean(layer(layer(parameters))**2)\n", " return loss\n", "\n", "\n", "res, metrics = loss(1., with_summary=False)\n", "print(\"No metrics\", summary.aggregate_metric_list([metrics]))\n", "\n", "res, metrics = loss(1., with_summary=True)\n", "print(\"With metrics\", summary.aggregate_metric_list([metrics]))" ] }, { "cell_type": "markdown", "metadata": { "id": "SnRKsVLbHA1t" }, "source": [ "## Limitations and Gotchas" ] }, { "cell_type": "markdown", "metadata": { "id": "LHCCoTCbLtpD" }, "source": [ "### Requires value traced to be a function of input\n", "\n", "At the moment summary.summary MUST be called with a descendant of whatever input is passed into the function wrapped by `with_summary_output_reduced`. In practice this is almost always the case as we seek to monitor changing values rather than constants.\n", "\n", "To demonstrate this, note how the constant value is NOT logged out, but if add it to `a*0` it does become logged out." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iIbrjrJ4HEd-", "outputId": "b1ccd1db-5615-45b9-b52c-0ae78ea9369f" }, "outputs": [ { "data": { "text/plain": [ "{'mean||constant_with_inp': 2.0, 'mean||with_input': 1.0}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def monitor(a):\n", " summary.summary(\"with_input\", a)\n", " summary.summary(\"constant\", 2.0)\n", " summary.summary(\"constant_with_inp\", 2.0 + (a * 0))\n", " return a\n", "\n", "\n", "result, metrics = summary.with_summary_output_reduced(monitor)(1.)\n", "summary.aggregate_metric_list([metrics])" ] }, { "cell_type": "markdown", "metadata": { "id": "yvEW6XfTH5_N" }, "source": [ "The rational for why this is a bit of a rabbit hole, but it is related to how tracing in jax work and is beyond the scope of this notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "wk9gKeH7LvHb" }, "source": [ "### No support for jax.lax.cond\n", "\n", "At this point one cannot extract summaries out of jax conditionals. Sorry. If this is a limitation to you let us know as we have some ideas to make this work." ] }, { "cell_type": "markdown", "metadata": { "id": "xMUpBsUUK4un" }, "source": [ "### No dynamic names\n", "\n", "At the moment, the tag, or the name of the summary, must be a string known at compile time. There is no support for dynamic summary names." ] }, { "cell_type": "markdown", "metadata": { "id": "6s4B2eUoz7nV" }, "source": [ "## Alternatives for extracting information\n", "Using this module is not the only way to extract information from a model. We discuss a couple other approaches." ] }, { "cell_type": "markdown", "metadata": { "id": "ZoekEh4J1geQ" }, "source": [ "### \"Thread\" metrics through\n", "One way to extract data from a function is to simply return the things we want to look at. As functions become more complex and nested this can become quite a pain as each one of these functions must pass out metric values. This process of spreading data throughout a bunch of functions is called \"threading\".\n", "\n", "Threading also requires all pieces of code to be involved -- as one must thread these metrics everywhere." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "q78BOx1X1uiN", "outputId": "4bb8807f-4345-4e8a-c231-9bb482222352" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "123.0\n" ] } ], "source": [ "def lossb(p):\n", " to_look_at = jnp.mean(123.)\n", " return p * 2, to_look_at\n", "\n", "\n", "def loss(parameters):\n", " l = jnp.mean(parameters**2)\n", " l, to_look_at = lossb(l)\n", " return l, to_look_at\n", "\n", "\n", "value_grad_fn = jax.jit(jax.value_and_grad(loss, has_aux=True))\n", "(loss, to_look_at), g = value_grad_fn(1.0)\n", "print(to_look_at)" ] }, { "cell_type": "markdown", "metadata": { "id": "jNt9CNJf2HJN" }, "source": [ "### jax external callbacks\n", "\n", "Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.\n", "\n", "One can use this to print which is a quick way to get data out of a network." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1Ih2LxP22MZD", "outputId": "0dd0b8ec-2c9e-414d-eadf-843122b7b8ab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: to_look_at\n", "123.\n" ] } ], "source": [ "\n", "\n", "def loss(parameters):\n", " loss = jnp.mean(parameters**2)\n", " to_look_at = jnp.mean(123.)\n", " jax.debug.print(\"to_look_at={}\", to_look_at)\n", " return loss\n", "\n", "\n", "value_grad_fn = jax.jit(jax.value_and_grad(loss))\n", "_ = value_grad_fn(1.0)" ] }, { "cell_type": "markdown", "metadata": { "id": "06Dvp3xrEA-R" }, "source": [ "It is also possible to extract data out with [host_callback.id_tap](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#jax.experimental.host_callback.id_tap). We experimented with this briefly for a summary library but found both performance issues and increased complexity around custom transforms." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HNXn4_jCRLMY" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "summary tutorial.ipynb", "provenance": [], "toc_visible": true }, "jupytext": { "formats": "ipynb,md:myst,py", "main_language": "python" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }