Outer Trainers
Outer-training, or meta-training is the process of learning the meta-parameters.
Base
- learned_optimization.outer_trainers.gradient_learner.GradientEstimator()[source]
Base class for classes which estimate grads (via ES, PES, or backprop).
- learned_optimization.outer_trainers.gradient_learner.gradient_worker_compute(worker_weights, gradient_estimators, unroll_states, key, with_metrics, device=None)[source]
Compute a gradient signal to meta-train with.
This function performs unrolls for each of the unroll_states with the corresponding gradient_estimator. The results from each of the gradient estimators get’s merged into a single gradient. This aggregation is done to save bandwidth when collecting gradients from workers.
- Parameters
worker_weights (
WorkerWeights
) – Weights created by the GradientLearner and represent the current parameters and model state of the learned optimizer.gradient_estimators (
Sequence
[GradientEstimator
]) – The gradient estimators used to update the unroll stateunroll_states (
Sequence
[GradientEstimatorState
]) – state of the gradient estimator (e.g. inner problem weights)key (
ndarray
) – jax rngwith_metrics (
bool
) – compute with summary metrics or notdevice (
Optional
[Device
]) – The jax device to run the computation on
- Returns
- The results of the computation.
This contains a gradient estimate, the next unroll states, metrics. A subset of which get passed to the GradientLearner.
- Return type
worker_compute_out
- learned_optimization.outer_trainers.gradient_learner.SingleMachineGradientLearner(learned_opt, gradient_estimators, theta_opt, num_steps=None)[source]
Train with gradient estimators on a single machine.
This is a convience wrapper calling the multi-worker interface – namley both GradientLearner and gradient_worker_compute.
- Parameters
learned_opt (
LearnedOptimizer
) –gradient_estimators (
Sequence
[GradientEstimator
]) –theta_opt (
Optimizer
) –
FullES
A vetorized, full length unroll, ES based gradient estimator.
- learned_optimization.outer_trainers.full_es.FullES(task_family, learned_opt, num_tasks, unroll_length=20, std=0.01, steps_per_jit=10, train_and_meta=False, loss_type='avg', recompute_samples=50, recompute_split='train', clip_loss_diff=None, stack_antithetic_samples=False)[source]
Gradient Estimator for computing ES gradients over some number of task.
This gradient estimator uses antithetic evolutionary strategies with num_tasks samples.
The loss for a given unroll is computed either with the average loss over a trajectory (loss_type=”avg”), or with a computation at the end of training (loss_type=”last”).
- Parameters
TruncatedPES
A vetorized, truncated, PES based gradient estimator.
- learned_optimization.outer_trainers.truncated_pes.TruncatedPES(task_family, learned_opt, trunc_sched, num_tasks, trunc_length=10, random_initial_iteration_offset=0, std=0.01, steps_per_jit=10, train_and_meta=False)[source]
GradientEstimator for computing PES gradient estimates.
Persistent Evolution Strategies (PES) is a gradient estimation technique for computing unbiased gradients in a unrolled computation graph. It does this by building of of Evolutionary Strategies but additionally keeping a running buffer of all the previously used perturbations. See the paper for more details (http://proceedings.mlr.press/v139/vicol21a.html).
In practice, PES is higher variance than pure truncated ES but lower bias.
- Parameters
task_family (
TaskFamily
) –learned_opt (
LearnedOptimizer
) –trunc_sched (
TruncationSchedule
) –num_tasks (
int
) –random_initial_iteration_offset (
int
) –