Part 4: GradientEstimators

import numpy as np
import jax.numpy as jnp
import jax
import functools
from matplotlib import pylab as plt
from typing import Optional, Tuple, Mapping

from learned_optimization.outer_trainers import full_es
from learned_optimization.outer_trainers import truncated_pes
from learned_optimization.outer_trainers import truncated_grad
from learned_optimization.outer_trainers import gradient_learner
from learned_optimization.outer_trainers import truncation_schedule
from learned_optimization.outer_trainers import common
from learned_optimization.outer_trainers import lopt_truncated_step
from learned_optimization.outer_trainers import truncated_step as truncated_step_mod
from learned_optimization.outer_trainers.gradient_learner import WorkerWeights, GradientEstimatorState, GradientEstimatorOut
from learned_optimization.outer_trainers import common

from learned_optimization.tasks import quadratics
from learned_optimization.tasks.fixed import image_mlp
from learned_optimization.tasks import base as tasks_base

from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.learned_optimizers import mlp_lopt
from learned_optimization.optimizers import base as opt_base

from learned_optimization import optimizers
from learned_optimization import training
from learned_optimization import eval_training

import haiku as hk
import tqdm

Gradient estimators provide an interface to estimate gradients of some loss with respect to the parameters of some meta-learned system. GradientEstimator are not specific to learned optimizers, and can be applied to any unrolled system defined by a TruncatedStep (see previous colab).

learned_optimization supports a handful of estimators each with different strengths and weaknesses. Understanding which estimators are right for which situations is an open research question. After providing some introductions to the GradientEstimator class, we provide a quick tour of the different estimators implemented here.

The GradientEstimator base class signature is below.

PRNGKey = jnp.ndarray


class GradientEstimator:
  truncated_step: truncated_step_mod.TruncatedStep

  def init_worker_state(self, worker_weights: WorkerWeights,
                        key: PRNGKey) -> GradientEstimatorState:
    raise NotImplementedError()

  def compute_gradient_estimate(
      self, worker_weights: WorkerWeights, key: PRNGKey,
      state: GradientEstimatorState, with_summary: Optional[bool]
  ) -> Tuple[GradientEstimatorOut, Mapping[str, jnp.ndarray]]:
    raise NotImplementedError()

A gradient estimator must have an instance of a TaskFamily – or the task that is being used to estimate gradients with, an init_worker_state function – which initializes the current state of the gradient estimator, and a compute_gradient_estimate function which takes state and computes a bunch of outputs (GradientEstimatorOut) which contain the computed gradients with respect to the learned optimizer, meta-loss values, and various other information about the unroll. Additionally a mapping which contains various metrics is returned.

Both of these methods take in a WorkerWeights instance. This particular piece of data represents the learnable weights needed to compute a gradients including the weights of the learned optimizer, as well as potentially non-learnable running statistics such as those computed with batch norm. In every case this contains the weights of the meta-learned algorithm (e.g. an optimizer) and is called theta. This can also contain other info though. If the learned optimizer has batchnorm, for example, it could also contain running averages.

In the following examples, we will show gradient estimation on learned optimizers using the VectorizedLOptTruncatedStep.

task_family = quadratics.FixedDimQuadraticFamily(10)
lopt = lopt_base.LearnableAdam()
# With FullES, there are no truncations, so we set trunc_sched to never ending.
trunc_sched = truncation_schedule.NeverEndingTruncationSchedule()
truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
    task_family,
    lopt,
    trunc_sched,
    num_tasks=3,
)

FullES

The FullES estimator is one of the simplest, and most reliable estimators but can be slow in practice as it does not make use of truncations. Instead, it uses antithetic sampling to estimate a gradient via ES of an entire optimization (hence the full in the name).

First we define a meta-objective, $f(\theta)$, which could be the loss at the end of training, or average loss. Next, we compute a gradient estimate via ES gradient estimation:

$\nabla_\theta f \approx \dfrac{\epsilon}{2\sigma^2} (f(\theta + \epsilon) - f(\theta - \epsilon))$

We can instantiate one of these as follows:

es_trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
gradient_estimator = full_es.FullES(
    truncated_step, truncation_schedule=es_trunc_sched)
key = jax.random.PRNGKey(0)
theta = truncated_step.outer_init(key)
worker_weights = gradient_learner.WorkerWeights(
    theta=theta,
    theta_model_state=None,
    outer_state=gradient_learner.OuterState(0))

Because we are working with full length unrolls, this gradient estimator has no state – there is nothing to keep track of truncation to truncation.

gradient_estimator_state = gradient_estimator.init_worker_state(
    worker_weights, key=key)
gradient_estimator_state
UnrollState()

Gradients can be computed with the compute_gradient_estimate method.

out, metrics = gradient_estimator.compute_gradient_estimate(
    worker_weights, key=key, state=gradient_estimator_state, with_summary=False)
out.grad
{'log_epsilon': DeviceArray(-0.0173279, dtype=float32),
 'log_lr': DeviceArray(-0.00474211, dtype=float32),
 'one_minus_beta1': DeviceArray(-0.02331395, dtype=float32),
 'one_minus_beta2': DeviceArray(0.00497994, dtype=float32)}

TruncatedPES

Truncated Persistent Evolutionary Strategies (PES) is a unbiased truncation method based on ES. It was proposed in Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies and has been a promising tool for training learned optimizers.

trunc_sched = truncation_schedule.ConstantTruncationSchedule(10)
truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
    task_family,
    lopt,
    trunc_sched,
    num_tasks=3,
    random_initial_iteration_offset=10)

gradient_estimator = truncated_pes.TruncatedPES(
    truncated_step=truncated_step, trunc_length=10)
key = jax.random.PRNGKey(1)
theta = truncated_step.outer_init(key)
worker_weights = gradient_learner.WorkerWeights(
    theta=theta,
    theta_model_state=None,
    outer_state=gradient_learner.OuterState(0))
gradient_estimator_state = gradient_estimator.init_worker_state(
    worker_weights, key=key)

Now let’s look at what this state contains.

jax.tree_util.tree_map(lambda x: x.shape, gradient_estimator_state)
PESWorkerState(pos_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), neg_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), accumulator={'log_epsilon': (3,), 'log_lr': (3,), 'one_minus_beta1': (3,), 'one_minus_beta2': (3,)})

First, this contains 2 instances of SingleState – one for the positive perturbation, and one for the negative perturbation. Each one of these contains all the necessary state required to keep track of the training run. This means the opt_state, details from the truncation, the task parameters (sample from the task family), the inner_step, and a bool to determine if done or not.

We can compute one gradient estimate as follows.

out, metrics = gradient_estimator.compute_gradient_estimate(
    worker_weights, key=key, state=gradient_estimator_state, with_summary=False)

This out object contains various outputs from the gradient estimator including gradients with respect to the learned optimizer, as well as the next state of the training models.

out.grad
{'log_epsilon': DeviceArray(0.00452795, dtype=float32),
 'log_lr': DeviceArray(-0.0123316, dtype=float32),
 'one_minus_beta1': DeviceArray(0.00704127, dtype=float32),
 'one_minus_beta2': DeviceArray(0.00493074, dtype=float32)}
jax.tree_util.tree_map(lambda x: x.shape, out.unroll_state)
PESWorkerState(pos_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), neg_state=TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,)), accumulator={'log_epsilon': (3,), 'log_lr': (3,), 'one_minus_beta1': (3,), 'one_minus_beta2': (3,)})

One could simply use these gradients to meta-train, and then use the unroll_states as the next state passed into the compute gradient estimate. For example:

print("Progress on inner problem before", out.unroll_state.pos_state.inner_step)
out, metrics = gradient_estimator.compute_gradient_estimate(
    worker_weights, key=key, state=out.unroll_state, with_summary=False)
print("Progress on inner problem after", out.unroll_state.pos_state.inner_step)
Progress on inner problem before [1 8 3]
Progress on inner problem after [0 7 2]

TruncatedGrad

TruncatedGrad performs truncated backprop through time. This is great for short unrolls, but can run into memory issues, and/or exploding gradients for longer unrolls.

truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
    task_family,
    lopt,
    trunc_sched,
    num_tasks=3,
    random_initial_iteration_offset=10)

gradient_estimator = truncated_grad.TruncatedGrad(
    truncated_step=truncated_step, unroll_length=5, steps_per_jit=5)
key = jax.random.PRNGKey(1)
theta = truncated_step.outer_init(key)
worker_weights = gradient_learner.WorkerWeights(
    theta=theta,
    theta_model_state=None,
    outer_state=gradient_learner.OuterState(0))
gradient_estimator_state = gradient_estimator.init_worker_state(
    worker_weights, key=key)
jax.tree_util.tree_map(lambda x: x.shape, gradient_estimator_state)
TruncatedUnrollState(inner_opt_state=OptaxState(params=(3, 10), state=None, optax_opt_state=(ScaleByAdamState(count=(3,), mu=(3, 10), nu=(3, 10)), EmptyState()), iteration=(3,)), inner_step=(3,), truncation_state=ConstantTruncationState(length=(3,)), task_param=(3, 10), is_done=(3,))
out, metrics = gradient_estimator.compute_gradient_estimate(
    worker_weights, key=key, state=gradient_estimator_state, with_summary=False)
out.grad
{'log_epsilon': DeviceArray(1.5270639e-10, dtype=float32),
 'log_lr': DeviceArray(-0.03582412, dtype=float32),
 'one_minus_beta1': DeviceArray(1.0147129e-06, dtype=float32),
 'one_minus_beta2': DeviceArray(-3.5097173e-08, dtype=float32)}