Outer Trainers

Outer-training, or meta-training is the process of learning the meta-parameters.

Base

learned_optimization.outer_trainers.gradient_learner.GradientEstimator()[source]

Base class for classes which estimate grads (via ES, PES, or backprop).

learned_optimization.outer_trainers.gradient_learner.gradient_worker_compute(worker_weights, gradient_estimators, unroll_states, key, with_metrics, device=None)[source]

Compute a gradient signal to meta-train with.

This function performs unrolls for each of the unroll_states with the corresponding gradient_estimator. The results from each of the gradient estimators get’s merged into a single gradient. This aggregation is done to save bandwidth when collecting gradients from workers.

Parameters
  • worker_weights (WorkerWeights) – Weights created by the GradientLearner and represent the current parameters and model state of the learned optimizer.

  • gradient_estimators (Sequence[GradientEstimator]) – The gradient estimators used to update the unroll state

  • unroll_states (Sequence[GradientEstimatorState]) – state of the gradient estimator (e.g. inner problem weights)

  • key (ndarray) – jax rng

  • with_metrics (bool) – compute with summary metrics or not

  • device (Optional[Device]) – The jax device to run the computation on

Returns

The results of the computation.

This contains a gradient estimate, the next unroll states, metrics. A subset of which get passed to the GradientLearner.

Return type

worker_compute_out

learned_optimization.outer_trainers.gradient_learner.SingleMachineGradientLearner(learned_opt, gradient_estimators, theta_opt, num_steps=None)[source]

Train with gradient estimators on a single machine.

This is a convience wrapper calling the multi-worker interface – namley both GradientLearner and gradient_worker_compute.

Parameters

FullES

A vetorized, full length unroll, ES based gradient estimator.

learned_optimization.outer_trainers.full_es.FullES(task_family, learned_opt, num_tasks, unroll_length=20, std=0.01, steps_per_jit=10, train_and_meta=False, loss_type='avg', recompute_samples=50, recompute_split='train', clip_loss_diff=None, stack_antithetic_samples=False)[source]

Gradient Estimator for computing ES gradients over some number of task.

This gradient estimator uses antithetic evolutionary strategies with num_tasks samples.

The loss for a given unroll is computed either with the average loss over a trajectory (loss_type=”avg”), or with a computation at the end of training (loss_type=”last”).

Parameters

TruncatedPES

A vetorized, truncated, PES based gradient estimator.

learned_optimization.outer_trainers.truncated_pes.TruncatedPES(task_family, learned_opt, trunc_sched, num_tasks, trunc_length=10, random_initial_iteration_offset=0, std=0.01, steps_per_jit=10, train_and_meta=False)[source]

GradientEstimator for computing PES gradient estimates.

Persistent Evolution Strategies (PES) is a gradient estimation technique for computing unbiased gradients in a unrolled computation graph. It does this by building of of Evolutionary Strategies but additionally keeping a running buffer of all the previously used perturbations. See the paper for more details (http://proceedings.mlr.press/v139/vicol21a.html).

In practice, PES is higher variance than pure truncated ES but lower bias.

Parameters
  • task_family (TaskFamily) –

  • learned_opt (LearnedOptimizer) –

  • trunc_sched (TruncationSchedule) –

  • num_tasks (int) –

  • random_initial_iteration_offset (int) –