# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Train a learned optimizer with gradients."""
import abc
import functools
from typing import Any, Mapping, Optional, Sequence, Tuple
from absl import logging
import flax
import gin
import jax
import jax.numpy as jnp
from learned_optimization import profile
from learned_optimization import summary
from learned_optimization import training
from learned_optimization import tree_utils
from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.optimizers import base as opt_base
from learned_optimization.tasks import base as tasks_base
import numpy as onp
PRNGKey = jnp.ndarray
ThetaParams = Any
ThetaModelState = Any
@flax.struct.dataclass
class GradientLearnerState:
theta_opt_state: Any
@flax.struct.dataclass
class OuterState:
outer_iteration: jnp.ndarray
@flax.struct.dataclass
class WorkerWeights:
theta: Any
theta_model_state: Any
outer_state: OuterState
@flax.struct.dataclass
class AggregatedGradient:
theta_grads: Any
theta_model_state: Any
mean_loss: jnp.ndarray
@flax.struct.dataclass
class WorkerComputeOut:
to_put: AggregatedGradient
unroll_states: Any
metrics: Mapping[str, float]
event_info: Any
@flax.struct.dataclass
class GradientEstimatorState:
pass
@flax.struct.dataclass
class UnrollInfo:
loss: jnp.ndarray
iteration: jnp.ndarray
task_param: jnp.ndarray
is_done: jnp.ndarray
@flax.struct.dataclass
class GradientEstimatorOut:
mean_loss: jnp.ndarray
grad: Any
unroll_state: GradientEstimatorState
unroll_info: UnrollInfo
@jax.jit
def _tree_mean(stack):
return jax.tree_map(lambda x: jnp.mean(x, axis=0), stack)
@gin.configurable
class GradientLearner:
"""Learner is responsible for training the weights of the learned opt."""
def __init__(self,
lopt: lopt_base.LearnedOptimizer,
theta_opt: opt_base.Optimizer,
init_theta_from_path: Optional[str] = None,
init_outer_state_from_path: Optional[str] = None,
num_steps: Optional[int] = None):
self._theta_opt = theta_opt
self._theta_opt_update = jax.jit(self._theta_opt.update)
self._lopt = lopt
self._init_theta_from_path = init_theta_from_path
self._init_outer_state_from_path = init_outer_state_from_path
self._num_steps = num_steps
@property
def learned_optimizer(self):
return self._lopt
def get_lopt_params(self, state: GradientLearnerState) -> ThetaParams:
return self._theta_opt.get_params(state.theta_opt_state)
def get_lopt_model_state(self,
state: GradientLearnerState) -> ThetaModelState:
return self._theta_opt.get_state(state.theta_opt_state)
def get_state_for_worker(self, state: GradientLearnerState) -> WorkerWeights:
return WorkerWeights(
theta=self.get_lopt_params(state),
theta_model_state=self.get_lopt_model_state(state),
outer_state=OuterState(state.theta_opt_state.iteration))
def init(self, key: PRNGKey) -> GradientLearnerState:
"""Initial state of the GradientLearner.
This can be constructed from a random distribution, or loaded from a path.
Args:
key: jax rng key
Returns:
gradient_learner_state: A new initial state of the gradient learner.
"""
theta_init = self._lopt.init(key)
# TODO(lmetz) hook up model state for learned optimizers
model_state = None
if self._init_theta_from_path:
logging.info( # pylint: disable=logging-fstring-interpolation
f"Got a init from params path {self._init_theta_from_path}."
" Using this instead of random randomly initializing.")
# parameter checkpoints are stored as tuples of lopt weights, (theta),
# lopt model state, a string generation id (used for population based
# training), and the outer-training iteration / number of weight updates
# performed.
# To load a checkpoint, the state of the target object must be specified,
# so we pass fake values here.
gen_id = ""
iteration = 0
(theta_init, model_state,
unused_gen_id, unused_iteration) = training.load_state(
self._init_theta_from_path,
(theta_init, model_state, gen_id, iteration))
del unused_gen_id
del unused_iteration
theta_opt_state = self._theta_opt.init(
theta_init, model_state, num_steps=self._num_steps)
if self._init_outer_state_from_path:
logging.info( # pylint: disable=logging-fstring-interpolation
f"Got a init from outer state path {self._init_outer_state_from_path}."
" Using this instead of randomly initializing.")
theta_opt_state = training.load_state(self._init_outer_state_from_path,
theta_opt_state)
return GradientLearnerState(theta_opt_state)
def update(
self,
state: GradientLearnerState,
grads_list: Sequence[AggregatedGradient],
with_metrics: bool = False,
key: Optional[PRNGKey] = None
) -> Tuple[GradientLearnerState, Mapping[str, float]]:
"""Update the state of the outer-trainer using grads_list.
This performs one outer weight update by aggregating the gradients in
`grads_list`.
Args:
state: The state of the outer-trainer.
grads_list: A list of gradients to be aggregated and applied.
with_metrics: To compute metrics, or not.
key: Jax PRNGKey.
Returns:
next_state: The next outer-training state.
metrics: The computed metrics from this update.
"""
metrics = {}
theta_opt_state = state.theta_opt_state
with profile.Profile("stack_data"):
grads_stack = tree_utils.tree_zip_onp([t.theta_grads for t in grads_list])
grads = _tree_mean(grads_stack)
model_state_stack = tree_utils.tree_zip_onp(
[t.theta_model_state for t in grads_list])
next_model_state = _tree_mean(model_state_stack)
losses = jnp.asarray([t.mean_loss for t in grads_list])
mean_loss = jnp.mean(losses)
min_loss = jnp.min(losses)
theta_opt_state = self._theta_opt_update(
theta_opt_state,
grads,
mean_loss,
key=key,
model_state=next_model_state)
# Create fast summaries for all steps, and slower summaries occasionally
metrics["none||mean_loss"] = mean_loss
metrics["none||best_of_mean_loss"] = min_loss
if with_metrics:
metrics["none||theta_grad_norm"] = tree_utils.tree_norm(grads)
metrics["none||theta_grad_abs_mean"] = tree_utils.tree_mean_abs(grads)
return GradientLearnerState(theta_opt_state), metrics
[docs]class GradientEstimator(abc.ABC):
"""Base class for classes which estimate grads (via ES, PES, or backprop)."""
task_family: tasks_base.TaskFamily
def init_worker_state(self, worker_weights: WorkerWeights,
key: PRNGKey) -> GradientEstimatorState:
raise NotImplementedError()
def compute_gradient_estimate(
self, worker_weights: WorkerWeights, key: PRNGKey,
state: GradientEstimatorState, with_summary: Optional[bool]
) -> Tuple[GradientEstimatorOut, Mapping[str, jnp.ndarray]]:
raise NotImplementedError()
def _nan_to_num(vals, replace, use_jnp=False):
if use_jnp:
return jax.tree_map(
functools.partial(
jnp.nan_to_num, nan=replace, posinf=replace, neginf=replace), vals)
else:
return jax.tree_map(onp.nan_to_num, vals)
def _tree_zeros_on_device(shapes, device):
leaves, treedef = jax.tree_flatten(shapes)
return jax.tree_unflatten(treedef,
_tree_zeros_on_device_inner(tuple(leaves), device))
@functools.partial(jax.jit, static_argnums=(0, 1))
def _tree_zeros_on_device_inner(shapes, device):
zero_val = lambda x: jax.device_put(jnp.asarray(0, dtype=x.dtype), device)
return jax.tree_map(lambda x: jnp.full(x.shape, zero_val(x)), shapes)
[docs]@gin.configurable
@profile.wrap()
def gradient_worker_compute(
worker_weights: WorkerWeights,
gradient_estimators: Sequence[GradientEstimator],
unroll_states: Sequence[GradientEstimatorState],
key: PRNGKey,
with_metrics: bool,
device: Optional[jax.lib.xla_client.Device] = None) -> WorkerComputeOut:
"""Compute a gradient signal to meta-train with.
This function performs unrolls for each of the unroll_states with the
corresponding gradient_estimator. The results from each of the gradient
estimators get's merged into a single gradient. This aggregation is done
to save bandwidth when collecting gradients from workers.
Args:
worker_weights: Weights created by the GradientLearner and represent the
current parameters and model state of the learned optimizer.
gradient_estimators: The gradient estimators used to update the unroll state
unroll_states: state of the gradient estimator (e.g. inner problem weights)
key: jax rng
with_metrics: compute with summary metrics or not
device: The jax device to run the computation on
Returns:
worker_compute_out: The results of the computation.
This contains a gradient estimate, the next unroll states, metrics.
A subset of which get passed to the GradientLearner.
"""
if device is None:
device = jax.local_devices(0)[0]
theta = worker_weights.theta
theta_model_state = worker_weights.theta_model_state
theta_shape = jax.tree_map(lambda x: jax.ShapedArray(x.shape, x.dtype), theta)
grads_accum = _tree_zeros_on_device(theta_shape, device)
metrics_list = []
unroll_states_out = []
losses = []
event_info = []
assert len(gradient_estimators) == len(unroll_states)
for si, (estimator,
unroll_state) in enumerate(zip(gradient_estimators, unroll_states)):
with profile.Profile(f"estimator{si}"):
key, rng = jax.random.split(key)
with profile.Profile(f"unroll__metrics{with_metrics}"):
estimator_out, metrics = estimator.compute_gradient_estimate(
worker_weights, rng, unroll_state, with_summary=with_metrics)
unroll_states_out.append(estimator_out.unroll_state)
losses.append(estimator_out.mean_loss)
with profile.Profile("tree_add"):
grads_accum = tree_utils.tree_add(grads_accum, estimator_out.grad)
# grab a random iteration from the trajectory
if estimator_out.unroll_info:
idx = onp.random.randint(0, len(estimator_out.unroll_info.loss))
def extract_one(idx, x):
return x[idx]
fn = functools.partial(extract_one, idx)
onp_task_params = jax.tree_map(onp.asarray,
estimator_out.unroll_info.task_param)
event_info.append({
"loss": estimator_out.unroll_info.loss[idx, :],
"task_param": jax.tree_map(fn, onp_task_params),
"iteration": estimator_out.unroll_info.iteration[idx],
"outer_iteration": worker_weights.outer_state.outer_iteration,
})
else:
logging.warn("No out specified by learner. "
"Not logging any events data.")
if with_metrics:
metrics = {k: v for k, v in metrics.items()}
# Metrics don't take into account which task they are comming from.
# Let's add additional metrics with the task name pulled out.
with profile.Profile("metric_computation"):
keys = list(metrics.keys())
for k in keys:
v = metrics[k]
assert "||" in k, f"bad metric format? Got: {k}"
agg, name = k.split("||")
metrics[f"{agg}||{estimator.task_family.name}/{name}"] = v
mean_abs = tree_utils.tree_mean_abs(estimator_out.grad)
family_name = estimator.task_family.name
metrics[f"mean||{family_name}/grad_mean_abs"] = mean_abs
norm = tree_utils.tree_norm(estimator_out.grad)
metrics[f"mean||{family_name}/grad_norm"] = norm
metrics[f"mean||{family_name}/mean_loss"] = estimator_out.mean_loss
metrics_list.append(metrics)
with profile.Profile("mean_grads"):
grads_accum = tree_utils.tree_div(grads_accum, len(gradient_estimators))
mean_loss = jnp.mean(jnp.asarray(losses))
# block here to better account for costs with profile profiling.
with profile.Profile("blocking"):
mean_loss.block_until_ready()
with profile.Profile("summary_aggregation"):
metrics = summary.aggregate_metric_list(metrics_list)
with profile.Profile("strip_nan"):
# this should ideally never be NAN
# TODO(lmetz) check if we need these checks.
grads_accum = _nan_to_num(grads_accum, 0.0)
# assume things are roughly scaled to 0-10. So 20 should be a big value.
# this doesn't effect gradient calculations.
mean_loss = _nan_to_num(mean_loss, 20.0, use_jnp=True)
with profile.Profile("grads_to_onp"):
to_put = AggregatedGradient(
theta_grads=grads_accum,
theta_model_state=theta_model_state,
mean_loss=mean_loss)
return WorkerComputeOut(
to_put=jax.tree_map(onp.asarray, to_put),
unroll_states=unroll_states_out,
metrics=metrics,
event_info=event_info)
@flax.struct.dataclass
class SingleMachineState:
gradient_learner_state: GradientLearnerState
gradient_estimator_states: Sequence[GradientEstimatorState]
[docs]class SingleMachineGradientLearner:
"""Train with gradient estimators on a single machine.
This is a convience wrapper calling the multi-worker interface -- namley
both `GradientLearner` and `gradient_worker_compute`.
"""
def __init__(self,
learned_opt: lopt_base.LearnedOptimizer,
gradient_estimators: Sequence[GradientEstimator],
theta_opt: opt_base.Optimizer,
num_steps: Optional[int] = None):
"""Initializer.
Args:
learned_opt: Learned optimizer to train
gradient_estimators: Sequence of gradient estimators used to calculate
gradients.
theta_opt: The optimizer used to train the weights of the learned opt.
num_steps: Number of meta-training steps used by optimizer for schedules.
"""
self.gradient_learner = GradientLearner(
learned_opt, theta_opt, num_steps=num_steps)
self.gradient_estimators = gradient_estimators
def init(self, key: PRNGKey) -> SingleMachineState:
"""Initial state.
This initializes the learned optimizer weights randomly, and set's up
optimizer variables for these weights. Additionally the first state of the
gradient estimators is also initialized.
Args:
key: jax rng
Returns:
The initial state
"""
key1, key = jax.random.split(key)
theta_state = self.gradient_learner.init(key1)
worker_weights = self.gradient_learner.get_state_for_worker(theta_state)
keys = jax.random.split(key, len(self.gradient_estimators))
unroll_states = [
grad_est.init_worker_state(worker_weights, key)
for key, grad_est in zip(keys, self.gradient_estimators)
]
return SingleMachineState(
gradient_learner_state=theta_state,
gradient_estimator_states=unroll_states)
def update(
self,
state,
key: PRNGKey,
with_metrics: Optional[bool] = False
) -> Tuple[SingleMachineState, jnp.ndarray, Mapping[str, jnp.ndarray]]:
"""Perform one outer-update to train the learned optimizer.
Args:
state: State of this class
key: jax rng
with_metrics: To compute metrics or not
Returns:
state: The next state from this class
loss: loss from the current iteration
metrics: dictionary of metrics computed
"""
key1, key2 = jax.random.split(key)
worker_weights = self.gradient_learner.get_state_for_worker(
state.gradient_learner_state)
worker_compute_out = gradient_worker_compute(
worker_weights,
self.gradient_estimators,
state.gradient_estimator_states,
key=key1,
with_metrics=with_metrics)
next_gradient_estimator_states = worker_compute_out.unroll_states
next_theta_state, metrics = self.gradient_learner.update(
state.gradient_learner_state, [worker_compute_out.to_put], key=key2)
return (SingleMachineState(
gradient_learner_state=next_theta_state,
gradient_estimator_states=next_gradient_estimator_states),
worker_compute_out.to_put.mean_loss, metrics)
def get_lopt_params(self, state: SingleMachineState) -> lopt_base.MetaParams:
"""Get the weights of the learned optimizer."""
return self.gradient_learner.get_lopt_params(state.gradient_learner_state)