Outer Trainers
Outer-training, or meta-training is the process of learning the meta-parameters.
Base
- learned_optimization.outer_trainers.gradient_learner.GradientEstimator()[source]
Base class for classes which estimate grads (via ES, PES, or backprop).
- learned_optimization.outer_trainers.gradient_learner.gradient_worker_compute(worker_weights, gradient_estimators, unroll_states, key, with_metrics, clip_nan_loss_to_value=20.0, extra_metrics=True, device=None)[source]
Compute a gradient signal to meta-train with.
This function performs unrolls for each of the unroll_states with the corresponding gradient_estimator. The results from each of the gradient estimators get’s merged into a single gradient. This aggregation is done to save bandwidth when collecting gradients from workers.
- Parameters:
worker_weights (
WorkerWeights
) – Weights created by the GradientLearner and represent the current parameters and model state of the learned optimizer.gradient_estimators (
Sequence
[GradientEstimator
]) – The gradient estimators used to update the unroll stateunroll_states (
Sequence
[GradientEstimatorState
]) – state of the gradient estimator (e.g. inner problem weights)key (
Array
) – jax rngwith_metrics (
bool
) – compute with summary metrics or notclip_nan_loss_to_value (
Optional
[float
]) – float, value to set nan losses toextra_metrics (
bool
) – log out additional metrics.device (
Optional
[Device
]) – The jax device to run the computation on
- Returns:
- The results of the computation.
This contains a gradient estimate, the next unroll states, metrics. A subset of which get passed to the GradientLearner.
- Return type:
worker_compute_out
- learned_optimization.outer_trainers.gradient_learner.SingleMachineGradientLearner(meta_init, gradient_estimators, theta_opt, num_steps=None)[source]
Train with gradient estimators on a single machine.
This is a convience wrapper calling the multi-worker interface – namley both GradientLearner and gradient_worker_compute.
- Parameters:
meta_init (
MetaInitializer
) –gradient_estimators (
Sequence
[GradientEstimator
]) –theta_opt (
Optimizer
) –
FullES
A vetorized, full length unroll, ES based gradient estimator.
- learned_optimization.outer_trainers.full_es.FullES(truncated_step, truncation_schedule, std=0.01, steps_per_jit=10, loss_type='avg', recompute_samples=50, clip_loss_diff=None, stack_antithetic_samples=False, sign_delta_loss_scalar=None)[source]
Gradient Estimator for computing ES gradients over some number of task.
This gradient estimator uses antithetic evolutionary strategies with num_tasks samples.
The loss for a given unroll is computed either with the average loss over a trajectory (loss_type=”avg”), the min loss (loss_type=”min”) or with a computation at the end of training (loss_type=”last_recompute”).
This gradient estimator manages it’s own truncations and thus you should ensure the passed in truncated_step doesn’t do any resetting of the unroll. If using the VectorizedLOptTruncatedStep this means passing an instance of NeverEndingTruncationSchedule as the trunc_sched.
TruncatedPES
A vetorized, truncated, PES based gradient estimator.
- learned_optimization.outer_trainers.truncated_pes.TruncatedPES(truncated_step, trunc_length=10, std=0.01, steps_per_jit=10, stack_antithetic_samples=False, sign_delta_loss_scalar=None)[source]
GradientEstimator for computing PES gradient estimates.
Persistent Evolution Strategies (PES) is a gradient estimation technique for computing unbiased gradients in a unrolled computation graph. It does this by building of of Evolutionary Strategies but additionally keeping a running buffer of all the previously used perturbations. See the paper for more details (http://proceedings.mlr.press/v139/vicol21a.html).
In practice, PES is higher variance than pure truncated ES but lower bias.