Source code for learned_optimization.outer_trainers.truncated_pes

# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A vetorized, truncated, PES based gradient estimator."""

import functools
from typing import Any, Mapping, Optional, Sequence, Tuple

import flax
from flax import jax_utils as flax_jax_utils
import gin
import haiku as hk
import jax
from jax import lax
import jax.numpy as jnp
from learned_optimization import jax_utils
from learned_optimization import profile
from learned_optimization import summary
from learned_optimization import tree_utils
from learned_optimization.outer_trainers import common
from learned_optimization.outer_trainers import gradient_learner
from learned_optimization.outer_trainers import truncated_step as truncated_step_mod

PRNGKey = jnp.ndarray
MetaParams = Any
TruncatedUnrollState = Any


@flax.struct.dataclass
class PESWorkerState(gradient_learner.GradientEstimatorState):
  pos_state: TruncatedUnrollState
  neg_state: TruncatedUnrollState
  accumulator: MetaParams


@functools.partial(jax.jit, static_argnames=("std", "sign_delta_loss_scalar"))
def compute_pes_grad(
    p_yses: Sequence[truncated_step_mod.TruncatedUnrollOut],
    n_yses: Sequence[truncated_step_mod.TruncatedUnrollOut],
    accumulator: MetaParams,
    vec_pos: MetaParams,
    std: float,
    sign_delta_loss_scalar: Optional[float] = None,
) -> Tuple[float, MetaParams, MetaParams, truncated_step_mod.TruncatedUnrollOut,
           float]:
  """Compute the PES gradient estimate from the outputs of many unrolls.

  Args:
    p_yses: Sequence of PES outputs from the positive perturbation.
    n_yses: Sequence of PES outputs from the negative perturbation.
    accumulator: Current PES accumulator from the last iteration.
    vec_pos: Positive perturbations used to compute the current unroll.
    std: Standard deviation of pertrubations used.
    sign_delta_loss_scalar: Optional, if specified the sign of the delta loss
      multiplied by this value is used instead of the real delta_loss

  Returns:
    loss: the mean loss.
    es_grad: the grad estimate.
    new_accumulator: the new accumulator value.
    delta_loss: the difference in positive and negative losses.

  """

  def flat_first(x):
    return x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))

  p_ys = jax.tree_util.tree_map(flat_first, tree_utils.tree_zip_jnp(p_yses))
  n_ys = jax.tree_util.tree_map(flat_first, tree_utils.tree_zip_jnp(n_yses))

  delta_losses = p_ys.loss - n_ys.loss

  if sign_delta_loss_scalar:
    # With PES, there is no single loss for a truncation. For the particular
    # perturbation we will estimate the sign by first averaging.
    sign_per_task = jnp.sign(jnp.mean(delta_losses * p_ys.mask, axis=0))
    delta_losses = jnp.ones_like(
        delta_losses) * sign_per_task * sign_delta_loss_scalar

  has_finished = lax.cumsum(jnp.asarray(p_ys.is_done, dtype=jnp.int32)) > 0

  # p_ys is of the form [sequence, n_tasks]
  denom = jnp.sum(p_ys.mask, axis=0)

  last_unroll_loss = jnp.sum(
      delta_losses * (1.0 - has_finished) * p_ys.mask, axis=0) / denom

  new_unroll_loss = jnp.sum(
      delta_losses * has_finished * p_ys.mask, axis=0) / denom

  factor = 1.0 / (2 * std**2)

  accumulator = tree_utils.tree_add(vec_pos, accumulator)

  num_tasks = last_unroll_loss.shape[0]

  def reshape_to(loss, p):
    return loss.reshape((num_tasks,) + (1,) * (len(p.shape) - 1)) * factor * p

  es_grad_from_accum = jax.tree_util.tree_map(
      functools.partial(reshape_to, last_unroll_loss), accumulator)

  es_grad_from_new_perturb = jax.tree_util.tree_map(
      functools.partial(reshape_to, new_unroll_loss), vec_pos)

  vec_es_grad = jax.tree_util.tree_map(lambda a, b: a + b, es_grad_from_accum,
                                       es_grad_from_new_perturb)

  es_grad = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), vec_es_grad)

  def _switch_one_accum(a, b):
    shape = [num_tasks] + [1] * (len(a.shape) - 1)
    return jnp.where(jnp.reshape(has_finished[-1], shape), a, b)

  new_accumulator = jax.tree_util.tree_map(_switch_one_accum, vec_pos,
                                           accumulator)

  pos_loss = jnp.sum(p_ys.loss * p_ys.mask, axis=0) / jnp.sum(p_ys.mask, axis=0)
  neg_loss = jnp.sum(n_ys.loss * n_ys.mask, axis=0) / jnp.sum(n_ys.mask, axis=0)

  return (
      jnp.mean((pos_loss + neg_loss) / 2.0),
      es_grad,
      new_accumulator,
      p_ys,
      delta_losses,
  )  # pytype: disable=bad-return-type


[docs] @gin.configurable class TruncatedPES(gradient_learner.GradientEstimator): """GradientEstimator for computing PES gradient estimates. Persistent Evolution Strategies (PES) is a gradient estimation technique for computing unbiased gradients in a unrolled computation graph. It does this by building of of Evolutionary Strategies but additionally keeping a running buffer of all the previously used perturbations. See the paper for more details (http://proceedings.mlr.press/v139/vicol21a.html). In practice, PES is higher variance than pure truncated ES but lower bias. """ def __init__( self, truncated_step: truncated_step_mod.VectorizedTruncatedStep, trunc_length=10, std=0.01, steps_per_jit=10, stack_antithetic_samples: bool = False, sign_delta_loss_scalar: Optional[float] = None, ): self.truncated_step = truncated_step self.std = std self.trunc_length = trunc_length self.steps_per_jit = steps_per_jit self.stack_antithetic_samples = stack_antithetic_samples self.sign_delta_loss_scalar = sign_delta_loss_scalar if self.trunc_length % self.steps_per_jit != 0: raise ValueError("Pass a trunc_length and steps_per_jit that are" " multiples of each other.") def task_name(self) -> str: return self.truncated_step.task_name() @profile.wrap() def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, key: PRNGKey) -> PESWorkerState: theta = worker_weights.theta pos_unroll_state = self.truncated_step.init_step_state( theta, worker_weights.outer_state, key, theta_is_vector=False) neg_unroll_state = pos_unroll_state accumulator = jax.tree_util.tree_map( lambda x: jnp.zeros([self.truncated_step.num_tasks] + list(x.shape)), theta) return PESWorkerState( pos_state=pos_unroll_state, neg_state=neg_unroll_state, accumulator=accumulator) @profile.wrap() def get_datas(self): return [ self.truncated_step.get_batch(self.steps_per_jit) for _ in range(self.trunc_length // self.steps_per_jit) ] @profile.wrap() def compute_gradient_estimate( # pytype: disable=signature-mismatch # overriding-parameter-type-checks self, worker_weights: gradient_learner.WorkerWeights, key: PRNGKey, state: PESWorkerState, with_summary: bool = False, datas_list: Optional[Sequence[Any]] = None, ) -> Tuple[gradient_learner.GradientEstimatorOut, Mapping[str, jnp.ndarray]]: p_state = state.pos_state n_state = state.neg_state accumulator = state.accumulator rng = hk.PRNGSequence(key) theta = worker_weights.theta vec_pos, vec_p_theta, vec_n_theta = common.vector_sample_perturbations( theta, next(rng), self.std, self.truncated_step.num_tasks) p_yses = [] n_yses = [] metrics = [] # TODO(lmetz) consider switching this to be a jax.lax.scan when inside jit. for i in range(self.trunc_length // self.steps_per_jit): if datas_list is None: if jax_utils.in_jit(): raise ValueError("Must pass data in when using a jit gradient est.") datas = self.truncated_step.get_batch(self.steps_per_jit) else: datas = datas_list[i] # force all to be non weak type. This is for cache hit reasons. # TODO(lmetz) consider instead just setting the weak type flag? p_state = jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype=x.dtype), p_state) n_state = jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype=x.dtype), n_state) key = next(rng) p_state, n_state, p_ys, n_ys, m = common.maybe_stacked_es_unroll( self.truncated_step, self.steps_per_jit, self.stack_antithetic_samples, vec_p_theta, vec_n_theta, p_state, n_state, key, datas, worker_weights.outer_state, with_summary=with_summary, sample_rng_key=next(rng)) metrics.append(m) p_yses.append(p_ys) n_yses.append(n_ys) loss, es_grad, new_accumulator, p_ys, delta_loss = compute_pes_grad( p_yses, n_yses, accumulator, vec_pos, self.std, sign_delta_loss_scalar=self.sign_delta_loss_scalar) unroll_info = gradient_learner.UnrollInfo( loss=p_ys.loss, iteration=p_ys.iteration, task_param=p_ys.task_param, is_done=p_ys.is_done) output = gradient_learner.GradientEstimatorOut( mean_loss=loss, grad=es_grad, unroll_state=PESWorkerState(p_state, n_state, new_accumulator), unroll_info=unroll_info) metrics = summary.aggregate_metric_list( metrics, use_jnp=jax_utils.in_jit(), key=next(rng)) if with_summary: metrics["sample||delta_loss_sample"] = summary.sample_value( key, jnp.abs(delta_loss)) metrics["mean||delta_loss_mean"] = jnp.abs(delta_loss) if hasattr(p_state, "inner_step"): metrics["sample||inner_step"] = p_state.inner_step[0] metrics["sample||end_inner_step"] = p_state.inner_step[0] return output, metrics
@functools.partial(jax.pmap, axis_name="dev") def _pmap_reduce(vals): return jax.lax.pmean(vals, axis_name="dev") @jax.pmap def vec_key_split(key): key1, key2 = jax.random.split(key) return key1, key2 @gin.configurable class TruncatedPESPMAP(TruncatedPES): """GradientEstimator for computing PES gradient estimates leveraging pmap. See TruncatedPES documentation for information on PES. This estimator additionally makes use of multiple TPU devices via jax's pmap. """ def __init__(self, *args, num_devices=8, replicate_data_across_devices=False, **kwargs): super().__init__(*args, **kwargs) self.num_devices = num_devices self.replicate_data_across_devices = replicate_data_across_devices if len(jax.local_devices()) != self.num_devices: raise ValueError("Mismatch in device count!" f" Found: {jax.local_devices()}." f" Expected {num_devices} devices.") self.pmap_init_step_state = jax.pmap( self.truncated_step.init_step_state, in_axes=(None, None, 0)) self.pmap_compute_pes_grad = jax.pmap( functools.partial(compute_pes_grad, std=self.std)) self.pmap_vector_sample_perturbations = jax.pmap( functools.partial( common.vector_sample_perturbations, std=self.std, num_samples=self.truncated_step.num_tasks), in_axes=(None, 0), ) @functools.partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, None, None), static_broadcasted_argnums=( 0, 6, )) def pmap_unroll_next_state(self, vec_theta, key, state, datas, outer_state, with_summary): theta_is_vector = True key1, key2 = jax.random.split(key) override_num_steps = None (p_state, p_ys), m = common.truncated_unroll( # pylint: disable=unbalanced-tuple-unpacking self.truncated_step, self.steps_per_jit, theta_is_vector, vec_theta, key1, state, datas, outer_state, override_num_steps, with_summary=with_summary, sample_rng_key=key2) return (p_state, p_ys), m @profile.wrap() def init_worker_state(self, worker_weights: gradient_learner.WorkerWeights, key: PRNGKey) -> PESWorkerState: theta = worker_weights.theta keys = jax.random.split(key, self.num_devices) # Note this doesn't use sampled theta for the first init. # I believe this is fine most of the time. # TODO(lmetz) consider init-ing at an is_done state instead. pos_unroll_state = self.pmap_init_step_state(worker_weights.theta, worker_weights.outer_state, keys) neg_unroll_state = pos_unroll_state accumulator = jax.tree_util.tree_map( lambda x: jnp.zeros([self.truncated_step.num_tasks] + list(x.shape)), theta) accumulator = flax_jax_utils.replicate(accumulator) return PESWorkerState( pos_state=pos_unroll_state, neg_state=neg_unroll_state, accumulator=accumulator) @profile.wrap() def compute_gradient_estimate( self, worker_weights: gradient_learner.WorkerWeights, key: PRNGKey, state: PESWorkerState, with_summary: bool = False ) -> Tuple[gradient_learner.GradientEstimatorOut, Mapping[str, jnp.ndarray]]: p_state = state.pos_state n_state = state.neg_state accumulator = state.accumulator vec_key = jax.random.split(key, self.num_devices) theta = worker_weights.theta vec_key1, vec_key = vec_key_split(vec_key) vec_pos, vec_p_theta, vec_n_theta = self.pmap_vector_sample_perturbations( theta, vec_key1) p_yses = [] n_yses = [] metrics = [] def get_batch(): """Get batch with leading dims [num_devices, steps_per_jit, num_tasks].""" if self.replicate_data_across_devices: b = self.truncated_step.get_batch(self.steps_per_jit) return flax_jax_utils.replicate(b) else: # Use different data across the devices batches = [ self.truncated_step.get_batch(self.steps_per_jit) for _ in range(self.num_devices) ] return tree_utils.tree_zip_onp(batches) for _ in range(self.trunc_length // self.steps_per_jit): datas = get_batch() # force all to be non weak type. This is for cache hit reasons. # TODO(lmetz) consider instead just setting the weak type flag? p_state = jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype=x.dtype), p_state) n_state = jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype=x.dtype), n_state) vec_key1, vec_key = vec_key_split(vec_key) (p_state, p_ys), m = self.pmap_unroll_next_state( # pylint: disable=unbalanced-tuple-unpacking vec_p_theta, vec_key, p_state, datas, worker_weights.outer_state, with_summary) metrics.append(m) p_yses.append(p_ys) (n_state, n_ys), _ = self.pmap_unroll_next_state( # pylint: disable=unbalanced-tuple-unpacking vec_n_theta, vec_key, n_state, datas, worker_weights.outer_state, False) n_yses.append(n_ys) loss, es_grad, new_accumulator, p_ys, delta_loss = self.pmap_compute_pes_grad( p_yses, n_yses, accumulator, vec_pos) es_grad, loss = flax_jax_utils.unreplicate(_pmap_reduce((es_grad, loss))) unroll_info = gradient_learner.UnrollInfo( loss=p_ys.loss, iteration=p_ys.iteration, task_param=p_ys.task_param, is_done=p_ys.is_done) output = gradient_learner.GradientEstimatorOut( mean_loss=loss, grad=es_grad, unroll_state=PESWorkerState(p_state, n_state, new_accumulator), unroll_info=unroll_info) metrics = summary.aggregate_metric_list(metrics) if with_summary: metrics["sample||delta_loss_sample"] = summary.sample_value( key, jnp.abs(delta_loss)) metrics["mean||delta_loss_mean"] = jnp.abs(delta_loss) if hasattr(p_state, "inner_step"): metrics["sample||inner_step"] = p_state.inner_step[0] metrics["sample||end_inner_step"] = p_state.inner_step[0] return output, metrics