Summary tutorial: Getting metrics out of your models
The goal of the learned_optimization.summary
module is to seamlessly allow researchers to annotate and extract data from within a jax computation / machine learning model.
This could be anything from mean and standard deviation of an activation, to looking at the distribution of outputs.
Doing this in Jax can be challenging at times as code written in Jax can make use of a large number of function transformations making it difficult to reach in and look at a value.
This notebook discusses the learned_optimization.summary
module which provides one solution to this problem and is discussed in this notebook.
Deps
In addition to learned_optimization
, the summary module requires oryx
. This can be a bit finicky to install at the moment as it relies upon particular versions of tensorflow (even though we never use these pieces).
All of learned_optimization
will run without oryx
, but to get summaries this module must be installed.
In a colab this is even more annoying as we must first upgrade the versions of some installed modules, restart the colab kernel, and then proceed to run the remainder of the cells.
!pip install --upgrade git+https://github.com/google/learned_optimization.git oryx tensorflow==2.8.0rc0 numpy
Collecting git+https://github.com/google/learned_optimization.git
Cloning https://github.com/google/learned_optimization.git to /tmp/pip-req-build-vyc4pif3
Running command git clone -q https://github.com/google/learned_optimization.git /tmp/pip-req-build-vyc4pif3
Collecting oryx
Downloading oryx-0.2.2-py3-none-any.whl (208 kB)
|████████████████████████████████| 208 kB 5.2 MB/s
?25hCollecting tensorflow==2.8.0rc0
Downloading tensorflow-2.8.0rc0-cp37-cp37m-manylinux2010_x86_64.whl (492.1 MB)
|████████████████████████████████| 492.1 MB 15 kB/s
?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (1.19.5)
Collecting numpy
Downloading numpy-1.21.5-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
|████████████████████████████████| 15.7 MB 40.3 MB/s
?25hRequirement already satisfied: absl-py==0.12.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.12.0)
Requirement already satisfied: jax>=0.2.6 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.2.25)
Requirement already satisfied: jaxlib>=0.1.68 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.1.71+cuda111)
Collecting nose
Downloading nose-1.3.7-py3-none-any.whl (154 kB)
|████████████████████████████████| 154 kB 25.0 MB/s
?25hCollecting dm-launchpad-nightly==0.3.0.dev20211105
Downloading dm_launchpad_nightly-0.3.0.dev20211105-cp37-cp37m-manylinux2010_x86_64.whl (3.8 MB)
|████████████████████████████████| 3.8 MB 33.9 MB/s
?25hRequirement already satisfied: tqdm>=4.62.3 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (4.62.3)
Collecting flax==0.3.3
Downloading flax-0.3.3-py3-none-any.whl (179 kB)
|████████████████████████████████| 179 kB 65.4 MB/s
?25hCollecting dm-haiku==0.0.5
Downloading dm_haiku-0.0.5-py3-none-any.whl (287 kB)
|████████████████████████████████| 287 kB 71.0 MB/s
?25hCollecting optax>=0.0.9
Downloading optax-0.1.0-py3-none-any.whl (126 kB)
|████████████████████████████████| 126 kB 73.8 MB/s
?25hCollecting tensorflow-datasets>=4.4.0
Downloading tensorflow_datasets-4.4.0-py3-none-any.whl (4.0 MB)
|████████████████████████████████| 4.0 MB 56.5 MB/s
?25hRequirement already satisfied: tensorflow-metadata==1.5.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (1.5.0)
Requirement already satisfied: tensorboard>=2.7.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (2.7.0)
Requirement already satisfied: gin-config>=0.5.0 in /usr/local/lib/python3.7/dist-packages (from learned-optimization==0.0.1) (0.5.0)
Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.6.3)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.3.0)
Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.10.0.2)
Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (0.2.0)
Requirement already satisfied: gast>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (0.4.0)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.1.0)
Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.13.3)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.43.0)
Requirement already satisfied: setuptools<60 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (57.4.0)
Collecting keras<2.9,>=2.8.0rc0
Downloading keras-2.8.0rc1-py2.py3-none-any.whl (1.4 MB)
|████████████████████████████████| 1.4 MB 46.6 MB/s
?25hRequirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.15.0)
Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.17.3)
Requirement already satisfied: flatbuffers>=1.12 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (2.0)
Requirement already satisfied: libclang>=9.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (12.0.0)
Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (3.1.0)
Collecting tf-estimator-nightly==2.8.0.dev2021122109
Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
|████████████████████████████████| 462 kB 73.4 MB/s
?25hRequirement already satisfied: keras-preprocessing>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (1.1.2)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow==2.8.0rc0) (0.23.1)
Collecting jmp>=0.0.2
Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5->learned-optimization==0.0.1) (0.8.9)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (0.1.6)
Collecting mock
Downloading mock-4.0.3-py3-none-any.whl (28 kB)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (1.3.0)
Requirement already satisfied: portpicker in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (1.3.9)
Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (from dm-launchpad-nightly==0.3.0.dev20211105->learned-optimization==0.0.1) (5.4.8)
Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax==0.3.3->learned-optimization==0.0.1) (1.0.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax==0.3.3->learned-optimization==0.0.1) (3.2.2)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-metadata==1.5.0->learned-optimization==0.0.1) (1.54.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.7/dist-packages (from astunparse>=1.6.0->tensorflow==2.8.0rc0) (0.37.1)
Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py>=2.9.0->tensorflow==2.8.0rc0) (1.5.2)
Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.6->learned-optimization==0.0.1) (1.4.1)
Collecting chex>=0.0.4
Downloading chex-0.1.0-py3-none-any.whl (65 kB)
|████████████████████████████████| 65 kB 3.0 MB/s
?25hRequirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax>=0.0.9->learned-optimization==0.0.1) (0.11.2)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (1.0.1)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (1.35.0)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (0.4.6)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (0.6.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (1.8.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (3.3.6)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.7.0->learned-optimization==0.0.1) (2.23.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (4.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.7.0->learned-optimization==0.0.1) (1.3.0)
Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard>=2.7.0->learned-optimization==0.0.1) (4.10.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard>=2.7.0->learned-optimization==0.0.1) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.7.0->learned-optimization==0.0.1) (0.4.8)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.7.0->learned-optimization==0.0.1) (2.10)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.7.0->learned-optimization==0.0.1) (3.1.1)
Requirement already satisfied: importlib-resources in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (5.4.0)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (21.4.0)
Requirement already satisfied: dill in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (0.3.4)
Requirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (0.16.0)
Requirement already satisfied: promise in /usr/local/lib/python3.7/dist-packages (from tensorflow-datasets>=4.4.0->learned-optimization==0.0.1) (2.3)
Collecting jax>=0.2.6
Downloading jax-0.2.26.tar.gz (850 kB)
|████████████████████████████████| 850 kB 58.4 MB/s
?25hCollecting tfp-nightly[jax]==0.16.0.dev20211216
Downloading tfp_nightly-0.16.0.dev20211216-py2.py3-none-any.whl (6.0 MB)
|████████████████████████████████| 6.0 MB 45.6 MB/s
?25hCollecting jaxlib>=0.1.68
Downloading jaxlib-0.1.75-cp37-none-manylinux2010_x86_64.whl (62.2 MB)
|████████████████████████████████| 62.2 MB 102 kB/s
?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from tfp-nightly[jax]==0.16.0.dev20211216->oryx) (4.4.2)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (1.3.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (3.0.6)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax==0.3.3->learned-optimization==0.0.1) (0.11.0)
Building wheels for collected packages: learned-optimization, jax
Building wheel for learned-optimization (setup.py) ... ?25l?25hdone
Created wheel for learned-optimization: filename=learned_optimization-0.0.1-py3-none-any.whl size=104813 sha256=41b05ea72b8c369b2f415fec20622961dfafa85177d215a2394fb06f58357d56
Stored in directory: /tmp/pip-ephem-wheel-cache-xjiazbys/wheels/56/23/36/7a65e02fddaf71574a41e4ed9c8bd596ff2ebc35ecef5a93c5
Building wheel for jax (setup.py) ... ?25l?25hdone
Created wheel for jax: filename=jax-0.2.26-py3-none-any.whl size=985318 sha256=66375fe217775fc60a8684179be28a776160840df891965f784e60c1e8a874a1
Stored in directory: /root/.cache/pip/wheels/2f/5a/a7/b792889b2a43b6a2bdc37060ec43961b2d8d607d0019223c99
Successfully built learned-optimization jax
Installing collected packages: numpy, jaxlib, jax, tfp-nightly, tf-estimator-nightly, mock, keras, jmp, chex, tensorflow-datasets, tensorflow, optax, nose, flax, dm-launchpad-nightly, dm-haiku, oryx, learned-optimization
Attempting uninstall: numpy
Found existing installation: numpy 1.19.5
Uninstalling numpy-1.19.5:
Successfully uninstalled numpy-1.19.5
Attempting uninstall: jaxlib
Found existing installation: jaxlib 0.1.71+cuda111
Uninstalling jaxlib-0.1.71+cuda111:
Successfully uninstalled jaxlib-0.1.71+cuda111
Attempting uninstall: jax
Found existing installation: jax 0.2.25
Uninstalling jax-0.2.25:
Successfully uninstalled jax-0.2.25
Attempting uninstall: keras
Found existing installation: keras 2.7.0
Uninstalling keras-2.7.0:
Successfully uninstalled keras-2.7.0
Attempting uninstall: tensorflow-datasets
Found existing installation: tensorflow-datasets 4.0.1
Uninstalling tensorflow-datasets-4.0.1:
Successfully uninstalled tensorflow-datasets-4.0.1
Attempting uninstall: tensorflow
Found existing installation: tensorflow 2.7.0
Uninstalling tensorflow-2.7.0:
Successfully uninstalled tensorflow-2.7.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
yellowbrick 1.3.post1 requires numpy<1.20,>=1.16.0, but you have numpy 1.21.5 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.
Successfully installed chex-0.1.0 dm-haiku-0.0.5 dm-launchpad-nightly-0.3.0.dev20211105 flax-0.3.3 jax-0.2.26 jaxlib-0.1.75 jmp-0.0.2 keras-2.8.0rc1 learned-optimization-0.0.1 mock-4.0.3 nose-1.3.7 numpy-1.21.5 optax-0.1.0 oryx-0.2.2 tensorflow-2.8.0rc0 tensorflow-datasets-4.4.0 tf-estimator-nightly-2.8.0.dev2021122109 tfp-nightly-0.16.0.dev20211216
To check that everything is operating as expected we can check that the imports succeed.
import oryx
from learned_optimization import summary
assert summary.ORYX_LOGGING
Basic Example
Let’s say that we have the following function and we wish to look at the to_look_at
value.
import jax
import jax.numpy as jnp
def forward(params):
to_look_at = jnp.mean(params) * 2.
return params
def loss(parameters):
loss = jnp.mean(forward(parameters)**2)
return loss
value_grad_fn = jax.jit(jax.value_and_grad(loss))
value_grad_fn(1.0)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(DeviceArray(1., dtype=float32),
DeviceArray(2., dtype=float32, weak_type=True))
With the summary module we can first annotate this value with summary.summary
def forward(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", to_look_at)
return params
@jax.jit
def loss(parameters):
loss = jnp.mean(forward(parameters)**2)
return loss
Then we can transform the loss
function with the function transformation: summary.with_summary_output_reduced
.
This transformation goes through the computation and extracts all the tagged values and returns them to us by name in a dictionary.
In implementation, all the hard work here is done by the wonderful oryx
library (in particular harvest).
When we wrap a function this, we return a tuple containing the original result, and a dictionary with the desired metrics.
result, metrics = summary.with_summary_output_reduced(loss)(1.)
result, metrics
(DeviceArray(1., dtype=float32),
{'mean||to_look_at': DeviceArray(2., dtype=float32)})
As currently returned, the dictionary contains extra information as well as potentially duplicate values. We can collapse these metrics into a single value with the following:
summary.aggregate_metric_list([metrics])
{'mean||to_look_at': 2.0}
The keys of this dictionary first show how the metric was aggregated. In this case there is only a single metric so the aggregation is ignored. This is followed by ||
and then the summary name.
One benefit of this function transformation is it can be nested with other jax function transformations. For example we can jit the transformed function like so:
result, metrics = jax.jit(summary.with_summary_output_reduced(loss))(1.)
summary.aggregate_metric_list([metrics])
{'mean||to_look_at': 2.0}
At this point aggregate_metric_list
cannot be jit. In practice this is fine as it performs very little computation.
Aggregation of the same name summaries.
Consider the following fake function which calls a layer
function twice.
This layer
function creates a summary and thus two summary are created.
When running the transformed function we see not one, but two values returned.
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", to_look_at)
return params * 2
@jax.jit
def loss(parameters):
loss = jnp.mean(layer(layer(parameters))**2)
return loss
result, metrics = summary.with_summary_output_reduced(loss)(1.)
result, metrics
(DeviceArray(16., dtype=float32),
{'mean||to_look_at___2': DeviceArray(2., dtype=float32),
'mean||to_look_at___3': DeviceArray(4., dtype=float32)})
These values can be combined with aggregate_metric_list
as before, but this time the aggregation takes the mean. This mean
is specified by the aggregation
keyword argument in summary.summary
which defaults to mean
.
summary.aggregate_metric_list([metrics])
{'mean||to_look_at': 3.0}
Another useful aggregation mode is “sample”. As jax’s random numbers are stateless, an additional RNG key must be passed in for this to work.
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", to_look_at, aggregation="sample")
return params * 2
@jax.jit
def loss(parameters):
loss = jnp.mean(layer(layer(parameters))**2)
return loss
key = jax.random.PRNGKey(0)
result, metrics = summary.with_summary_output_reduced(loss)(
1., sample_rng_key=key)
summary.aggregate_metric_list([metrics])
{'sample||to_look_at': 4.0}
Finally, there is "collect"
which concatenates all the values together into one long tensor after first raveling all the inputs. This is useful for extracting distributions of quantities.
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary(
"to_look_at", jnp.arange(10) * to_look_at, aggregation="collect")
return params * 2
@jax.jit
def loss(parameters):
loss = jnp.mean(layer(layer(parameters))**2)
return loss
key = jax.random.PRNGKey(0)
result, metrics = summary.with_summary_output_reduced(loss)(
1., sample_rng_key=key)
summary.aggregate_metric_list([metrics])
{'collect||to_look_at': array([ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18., 0., 4., 8.,
12., 16., 20., 24., 28., 32., 36.], dtype=float32)}
Summary Scope
Sometimes it is useful to be able to group all summaries from a function inside some code block with some common name. This can be done with the summary_scope
context.
@jax.jit
def loss(parameters):
with summary.summary_scope("scope1"):
summary.summary("to_look_at", parameters)
with summary.summary_scope("nested"):
summary.summary("summary2", parameters)
with summary.summary_scope("scope2"):
summary.summary("to_look_at", parameters)
return parameters
key = jax.random.PRNGKey(0)
result, metrics = summary.with_summary_output_reduced(loss)(
1., sample_rng_key=key)
summary.aggregate_metric_list([metrics])
{'mean||nested/scope2/to_look_at': 1.0,
'mean||nested/summary2': 1.0,
'mean||scope1/to_look_at': 1.0}
Usage with function transforms.
Thanks to oryx
, all of this functionality works well across a variety of function transformations.
Here is an example with a scan, vmap, and jit.
The aggregation modes will aggregate across all timesteps, and all batched dimensions.
@jax.jit
def fn(a):
summary.summary("other_val", a[2])
def update(state, _):
s = state + 1
summary.summary("mean_loop", s[0])
summary.summary("collect_loop", s[0], aggregation="collect")
return s, s
a, _ = jax.lax.scan(update, a, jnp.arange(20))
return a * 2
vmap_fn = jax.vmap(fn)
result, metrics = jax.jit(summary.with_summary_output_reduced(vmap_fn))(
jnp.tile(jnp.arange(4), (2, 2)))
summary.aggregate_metric_list([metrics])
{'collect||collect_loop': array([ 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9,
9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17,
18, 18, 19, 19, 20, 20], dtype=int32),
'mean||mean_loop': 10.5,
'mean||other_val': 2.0}
Optionally compute metrics: @add_with_metrics
Oftentimes it is useful to define two “versions” of a function – one with metrics, and one without – as sometimes the computation of the metrics adds unneeded overhead that does not need to be run every iteration.
To create these two versions one can simply wrap the function with the add_with_summary
decorator.
This adds both a keyword argument, and an extra return to the wrapped function which switches between computing metrics, or not.
from learned_optimization import summary
import functools
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", jnp.arange(10) * to_look_at)
return params * 2
@functools.partial(jax.jit, static_argnames="with_summary")
@summary.add_with_summary
def loss(parameters):
loss = jnp.mean(layer(layer(parameters))**2)
return loss
res, metrics = loss(1., with_summary=False)
print("No metrics", summary.aggregate_metric_list([metrics]))
res, metrics = loss(1., with_summary=True)
print("With metrics", summary.aggregate_metric_list([metrics]))
No metrics {}
With metrics {'mean||to_look_at': 13.5}
Limitations and Gotchas
Requires value traced to be a function of input
At the moment summary.summary MUST be called with a descendant of whatever input is passed into the function wrapped by with_summary_output_reduced
. In practice this is almost always the case as we seek to monitor changing values rather than constants.
To demonstrate this, note how the constant value is NOT logged out, but if add it to a*0
it does become logged out.
def monitor(a):
summary.summary("with_input", a)
summary.summary("constant", 2.0)
summary.summary("constant_with_inp", 2.0 + (a * 0))
return a
result, metrics = summary.with_summary_output_reduced(monitor)(1.)
summary.aggregate_metric_list([metrics])
{'mean||constant_with_inp': 2.0, 'mean||with_input': 1.0}
The rational for why this is a bit of a rabbit hole, but it is related to how tracing in jax work and is beyond the scope of this notebook.
No support for jax.lax.cond
At this point one cannot extract summaries out of jax conditionals. Sorry. If this is a limitation to you let us know as we have some ideas to make this work.
No dynamic names
At the moment, the tag, or the name of the summary, must be a string known at compile time. There is no support for dynamic summary names.
Alternatives for extracting information
Using this module is not the only way to extract information from a model. We discuss a couple other approaches.
“Thread” metrics through
One way to extract data from a function is to simply return the things we want to look at. As functions become more complex and nested this can become quite a pain as each one of these functions must pass out metric values. This process of spreading data throughout a bunch of functions is called “threading”.
Threading also requires all pieces of code to be involved – as one must thread these metrics everywhere.
def lossb(p):
to_look_at = jnp.mean(123.)
return p * 2, to_look_at
def loss(parameters):
l = jnp.mean(parameters**2)
l, to_look_at = lossb(l)
return l, to_look_at
value_grad_fn = jax.jit(jax.value_and_grad(loss, has_aux=True))
(loss, to_look_at), g = value_grad_fn(1.0)
print(to_look_at)
123.0
jax.experimental.host_callback
Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.
One can use this to print which is a quick way to get data out of a network.
from jax.experimental import host_callback as hcb
def loss(parameters):
loss = jnp.mean(parameters**2)
to_look_at = jnp.mean(123.)
hcb.id_print(to_look_at, name="to_look_at")
return loss
value_grad_fn = jax.jit(jax.value_and_grad(loss))
_ = value_grad_fn(1.0)
name: to_look_at
123.
It is also possible to extract data out with host_callback.id_tap. We experimented with this briefly for a summary library but found both performance issues and increased complexity around custom transforms.