# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Train a learned optimizer with gradients."""
import abc
import functools
import time
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
from absl import logging
import chex
import flax
import gin
import jax
from jax import core
import jax.numpy as jnp
from learned_optimization import checkpoints
from learned_optimization import profile
from learned_optimization import summary
from learned_optimization import tree_utils
from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.optimizers import base as opt_base
from learned_optimization.outer_trainers import truncated_step as truncated_step_mod
import numpy as onp
from typing_extensions import Protocol
PRNGKey = jnp.ndarray
MetaParams = Any
ThetaModelState = Any
@flax.struct.dataclass
class GradientLearnerState:
theta_opt_state: Any
@flax.struct.dataclass
class OuterState:
outer_iteration: jnp.ndarray
@flax.struct.dataclass
class WorkerWeights:
theta: MetaParams
theta_model_state: Any
outer_state: Optional[OuterState]
@flax.struct.dataclass
class AggregatedGradient:
theta_grads: Any
theta_model_state: Any
mean_loss: jnp.ndarray
@flax.struct.dataclass
class WorkerComputeOut:
to_put: AggregatedGradient
unroll_states: Any
metrics: Mapping[str, float]
event_info: Any
@flax.struct.dataclass
class GradientEstimatorState:
pass
@flax.struct.dataclass
class UnrollInfo:
loss: jnp.ndarray
iteration: jnp.ndarray
task_param: jnp.ndarray
is_done: jnp.ndarray
@flax.struct.dataclass
class GradientEstimatorOut:
mean_loss: jnp.ndarray
grad: Any
unroll_state: GradientEstimatorState
unroll_info: Optional[UnrollInfo]
@flax.struct.dataclass
class ParameterCheckpoint:
"""State that we write out to disk for using the optimizer."""
params: lopt_base.MetaParams
gen_id: str
step: int
@flax.struct.dataclass
class OptCheckpoint:
"""State that we write out to disk for training the optimizer."""
gradient_learner_state: GradientLearnerState
elapsed_time: Union[float, jnp.ndarray]
total_inner_steps: int
class MetaInitializer(Protocol):
"""Protocol for objects which contain a jax init function."""
def init(self, key: chex.PRNGKey) -> MetaParams:
pass
@jax.jit
def _tree_mean(stack):
return jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), stack)
def _tree_mean_onp(stack):
return jax.tree_util.tree_map(lambda x: onp.mean(x, axis=0), stack)
@functools.lru_cache(None)
def _get_theta_update_fn(theta_opt: opt_base.Optimizer):
def update(theta_opt_state, grads, loss, key, model_state):
with summary.summary_scope("outer_opt"):
return theta_opt.update(
theta_opt_state, grads, loss=loss, key=key, model_state=model_state)
fn = summary.add_with_summary(update)
return jax.jit(fn, static_argnames=("with_summary",))
@gin.configurable
class GradientLearner:
"""Learner is responsible for training the weights of the learned opt."""
def __init__(
self,
meta_init: MetaInitializer,
theta_opt: opt_base.Optimizer,
init_theta_from_path: Optional[str] = None,
init_outer_state_from_path: Optional[str] = None,
reset_outer_iteration: bool = False,
num_steps: Optional[int] = None,
init_seed: Optional[int] = None,
):
self._theta_opt = theta_opt
self._meta_init = meta_init
self._init_theta_from_path = init_theta_from_path
self._init_outer_state_from_path = init_outer_state_from_path
self._reset_outer_iteration = reset_outer_iteration
self._num_steps = num_steps
self._init_seed = init_seed
def get_meta_params(self, state: GradientLearnerState) -> MetaParams:
return self._theta_opt.get_params(state.theta_opt_state)
def get_meta_model_state(self,
state: GradientLearnerState) -> ThetaModelState:
return self._theta_opt.get_state(state.theta_opt_state)
def get_state_for_worker(self, state: GradientLearnerState) -> WorkerWeights:
return WorkerWeights(
theta=self.get_meta_params(state),
theta_model_state=self.get_meta_model_state(state),
outer_state=OuterState(state.theta_opt_state.iteration))
def init(self, key: PRNGKey) -> GradientLearnerState:
"""Initial state of the GradientLearner.
This can be constructed from a random distribution, or loaded from a path.
Args:
key: jax rng key
Returns:
gradient_learner_state: A new initial state of the gradient learner.
"""
if self._init_seed is not None:
key = jax.random.PRNGKey(self._init_seed)
theta_init = self._meta_init.init(key)
# TODO(lmetz) hook up model state for learned optimizers
model_state = None
if self._init_theta_from_path:
logging.info( # pylint: disable=logging-fstring-interpolation
f"Got a init from params path {self._init_theta_from_path}."
" Using this instead of random initialization.")
# To load a checkpoint, the state of the target object must be specified,
# so we pass fake values here.
fake_param_checkpoint = ParameterCheckpoint(
params=theta_init, gen_id="", step=0)
real_param_checkpoint = checkpoints.load_state(self._init_theta_from_path,
fake_param_checkpoint)
theta_init = real_param_checkpoint.params
theta_opt_state = self._theta_opt.init(
theta_init, model_state, num_steps=self._num_steps)
if self._init_outer_state_from_path:
logging.info( # pylint: disable=logging-fstring-interpolation
f"Got a init from outer state path {self._init_outer_state_from_path}."
" Using this instead of randomly initializing.")
fake_checkpoint = OptCheckpoint(
gradient_learner_state=GradientLearnerState(theta_opt_state),
elapsed_time=0.0,
total_inner_steps=1)
real_checkpoint = checkpoints.load_state(self._init_outer_state_from_path,
fake_checkpoint)
theta_opt_state = real_checkpoint.gradient_learner_state.theta_opt_state
if self._reset_outer_iteration:
theta_opt_state = theta_opt_state.replace(iteration=0)
return GradientLearnerState(theta_opt_state)
def update(
self,
state: GradientLearnerState,
grads_list: Sequence[AggregatedGradient],
with_metrics: bool = False,
key: Optional[PRNGKey] = None
) -> Tuple[GradientLearnerState, Mapping[str, float]]:
"""Update the state of the outer-trainer using grads_list.
This performs one outer weight update by aggregating the gradients in
`grads_list`.
Args:
state: The state of the outer-trainer.
grads_list: A list of gradients to be aggregated and applied.
with_metrics: To compute metrics, or not.
key: Jax PRNGKey.
Returns:
next_state: The next outer-training state.
metrics: The computed metrics from this update.
"""
metrics = {}
theta_opt_state = state.theta_opt_state
with profile.Profile("stack_grad"):
grads_stack = tree_utils.tree_zip_onp([t.theta_grads for t in grads_list])
with profile.Profile("mean_grad"):
grads = _tree_mean_onp(grads_stack)
with profile.Profile("stack_state"):
model_state_stack = tree_utils.tree_zip_onp(
[t.theta_model_state for t in grads_list])
next_model_state = _tree_mean_onp(model_state_stack)
with profile.Profile("stack_loss"):
losses = jnp.asarray([t.mean_loss for t in grads_list])
mean_loss = jnp.mean(losses)
min_loss = jnp.min(losses)
fn = _get_theta_update_fn(self._theta_opt)
key1, key2 = jax.random.split(key)
theta_opt_state, theta_update_metrics = fn(
theta_opt_state,
grads,
mean_loss,
key1,
next_model_state,
sample_rng_key=key2,
with_summary=with_metrics)
metrics = summary.aggregate_metric_list([metrics, theta_update_metrics])
# Create fast summaries for all steps, and slower summaries occasionally
metrics["none||mean_loss"] = mean_loss
metrics["none||best_of_mean_loss"] = min_loss
if with_metrics:
metrics["none||theta_grad_norm"] = tree_utils.tree_norm(grads)
metrics["none||theta_grad_abs_mean"] = tree_utils.tree_mean_abs(grads)
return GradientLearnerState(theta_opt_state), metrics # pytype: disable=bad-return-type # jax-ndarray
[docs]
class GradientEstimator(abc.ABC):
"""Base class for classes which estimate grads (via ES, PES, or backprop)."""
truncated_step: truncated_step_mod.TruncatedStep
def init_worker_state(self, worker_weights: WorkerWeights,
key: PRNGKey) -> GradientEstimatorState:
raise NotImplementedError()
def compute_gradient_estimate(
self,
worker_weights: WorkerWeights,
key: PRNGKey,
state: GradientEstimatorState,
with_summary: Optional[bool],
datas_list: Any = None,
) -> Tuple[GradientEstimatorOut, Mapping[str, jnp.ndarray]]:
raise NotImplementedError()
def task_name(self):
return "default_task"
def cfg_name(self):
return "default_cfg"
def get_datas(self) -> Any:
raise NotImplementedError()
@functools.partial(jax.jit, donate_argnums=(0,))
def _jit_nan_to_num(vals, replace):
return jax.tree_util.tree_map(
functools.partial(
jnp.nan_to_num, nan=replace, posinf=replace, neginf=replace), vals)
def _nan_to_num(vals, replace, use_jnp=False):
if use_jnp:
return _jit_nan_to_num(vals, replace)
else:
return jax.tree_util.tree_map(onp.nan_to_num, vals)
def _tree_zeros_on_device(shapes, device):
leaves, treedef = jax.tree_util.tree_flatten(shapes)
return jax.tree_util.tree_unflatten(
treedef, _tree_zeros_on_device_inner(tuple(leaves), device))
@functools.partial(jax.jit, static_argnums=(0, 1))
def _tree_zeros_on_device_inner(shapes, device):
zero_val = lambda x: jax.device_put(jnp.asarray(0, dtype=x.dtype), device)
return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, zero_val(x)),
shapes)
[docs]
@gin.configurable
@profile.wrap()
def gradient_worker_compute(
worker_weights: WorkerWeights,
gradient_estimators: Sequence[GradientEstimator],
unroll_states: Sequence[GradientEstimatorState],
key: PRNGKey,
with_metrics: bool,
clip_nan_loss_to_value: Optional[float] = 20.0,
extra_metrics: bool = True,
device: Optional[jax.Device] = None,
) -> WorkerComputeOut:
"""Compute a gradient signal to meta-train with.
This function performs unrolls for each of the unroll_states with the
corresponding gradient_estimator. The results from each of the gradient
estimators get's merged into a single gradient. This aggregation is done
to save bandwidth when collecting gradients from workers.
Args:
worker_weights: Weights created by the GradientLearner and represent the
current parameters and model state of the learned optimizer.
gradient_estimators: The gradient estimators used to update the unroll state
unroll_states: state of the gradient estimator (e.g. inner problem weights)
key: jax rng
with_metrics: compute with summary metrics or not
clip_nan_loss_to_value: float, value to set nan losses to
extra_metrics: log out additional metrics.
device: The jax device to run the computation on
Returns:
worker_compute_out: The results of the computation.
This contains a gradient estimate, the next unroll states, metrics.
A subset of which get passed to the GradientLearner.
"""
if device is None:
device = jax.local_devices(0)[0]
theta = worker_weights.theta
theta_model_state = worker_weights.theta_model_state
theta_shape = jax.tree_util.tree_map(
lambda x: core.ShapedArray(x.shape, x.dtype), theta
)
grads_accum = _tree_zeros_on_device(theta_shape, device)
metrics_list = []
unroll_states_out = []
losses = []
event_info = []
assert len(gradient_estimators) == len(unroll_states)
for si, (estimator,
unroll_state) in enumerate(zip(gradient_estimators, unroll_states)):
with profile.Profile(f"estimator{si}"):
stime = time.time()
key, rng = jax.random.split(key)
cfg_name = estimator.cfg_name()
logging.info(
"compute_gradient_estimate for estimator name %s and cfg name %s",
estimator.task_name(), estimator.cfg_name())
with profile.Profile(f"unroll__metrics{with_metrics}"):
estimator_out, metrics = estimator.compute_gradient_estimate(
worker_weights, rng, unroll_state, with_summary=with_metrics)
unroll_states_out.append(estimator_out.unroll_state)
losses.append(estimator_out.mean_loss)
with profile.Profile("tree_add"):
grads_accum = tree_utils.tree_add(grads_accum, estimator_out.grad)
# grab a random iteration from the trajectory
if estimator_out.unroll_info:
idx = onp.random.randint(0, len(estimator_out.unroll_info.loss))
def extract_one(idx, x):
return x[idx] if x is not None else None
fn = functools.partial(extract_one, idx)
onp_task_params = jax.tree_util.tree_map(
onp.asarray, estimator_out.unroll_info.task_param)
iteration = estimator_out.unroll_info.iteration[
idx] if estimator_out.unroll_info.iteration is not None else None
event_info.append({
"loss": estimator_out.unroll_info.loss[idx, :],
"task_param": jax.tree_util.tree_map(fn, onp_task_params),
"iteration": iteration,
"outer_iteration": worker_weights.outer_state.outer_iteration,
})
else:
logging.warn("No out specified by learner. "
"Not logging any events data.")
metrics = {k: v for k, v in metrics.items()}
if extra_metrics:
family_name = estimator.task_name()
cfg_name = estimator.cfg_name()
if with_metrics:
# Metrics don't take into account which task they are comming from.
# Let's add additional metrics with the task name pulled out.
with profile.Profile("metric_computation"):
keys = list(metrics.keys())
for k in keys:
v = metrics[k]
assert "||" in k, f"bad metric format? Got: {k}"
agg, name = k.split("||")
metrics[f"{agg}||{family_name}/{name}"] = v
metrics[f"{agg}||{cfg_name}/{name}"] = v
mean_abs = tree_utils.tree_mean_abs(estimator_out.grad)
metrics[f"mean||{family_name}/grad_mean_abs"] = mean_abs
metrics[f"mean||{cfg_name}/grad_mean_abs"] = mean_abs
norm = tree_utils.tree_norm(estimator_out.grad)
metrics[f"mean||{family_name}/grad_norm"] = norm
metrics[f"mean||{cfg_name}/grad_norm"] = norm
metrics[f"mean||{family_name}/mean_loss"] = estimator_out.mean_loss
metrics[f"mean||{cfg_name}/mean_loss"] = estimator_out.mean_loss
metrics[f"sample||{family_name}/time"] = time.time() - stime
metrics[f"sample||{cfg_name}/time"] = time.time() - stime
metrics_list.append(metrics)
with profile.Profile("mean_grads"):
grads_accum = tree_utils.tree_div(grads_accum, len(gradient_estimators))
mean_loss = jnp.mean(jnp.asarray(losses))
# block here to better account for costs with profile profiling.
with profile.Profile("blocking"):
stime = time.time()
mean_loss.block_until_ready()
block_time = time.time() - stime
with profile.Profile("summary_aggregation"):
metrics = summary.aggregate_metric_list(metrics_list)
metrics["mean||block_time"] = block_time
with profile.Profile("strip_nan"):
# this should ideally never be NAN
# TODO(lmetz) check if we need these checks.
grads_accum = _nan_to_num(grads_accum, 0.0, use_jnp=True)
if clip_nan_loss_to_value:
mean_loss = _nan_to_num(mean_loss, clip_nan_loss_to_value, use_jnp=True)
with profile.Profile("grads_to_onp"):
to_put = AggregatedGradient(
theta_grads=grads_accum,
theta_model_state=theta_model_state,
mean_loss=mean_loss)
return WorkerComputeOut(
to_put=jax.tree_util.tree_map(onp.asarray, to_put),
unroll_states=unroll_states_out,
metrics=metrics,
event_info=event_info)
@flax.struct.dataclass
class SingleMachineState:
gradient_learner_state: GradientLearnerState
gradient_estimator_states: Sequence[GradientEstimatorState]
[docs]
class SingleMachineGradientLearner:
"""Train with gradient estimators on a single machine.
This is a convience wrapper calling the multi-worker interface -- namley
both `GradientLearner` and `gradient_worker_compute`.
"""
def __init__(self,
meta_init: MetaInitializer,
gradient_estimators: Sequence[GradientEstimator],
theta_opt: opt_base.Optimizer,
num_steps: Optional[int] = None):
"""Initializer.
Args:
meta_init: Class containing an init function to construct outer params.
gradient_estimators: Sequence of gradient estimators used to calculate
gradients.
theta_opt: The optimizer used to train the weights of the learned opt.
num_steps: Number of meta-training steps used by optimizer for schedules.
"""
self.gradient_learner = GradientLearner(
meta_init, theta_opt, num_steps=num_steps)
self.gradient_estimators = gradient_estimators
def init(self, key: PRNGKey) -> SingleMachineState:
"""Initial state.
This initializes the learned optimizer weights randomly, and set's up
optimizer variables for these weights. Additionally the first state of the
gradient estimators is also initialized.
Args:
key: jax rng
Returns:
The initial state
"""
key1, key = jax.random.split(key)
theta_state = self.gradient_learner.init(key1)
worker_weights = self.gradient_learner.get_state_for_worker(theta_state)
keys = jax.random.split(key, len(self.gradient_estimators))
unroll_states = [
grad_est.init_worker_state(worker_weights, key)
for key, grad_est in zip(keys, self.gradient_estimators)
]
return SingleMachineState(
gradient_learner_state=theta_state,
gradient_estimator_states=unroll_states)
def update(
self,
state,
key: PRNGKey,
with_metrics: Optional[bool] = False
) -> Tuple[SingleMachineState, jnp.ndarray, Mapping[str, jnp.ndarray]]:
"""Perform one outer-update to train the learned optimizer.
Args:
state: State of this class
key: jax rng
with_metrics: To compute metrics or not
Returns:
state: The next state from this class
loss: loss from the current iteration
metrics: dictionary of metrics computed
"""
key1, key2 = jax.random.split(key)
worker_weights = self.gradient_learner.get_state_for_worker(
state.gradient_learner_state)
worker_compute_out = gradient_worker_compute(
worker_weights,
self.gradient_estimators,
state.gradient_estimator_states,
key=key1,
with_metrics=with_metrics)
next_gradient_estimator_states = worker_compute_out.unroll_states
next_theta_state, metrics = self.gradient_learner.update(
state.gradient_learner_state, [worker_compute_out.to_put],
key=key2,
with_metrics=with_metrics)
metrics = summary.aggregate_metric_list(
[worker_compute_out.metrics, metrics])
return (SingleMachineState(
gradient_learner_state=next_theta_state,
gradient_estimator_states=next_gradient_estimator_states),
worker_compute_out.to_put.mean_loss, metrics)
def get_meta_params(self, state: SingleMachineState) -> lopt_base.MetaParams:
"""Get the weights of the learned optimizer."""
return self.gradient_learner.get_meta_params(state.gradient_learner_state)