# coding=utf-8
# Copyright 2021 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Train a learned optimizer with gradients."""
import abc
import functools
import time
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

from absl import logging
import chex
import flax
import gin
import jax
from jax import core
import jax.numpy as jnp
from learned_optimization import checkpoints
from learned_optimization import profile
from learned_optimization import summary
from learned_optimization import tree_utils
from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.optimizers import base as opt_base
from learned_optimization.outer_trainers import truncated_step as truncated_step_mod
import numpy as onp
from typing_extensions import Protocol

PRNGKey = jnp.ndarray
MetaParams = Any
ThetaModelState = Any

class GradientLearnerState:
  theta_opt_state: Any

class OuterState:
  outer_iteration: jnp.ndarray

class WorkerWeights:
  theta: MetaParams
  theta_model_state: Any
  outer_state: Optional[OuterState]

class AggregatedGradient:
  theta_grads: Any
  theta_model_state: Any
  mean_loss: jnp.ndarray

class WorkerComputeOut:
  to_put: AggregatedGradient
  unroll_states: Any
  metrics: Mapping[str, float]
  event_info: Any

class GradientEstimatorState:

class UnrollInfo:
  loss: jnp.ndarray
  iteration: jnp.ndarray
  task_param: jnp.ndarray
  is_done: jnp.ndarray

class GradientEstimatorOut:
  mean_loss: jnp.ndarray
  grad: Any
  unroll_state: GradientEstimatorState
  unroll_info: Optional[UnrollInfo]

class ParameterCheckpoint:
  """State that we write out to disk for using the optimizer."""
  params: lopt_base.MetaParams
  gen_id: str
  step: int

class OptCheckpoint:
  """State that we write out to disk for training the optimizer."""
  gradient_learner_state: GradientLearnerState
  elapsed_time: Union[float, jnp.ndarray]
  total_inner_steps: int

class MetaInitializer(Protocol):
  """Protocol for objects which contain a jax init function."""

  def init(self, key: chex.PRNGKey) -> MetaParams:

def _tree_mean(stack):
  return jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), stack)

def _tree_mean_onp(stack):
  return jax.tree_util.tree_map(lambda x: onp.mean(x, axis=0), stack)

def _get_theta_update_fn(theta_opt: opt_base.Optimizer):

  def update(theta_opt_state, grads, loss, key, model_state):
    with summary.summary_scope("outer_opt"):
      return theta_opt.update(
          theta_opt_state, grads, loss=loss, key=key, model_state=model_state)

  fn = summary.add_with_summary(update)
  return jax.jit(fn, static_argnames=("with_summary",))

class GradientLearner:
  """Learner is responsible for training the weights of the learned opt."""

  def __init__(
      meta_init: MetaInitializer,
      theta_opt: opt_base.Optimizer,
      init_theta_from_path: Optional[str] = None,
      init_outer_state_from_path: Optional[str] = None,
      reset_outer_iteration: bool = False,
      num_steps: Optional[int] = None,
      init_seed: Optional[int] = None,
    self._theta_opt = theta_opt
    self._meta_init = meta_init
    self._init_theta_from_path = init_theta_from_path
    self._init_outer_state_from_path = init_outer_state_from_path
    self._reset_outer_iteration = reset_outer_iteration
    self._num_steps = num_steps
    self._init_seed = init_seed

  def get_meta_params(self, state: GradientLearnerState) -> MetaParams:
    return self._theta_opt.get_params(state.theta_opt_state)

  def get_meta_model_state(self,
                           state: GradientLearnerState) -> ThetaModelState:
    return self._theta_opt.get_state(state.theta_opt_state)

  def get_state_for_worker(self, state: GradientLearnerState) -> WorkerWeights:
    return WorkerWeights(

  def init(self, key: PRNGKey) -> GradientLearnerState:
    """Initial state of the GradientLearner.

    This can be constructed from a random distribution, or loaded from a path.

      key: jax rng key

      gradient_learner_state: A new initial state of the gradient learner.
    if self._init_seed is not None:
      key = jax.random.PRNGKey(self._init_seed)

    theta_init = self._meta_init.init(key)
    # TODO(lmetz) hook up model state for learned optimizers
    model_state = None

    if self._init_theta_from_path:  # pylint: disable=logging-fstring-interpolation
          f"Got a init from params path {self._init_theta_from_path}."
          " Using this instead of random initialization.")

      # To load a checkpoint, the state of the target object must be specified,
      # so we pass fake values here.
      fake_param_checkpoint = ParameterCheckpoint(
          params=theta_init, gen_id="", step=0)
      real_param_checkpoint = checkpoints.load_state(self._init_theta_from_path,
      theta_init = real_param_checkpoint.params

    theta_opt_state = self._theta_opt.init(
        theta_init, model_state, num_steps=self._num_steps)

    if self._init_outer_state_from_path:  # pylint: disable=logging-fstring-interpolation
          f"Got a init from outer state path {self._init_outer_state_from_path}."
          " Using this instead of randomly initializing.")
      fake_checkpoint = OptCheckpoint(
      real_checkpoint = checkpoints.load_state(self._init_outer_state_from_path,
      theta_opt_state = real_checkpoint.gradient_learner_state.theta_opt_state
      if self._reset_outer_iteration:
        theta_opt_state = theta_opt_state.replace(iteration=0)

    return GradientLearnerState(theta_opt_state)

  def update(
      state: GradientLearnerState,
      grads_list: Sequence[AggregatedGradient],
      with_metrics: bool = False,
      key: Optional[PRNGKey] = None
  ) -> Tuple[GradientLearnerState, Mapping[str, float]]:
    """Update the state of the outer-trainer using grads_list.

    This performs one outer weight update by aggregating the gradients in

      state: The state of the outer-trainer.
      grads_list: A list of gradients to be aggregated and applied.
      with_metrics: To compute metrics, or not.
      key: Jax PRNGKey.

      next_state: The next outer-training state.
      metrics: The computed metrics from this update.

    metrics = {}
    theta_opt_state = state.theta_opt_state

    with profile.Profile("stack_grad"):
      grads_stack = tree_utils.tree_zip_onp([t.theta_grads for t in grads_list])
    with profile.Profile("mean_grad"):
      grads = _tree_mean_onp(grads_stack)

    with profile.Profile("stack_state"):
      model_state_stack = tree_utils.tree_zip_onp(
          [t.theta_model_state for t in grads_list])
      next_model_state = _tree_mean_onp(model_state_stack)

    with profile.Profile("stack_loss"):
      losses = jnp.asarray([t.mean_loss for t in grads_list])
      mean_loss = jnp.mean(losses)
      min_loss = jnp.min(losses)

    fn = _get_theta_update_fn(self._theta_opt)
    key1, key2 = jax.random.split(key)
    theta_opt_state, theta_update_metrics = fn(
    metrics = summary.aggregate_metric_list([metrics, theta_update_metrics])

    # Create fast summaries for all steps, and slower summaries occasionally
    metrics["none||mean_loss"] = mean_loss
    metrics["none||best_of_mean_loss"] = min_loss

    if with_metrics:
      metrics["none||theta_grad_norm"] = tree_utils.tree_norm(grads)
      metrics["none||theta_grad_abs_mean"] = tree_utils.tree_mean_abs(grads)

    return GradientLearnerState(theta_opt_state), metrics  # pytype: disable=bad-return-type  # jax-ndarray

[docs] class GradientEstimator(abc.ABC): """Base class for classes which estimate grads (via ES, PES, or backprop).""" truncated_step: truncated_step_mod.TruncatedStep def init_worker_state(self, worker_weights: WorkerWeights, key: PRNGKey) -> GradientEstimatorState: raise NotImplementedError() def compute_gradient_estimate( self, worker_weights: WorkerWeights, key: PRNGKey, state: GradientEstimatorState, with_summary: Optional[bool], datas_list: Any = None, ) -> Tuple[GradientEstimatorOut, Mapping[str, jnp.ndarray]]: raise NotImplementedError() def task_name(self): return "default_task" def cfg_name(self): return "default_cfg" def get_datas(self) -> Any: raise NotImplementedError()
@functools.partial(jax.jit, donate_argnums=(0,)) def _jit_nan_to_num(vals, replace): return jax.tree_util.tree_map( functools.partial( jnp.nan_to_num, nan=replace, posinf=replace, neginf=replace), vals) def _nan_to_num(vals, replace, use_jnp=False): if use_jnp: return _jit_nan_to_num(vals, replace) else: return jax.tree_util.tree_map(onp.nan_to_num, vals) def _tree_zeros_on_device(shapes, device): leaves, treedef = jax.tree_util.tree_flatten(shapes) return jax.tree_util.tree_unflatten( treedef, _tree_zeros_on_device_inner(tuple(leaves), device)) @functools.partial(jax.jit, static_argnums=(0, 1)) def _tree_zeros_on_device_inner(shapes, device): zero_val = lambda x: jax.device_put(jnp.asarray(0, dtype=x.dtype), device) return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, zero_val(x)), shapes)
[docs] @gin.configurable @profile.wrap() def gradient_worker_compute( worker_weights: WorkerWeights, gradient_estimators: Sequence[GradientEstimator], unroll_states: Sequence[GradientEstimatorState], key: PRNGKey, with_metrics: bool, clip_nan_loss_to_value: Optional[float] = 20.0, extra_metrics: bool = True, device: Optional[jax.Device] = None, ) -> WorkerComputeOut: """Compute a gradient signal to meta-train with. This function performs unrolls for each of the unroll_states with the corresponding gradient_estimator. The results from each of the gradient estimators get's merged into a single gradient. This aggregation is done to save bandwidth when collecting gradients from workers. Args: worker_weights: Weights created by the GradientLearner and represent the current parameters and model state of the learned optimizer. gradient_estimators: The gradient estimators used to update the unroll state unroll_states: state of the gradient estimator (e.g. inner problem weights) key: jax rng with_metrics: compute with summary metrics or not clip_nan_loss_to_value: float, value to set nan losses to extra_metrics: log out additional metrics. device: The jax device to run the computation on Returns: worker_compute_out: The results of the computation. This contains a gradient estimate, the next unroll states, metrics. A subset of which get passed to the GradientLearner. """ if device is None: device = jax.local_devices(0)[0] theta = worker_weights.theta theta_model_state = worker_weights.theta_model_state theta_shape = jax.tree_util.tree_map( lambda x: core.ShapedArray(x.shape, x.dtype), theta ) grads_accum = _tree_zeros_on_device(theta_shape, device) metrics_list = [] unroll_states_out = [] losses = [] event_info = [] assert len(gradient_estimators) == len(unroll_states) for si, (estimator, unroll_state) in enumerate(zip(gradient_estimators, unroll_states)): with profile.Profile(f"estimator{si}"): stime = time.time() key, rng = jax.random.split(key) cfg_name = estimator.cfg_name() "compute_gradient_estimate for estimator name %s and cfg name %s", estimator.task_name(), estimator.cfg_name()) with profile.Profile(f"unroll__metrics{with_metrics}"): estimator_out, metrics = estimator.compute_gradient_estimate( worker_weights, rng, unroll_state, with_summary=with_metrics) unroll_states_out.append(estimator_out.unroll_state) losses.append(estimator_out.mean_loss) with profile.Profile("tree_add"): grads_accum = tree_utils.tree_add(grads_accum, estimator_out.grad) # grab a random iteration from the trajectory if estimator_out.unroll_info: idx = onp.random.randint(0, len(estimator_out.unroll_info.loss)) def extract_one(idx, x): return x[idx] if x is not None else None fn = functools.partial(extract_one, idx) onp_task_params = jax.tree_util.tree_map( onp.asarray, estimator_out.unroll_info.task_param) iteration = estimator_out.unroll_info.iteration[ idx] if estimator_out.unroll_info.iteration is not None else None event_info.append({ "loss": estimator_out.unroll_info.loss[idx, :], "task_param": jax.tree_util.tree_map(fn, onp_task_params), "iteration": iteration, "outer_iteration": worker_weights.outer_state.outer_iteration, }) else: logging.warn("No out specified by learner. " "Not logging any events data.") metrics = {k: v for k, v in metrics.items()} if extra_metrics: family_name = estimator.task_name() cfg_name = estimator.cfg_name() if with_metrics: # Metrics don't take into account which task they are comming from. # Let's add additional metrics with the task name pulled out. with profile.Profile("metric_computation"): keys = list(metrics.keys()) for k in keys: v = metrics[k] assert "||" in k, f"bad metric format? Got: {k}" agg, name = k.split("||") metrics[f"{agg}||{family_name}/{name}"] = v metrics[f"{agg}||{cfg_name}/{name}"] = v mean_abs = tree_utils.tree_mean_abs(estimator_out.grad) metrics[f"mean||{family_name}/grad_mean_abs"] = mean_abs metrics[f"mean||{cfg_name}/grad_mean_abs"] = mean_abs norm = tree_utils.tree_norm(estimator_out.grad) metrics[f"mean||{family_name}/grad_norm"] = norm metrics[f"mean||{cfg_name}/grad_norm"] = norm metrics[f"mean||{family_name}/mean_loss"] = estimator_out.mean_loss metrics[f"mean||{cfg_name}/mean_loss"] = estimator_out.mean_loss metrics[f"sample||{family_name}/time"] = time.time() - stime metrics[f"sample||{cfg_name}/time"] = time.time() - stime metrics_list.append(metrics) with profile.Profile("mean_grads"): grads_accum = tree_utils.tree_div(grads_accum, len(gradient_estimators)) mean_loss = jnp.mean(jnp.asarray(losses)) # block here to better account for costs with profile profiling. with profile.Profile("blocking"): stime = time.time() mean_loss.block_until_ready() block_time = time.time() - stime with profile.Profile("summary_aggregation"): metrics = summary.aggregate_metric_list(metrics_list) metrics["mean||block_time"] = block_time with profile.Profile("strip_nan"): # this should ideally never be NAN # TODO(lmetz) check if we need these checks. grads_accum = _nan_to_num(grads_accum, 0.0, use_jnp=True) if clip_nan_loss_to_value: mean_loss = _nan_to_num(mean_loss, clip_nan_loss_to_value, use_jnp=True) with profile.Profile("grads_to_onp"): to_put = AggregatedGradient( theta_grads=grads_accum, theta_model_state=theta_model_state, mean_loss=mean_loss) return WorkerComputeOut( to_put=jax.tree_util.tree_map(onp.asarray, to_put), unroll_states=unroll_states_out, metrics=metrics, event_info=event_info)
@flax.struct.dataclass class SingleMachineState: gradient_learner_state: GradientLearnerState gradient_estimator_states: Sequence[GradientEstimatorState]
[docs] class SingleMachineGradientLearner: """Train with gradient estimators on a single machine. This is a convience wrapper calling the multi-worker interface -- namley both `GradientLearner` and `gradient_worker_compute`. """ def __init__(self, meta_init: MetaInitializer, gradient_estimators: Sequence[GradientEstimator], theta_opt: opt_base.Optimizer, num_steps: Optional[int] = None): """Initializer. Args: meta_init: Class containing an init function to construct outer params. gradient_estimators: Sequence of gradient estimators used to calculate gradients. theta_opt: The optimizer used to train the weights of the learned opt. num_steps: Number of meta-training steps used by optimizer for schedules. """ self.gradient_learner = GradientLearner( meta_init, theta_opt, num_steps=num_steps) self.gradient_estimators = gradient_estimators def init(self, key: PRNGKey) -> SingleMachineState: """Initial state. This initializes the learned optimizer weights randomly, and set's up optimizer variables for these weights. Additionally the first state of the gradient estimators is also initialized. Args: key: jax rng Returns: The initial state """ key1, key = jax.random.split(key) theta_state = self.gradient_learner.init(key1) worker_weights = self.gradient_learner.get_state_for_worker(theta_state) keys = jax.random.split(key, len(self.gradient_estimators)) unroll_states = [ grad_est.init_worker_state(worker_weights, key) for key, grad_est in zip(keys, self.gradient_estimators) ] return SingleMachineState( gradient_learner_state=theta_state, gradient_estimator_states=unroll_states) def update( self, state, key: PRNGKey, with_metrics: Optional[bool] = False ) -> Tuple[SingleMachineState, jnp.ndarray, Mapping[str, jnp.ndarray]]: """Perform one outer-update to train the learned optimizer. Args: state: State of this class key: jax rng with_metrics: To compute metrics or not Returns: state: The next state from this class loss: loss from the current iteration metrics: dictionary of metrics computed """ key1, key2 = jax.random.split(key) worker_weights = self.gradient_learner.get_state_for_worker( state.gradient_learner_state) worker_compute_out = gradient_worker_compute( worker_weights, self.gradient_estimators, state.gradient_estimator_states, key=key1, with_metrics=with_metrics) next_gradient_estimator_states = worker_compute_out.unroll_states next_theta_state, metrics = self.gradient_learner.update( state.gradient_learner_state, [worker_compute_out.to_put], key=key2, with_metrics=with_metrics) metrics = summary.aggregate_metric_list( [worker_compute_out.metrics, metrics]) return (SingleMachineState( gradient_learner_state=next_theta_state, gradient_estimator_states=next_gradient_estimator_states), worker_compute_out.to_put.mean_loss, metrics) def get_meta_params(self, state: SingleMachineState) -> lopt_base.MetaParams: """Get the weights of the learned optimizer.""" return self.gradient_learner.get_meta_params(state.gradient_learner_state)