Outer Trainers

Outer-training, or meta-training is the process of learning the meta-parameters.

Base

learned_optimization.outer_trainers.gradient_learner.GradientEstimator()[source]

Base class for classes which estimate grads (via ES, PES, or backprop).

learned_optimization.outer_trainers.gradient_learner.gradient_worker_compute(worker_weights, gradient_estimators, unroll_states, key, with_metrics, clip_nan_loss_to_value=20.0, extra_metrics=True, device=None)[source]

Compute a gradient signal to meta-train with.

This function performs unrolls for each of the unroll_states with the corresponding gradient_estimator. The results from each of the gradient estimators get’s merged into a single gradient. This aggregation is done to save bandwidth when collecting gradients from workers.

Parameters:
  • worker_weights (WorkerWeights) – Weights created by the GradientLearner and represent the current parameters and model state of the learned optimizer.

  • gradient_estimators (Sequence[GradientEstimator]) – The gradient estimators used to update the unroll state

  • unroll_states (Sequence[GradientEstimatorState]) – state of the gradient estimator (e.g. inner problem weights)

  • key (Array) – jax rng

  • with_metrics (bool) – compute with summary metrics or not

  • clip_nan_loss_to_value (Optional[float]) – float, value to set nan losses to

  • extra_metrics (bool) – log out additional metrics.

  • device (Optional[Device]) – The jax device to run the computation on

Returns:

The results of the computation.

This contains a gradient estimate, the next unroll states, metrics. A subset of which get passed to the GradientLearner.

Return type:

worker_compute_out

learned_optimization.outer_trainers.gradient_learner.SingleMachineGradientLearner(meta_init, gradient_estimators, theta_opt, num_steps=None)[source]

Train with gradient estimators on a single machine.

This is a convience wrapper calling the multi-worker interface – namley both GradientLearner and gradient_worker_compute.

Parameters:

FullES

A vetorized, full length unroll, ES based gradient estimator.

learned_optimization.outer_trainers.full_es.FullES(truncated_step, truncation_schedule, std=0.01, steps_per_jit=10, loss_type='avg', recompute_samples=50, clip_loss_diff=None, stack_antithetic_samples=False, sign_delta_loss_scalar=None)[source]

Gradient Estimator for computing ES gradients over some number of task.

This gradient estimator uses antithetic evolutionary strategies with num_tasks samples.

The loss for a given unroll is computed either with the average loss over a trajectory (loss_type=”avg”), the min loss (loss_type=”min”) or with a computation at the end of training (loss_type=”last_recompute”).

This gradient estimator manages it’s own truncations and thus you should ensure the passed in truncated_step doesn’t do any resetting of the unroll. If using the VectorizedLOptTruncatedStep this means passing an instance of NeverEndingTruncationSchedule as the trunc_sched.

Parameters:
  • truncated_step (VectorizedTruncatedStep) –

  • truncation_schedule (TruncationSchedule) –

  • std (float) –

  • steps_per_jit (int) –

  • loss_type (str) –

  • recompute_samples (int) –

  • clip_loss_diff (Optional[float]) –

  • stack_antithetic_samples (bool) –

  • sign_delta_loss_scalar (Optional[float]) –

TruncatedPES

A vetorized, truncated, PES based gradient estimator.

learned_optimization.outer_trainers.truncated_pes.TruncatedPES(truncated_step, trunc_length=10, std=0.01, steps_per_jit=10, stack_antithetic_samples=False, sign_delta_loss_scalar=None)[source]

GradientEstimator for computing PES gradient estimates.

Persistent Evolution Strategies (PES) is a gradient estimation technique for computing unbiased gradients in a unrolled computation graph. It does this by building of of Evolutionary Strategies but additionally keeping a running buffer of all the previously used perturbations. See the paper for more details (http://proceedings.mlr.press/v139/vicol21a.html).

In practice, PES is higher variance than pure truncated ES but lower bias.

Parameters:
  • truncated_step (VectorizedTruncatedStep) –

  • stack_antithetic_samples (bool) –

  • sign_delta_loss_scalar (Optional[float]) –