Part 2: Custom Tasks, Task Families, and Performance Improvements

In this part, we will look at how to define custom tasks and datasets. We will also consider families of tasks, which are common specifications of meta-learning problems. Finally, we will look at how to efficiently parallelize over tasks during training.

Prerequisites

This document assumes knowledge of JAX which is covered in depth at the JAX Docs. In particular, we would recomend making your way through JAX tutorial 101. We also recommend that you have worked your way through Part 1.

!pip install git+https://github.com/google/learned_optimization.git
import numpy as np
import jax.numpy as jnp
import jax
from matplotlib import pylab as plt

from learned_optimization.outer_trainers import full_es
from learned_optimization.outer_trainers import truncated_pes
from learned_optimization.outer_trainers import gradient_learner
from learned_optimization.outer_trainers import truncation_schedule

from learned_optimization.tasks import quadratics
from learned_optimization.tasks.fixed import image_mlp
from learned_optimization.tasks import base as tasks_base
from learned_optimization.tasks.datasets import base as datasets_base

from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.learned_optimizers import mlp_lopt
from learned_optimization.optimizers import base as opt_base

from learned_optimization import optimizers
from learned_optimization import eval_training

import haiku as hk
import tqdm

Defining a custom Dataset

The dataset’s in this library consists of iterators which yield batches of the corresponding data. For the provided tasks, these dataset have 4 splits of data rather than the traditional 3. We have “train” which is data used by the task to train a model, “inner_valid” which contains validation data for use when inner training (training an instance of a task). This could be use for, say, picking hparams. “outer_valid” which is used to meta-train with – this is unseen in inner training and thus serves as a basis to train learned optimizers against. “test” which can be used to test the learned optimizer with.

To make a dataset, simply write 4 iterators with these splits.

For performance reasons, creating these iterators cannot be slow. The existing dataset’s make extensive use of caching to share iterators across tasks which use the same data iterators. To account for this reuse, it is expected that these iterators are always randomly sampling data and have a large shuffle buffer so as to not run into any sampling issues.

import numpy as np


def data_iterator():
  bs = 3
  while True:
    batch = {"data": np.zeros([bs, 5])}
    yield batch


@datasets_base.dataset_lru_cache
def get_datasets():
  return datasets_base.Datasets(
      train=data_iterator(),
      inner_valid=data_iterator(),
      outer_valid=data_iterator(),
      test=data_iterator())


ds = get_datasets()
next(ds.train)
{'data': array([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])}

Defining a custom Task

To define a custom class, one simply needs to write a base class of Task. Let’s look at a simple task consisting of a quadratic task with noisy targets.

# First we construct data iterators.
def noise_datasets():

  def _fn():
    while True:
      yield np.random.normal(size=[4, 2]).astype(dtype=np.float32)

  return datasets_base.Datasets(
      train=_fn(), inner_valid=_fn(), outer_valid=_fn(), test=_fn())


class MyTask(tasks_base.Task):
  datasets = noise_datasets()

  def loss(self, params, rng, data):
    return jnp.sum(jnp.square(params - data))

  def init(self, key):
    return jax.random.normal(key, shape=(4, 2))


task = MyTask()
key = jax.random.PRNGKey(0)
key1, key = jax.random.split(key)
params = task.init(key)

task.loss(params, key1, next(task.datasets.train))
DeviceArray(10.503748, dtype=float32)

Meta-training on multiple tasks: TaskFamily

What we have shown previously was meta-training on a single task instance. While sometimes this is sufficient for a given situation, in many situations we seek to meta-train a meta-learning algorithm such as a learned optimizer on a mixture of different tasks.

One path to do this is to simply run more than one meta-gradient computation, each with different tasks, average the gradients, and perform one meta-update. This works great when the tasks are quite different – e.g. meta-gradients when training a convnet vs a MLP. A big negative to this is that these meta-gradient calculations are happening sequentially, and thus making poor use of hardware accelerators like GPU or TPU.

As a solution to this problem, we have an abstraction of a TaskFamily to enable better use of hardware. A TaskFamily represents a distribution over a set of tasks and specifies particular samples from this distribution as a pytree of jax types.

The function to sample these configurations is called sample, and the function to get a task from the sampled config is task_fn. TaskFamily also optionally contain datasets which are shared for all the Task it creates.

As a simple example, let’s consider a family of quadratics parameterized by meansquared error to some point which itself is sampled.

PRNGKey = jnp.ndarray
TaskParams = jnp.ndarray


class FixedDimQuadraticFamily(tasks_base.TaskFamily):
  """A simple TaskFamily with a fixed dimensionality but sampled target."""

  def __init__(self, dim: int):
    super().__init__()
    self._dim = dim
    self.datasets = None

  def sample(self, key: PRNGKey) -> TaskParams:
    # Sample the target for the quadratic task.
    return jax.random.normal(key, shape=(self._dim,))

  def task_fn(self, task_params: TaskParams) -> tasks_base.Task:
    dim = self._dim

    class _Task(tasks_base.Task):

      def loss(self, params, rng, _):
        # Compute MSE to the target task.
        return jnp.sum(jnp.square(task_params - params))

      def init(self, key):
        return jax.random.normal(key, shape=(dim,))

    return _Task()

With this task family defined, we can create instances by sampling a configuration and creating a task. This task acts like any other task in that it has an init and a loss function.

task_family = FixedDimQuadraticFamily(10)
key = jax.random.PRNGKey(0)
task_cfg = task_family.sample(key)
task = task_family.task_fn(task_cfg)

key1, key = jax.random.split(key)
params = task.init(key)
batch = None
task.loss(params, key, batch)
DeviceArray(13.190405, dtype=float32)

To achive speedups, we can now leverage jax.vmap to train multiple task instances in parallel! Depending on the task, this can be considerably faster than serially executing them.

def train_task(cfg, key):
  task = task_family.task_fn(cfg)
  key1, key = jax.random.split(key)
  params = task.init(key1)
  opt = opt_base.Adam()

  opt_state = opt.init(params)

  for i in range(4):
    params = opt.get_params(opt_state)
    loss, grad = jax.value_and_grad(task.loss)(params, key, None)
    opt_state = opt.update(opt_state, grad, loss=loss)
  loss = task.loss(params, key, None)
  return loss


task_cfg = task_family.sample(key)
print("single loss", train_task(task_cfg, key))

keys = jax.random.split(key, 32)
task_cfgs = jax.vmap(task_family.sample)(keys)
losses = jax.vmap(train_task)(task_cfgs, keys)
print("multiple losses", losses)
single loss 10.973224
multiple losses [28.484756  15.884144  10.12129   17.281586  18.210754  17.650654
 31.202633  20.745605  21.301374  36.30536   22.189842  21.358437
 13.802605  16.462059  13.092703  25.175426  23.442476  13.078012
 20.773136  15.165912  23.114235  24.486801  31.850758  11.04059
  5.795575  26.002295  31.550493   2.9317625 10.598424  18.45548
 24.402779  20.770353 ]

Because of this ability to apply vmap over task families, this is the main building block for a number of the high level libraries in this package. Single tasks can always be converted to a task family with:

single_task = image_mlp.ImageMLP_FashionMnist8_Relu32()
task_family = tasks_base.single_task_to_family(single_task)

This wrapper task family has no configuable value and always returns the base task.

cfg = task_family.sample(key)
print("config only contains a dummy value:", cfg)
task = task_family.task_fn(cfg)
# Tasks are the same
assert task == single_task
config only contains a dummy value: 0

Limitations of TaskFamily

Task families are designed for, and only work for variation that results in a static computation graph. This is required for jax.vmap to work.

This means things like naively changing hidden sizes, or number of layers, activation functions is off the table.

In some cases, one can leverage other jax control flow such as jax.lax.cond to select between implementations. For example, one could make a TaskFamily that used one of 2 activation functions. While this works, the resulting vectorized computation could be slow and thus profiling is required to determine if this is a good idea or not.

In this code base, we use TaskFamily to mainly parameterize over different kinds of initializations.