Source code for learned_optimization.learned_optimizers.base

# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for learned optimizers plus learnable hparam variants."""
import abc
import collections
from typing import Any, Callable

import gin
import haiku as hk
import jax.numpy as jnp
from learned_optimization import summary
from learned_optimization.optimizers import base as opt_base

MetaParamOpt = collections.namedtuple("MetaParamOpt", ["init", "opt_fn"])

PRNGKey = jnp.ndarray
Params = Any
MetaParams = Any


[docs]class LearnedOptimizer(abc.ABC): """Base class for learned optimizers.""" @abc.abstractmethod def init(self, rng: PRNGKey) -> MetaParams: raise NotImplementedError() @abc.abstractmethod def opt_fn(self, theta: MetaParams, is_training: bool = False) -> opt_base.Optimizer: raise NotImplementedError()
Invertable = collections.namedtuple("Invertable", ["forward", "inverse"]) one_minus_log = Invertable( forward=lambda x: jnp.log(1 - x), inverse=lambda x: 1 - jnp.exp(x))
[docs]@gin.configurable class LearnableSGD(LearnedOptimizer): """SGD with learnable hparams.""" def __init__(self, initial_lr=0.01): self.initial_lr = initial_lr def init(self, rng: PRNGKey) -> MetaParams: return hk.data_structures.to_haiku_dict( {"log_lr": jnp.log(jnp.asarray(self.initial_lr))}) def opt_fn(self, theta, is_training=False) -> opt_base.Optimizer: lr = jnp.exp(theta["log_lr"]) summary.summary(theta["log_lr"], "learnable_sgd/pre_lr") summary.summary(lr, "learnable_sgd/lr") return opt_base.SGD(lr)
[docs]@gin.configurable class LearnableSGDM(LearnedOptimizer): """SGDM with learnable hparams.""" def __init__(self, initial_lr=0.01, initial_momentum=0.9): self.initial_lr = initial_lr self.initial_momentum = initial_momentum def init(self, rng: PRNGKey) -> MetaParams: return hk.data_structures.to_haiku_dict({ "log_lr": jnp.log(jnp.asarray(self.initial_lr)), "one_minus_momentum": one_minus_log.forward(self.initial_momentum) }) def opt_fn(self, theta: MetaParams, is_training: bool = False) -> opt_base.Optimizer: lr = jnp.exp(theta["log_lr"]) mom = one_minus_log.inverse(theta["one_minus_momentum"]) summary.summary(theta["log_lr"], "learnable_sgdm/pre_lr") summary.summary(lr, "learnable_sgdm/lr") summary.summary(theta["one_minus_momentum"], "learnable_sgdm/pre_mom") summary.summary(mom, "learnable_sgdm/mom") return opt_base.SGDM(lr, mom)
[docs]@gin.configurable class LearnableAdam(LearnedOptimizer): """Adam with learnable hparams.""" def __init__(self, initial_lr=0.001, initial_beta1=0.9, initial_beta2=0.999, initial_epsilon=1e-8, use_summary=True): self.initial_lr = initial_lr self.initial_beta1 = initial_beta1 self.initial_beta2 = initial_beta2 self.initial_epsilon = initial_epsilon self.use_summary = use_summary def init(self, rng: PRNGKey) -> MetaParams: return hk.data_structures.to_haiku_dict({ "log_lr": jnp.log(jnp.asarray(self.initial_lr)), "one_minus_beta1": one_minus_log.forward(self.initial_beta1), "one_minus_beta2": one_minus_log.forward(self.initial_beta2), "log_epsilon": jnp.log(self.initial_epsilon), }) def opt_fn(self, theta: MetaParams, is_training: bool = False) -> opt_base.Optimizer: lr = jnp.exp(theta["log_lr"]) beta1 = one_minus_log.inverse(theta["one_minus_beta1"]) beta2 = one_minus_log.inverse(theta["one_minus_beta2"]) eps = jnp.exp(theta["log_epsilon"]) if self.use_summary: summary.summary(theta["log_lr"], "learnable_adam/pre_lr") summary.summary(lr, "learnable_adam/lr") summary.summary(theta["one_minus_beta1"], "learnable_adam/pre_beta1") summary.summary(beta1, "learnable_adam/beta1") summary.summary(theta["one_minus_beta2"], "learnable_adam/pre_beta2") summary.summary(beta2, "learnable_adam/beta2") summary.summary(theta["log_epsilon"], "learnable_adam/pre_epsilon") summary.summary(eps, "learnable_adam/epsilon") return opt_base.Adam(lr, beta1, beta2, eps)
def learned_optimizer_from_opt(opt: opt_base.Optimizer) -> LearnedOptimizer: """Create a learned optimizer out of a baseline optimizer. Note this does not have any learnable parameters. Args: opt: Optimizer to turn into the LearnedOptimizer interface. Returns: The wrapped learned optimizer. """ class LOpt(LearnedOptimizer): def init(self, rng): return None def opt_fn(self, theta, is_training=False): return opt return LOpt() @gin.configurable def wrap_learned_opt( learned_opt: LearnedOptimizer, opt_wrapper: Callable[[opt_base.Optimizer], opt_base.Optimizer] ) -> LearnedOptimizer: """Wrap a learned optimizer with a wrapper for to Optimizers.""" class LOpt(LearnedOptimizer): def init(self, rng): return learned_opt.init(rng) def opt_fn(self, theta, is_training=False): return opt_wrapper(learned_opt.opt_fn(theta)) return LOpt()