Source code for learned_optimization.optimizers.optax_opts

# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Thin wrapper on top of optax optimizers.

For these optimizers, see optax's implementation / docs for more details.
"""

import functools
from typing import Any, Callable, Optional, Sequence, Union

import chex
from flax import struct
import gin
import jax
import jax.numpy as jnp
from learned_optimization.optimizers import base
import optax

ModelState = Any
Params = Any
Gradient = Params
OptState = Any


@struct.dataclass
class OptaxState:
  params: chex.ArrayTree
  state: chex.ArrayTree
  optax_opt_state: chex.ArrayTree
  iteration: jnp.ndarray


class OptaxOptimizer(base.Optimizer):
  """Wrapper to convert optax optimizers into `Optimizers`."""

  def __init__(self, opt: optax.GradientTransformation):
    super().__init__()
    self.opt = opt

  def init(self,
           params: Params,
           model_state: Optional[ModelState] = None,
           num_steps: Optional[int] = None,
           key: Optional[chex.PRNGKey] = None):
    return OptaxState(  # pytype: disable=wrong-arg-types  # jax-ndarray
        params=params,
        optax_opt_state=self.opt.init(params),
        state=model_state,
        iteration=0,
    )

  @functools.partial(jax.jit, static_argnums=(0,))
  def update(self,
             opt_state: OptaxState,
             grad: Gradient,
             loss: Optional[jnp.ndarray] = None,
             model_state: Optional[ModelState] = None,
             key: Optional[chex.PRNGKey] = None,
             **kwargs):
    del loss
    update, new_opt_state = self.opt.update(grad, opt_state.optax_opt_state,
                                            opt_state.params)
    return OptaxState(
        state=model_state,
        params=optax.apply_updates(opt_state.params, update),
        optax_opt_state=new_opt_state,
        iteration=opt_state.iteration + 1,
    )


[docs] @gin.configurable class SGD(OptaxOptimizer): """Stochastic gradient descent.""" def __init__(self, learning_rate=0.01): self.learning_rate = learning_rate opt = optax.sgd(learning_rate) super().__init__(opt) @property def name(self): return f"SGD_lr{self.learning_rate}"
@gin.configurable class SGDM(OptaxOptimizer): """Stochastic gradient descent with momentum.""" def __init__(self, learning_rate=0.01, momentum=0.9): self.learning_rate = learning_rate self.momentum = momentum opt = optax.sgd(learning_rate, momentum) super().__init__(opt) @property def name(self): return f"SGDM_lr{self.learning_rate}_m{self.momentum}"
[docs] @gin.configurable class Adam(OptaxOptimizer): """Adam optimizer.""" def __init__(self, learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8, epsilon_root=1e-8): self.learning_rate = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.epsilon_root = epsilon_root opt = optax.adam( learning_rate=learning_rate, b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root) super().__init__(opt) @property def name(self): return (f"Adam_lr{self.learning_rate}_b1{self.beta1}_b2{self.beta2}" f"_eps{self.epsilon}_epsroot{self.epsilon_root}")
def piecewise_linear(times: Sequence[float], vals: Sequence[float]) -> Callable[[float], float]: """Returns a function which interpolates piecewise values.""" times = jnp.asarray(times) vals = jnp.asarray(vals) def fn(x): if len(times) <= 1: assert len(vals) == 1 return vals[0] vs = [] all_before = jnp.all(x <= times) all_after = jnp.all(x >= times) for i in range(len(times) - 1): x1 = times[i] x2 = times[i + 1] y1 = vals[i] y2 = vals[i + 1] m = (y2 - y1) / (x2 - x1) v = (x - x1) * m + y1 vs.append(v) idx = jnp.sum(x > times) - 1 mid = jnp.take(jnp.asarray(vs), idx) return all_before * vals[0] + all_after * vals[-1] + mid * ( (1 - all_before) * (1 - all_after)) return fn @gin.configurable class PiecewiseLinearAdam(OptaxOptimizer): """Adam with a piecewise linear learning rate schedule.""" def __init__(self, times=(10000, 20000), lrs=(1e-4, 1e-5), beta1=0.9, beta2=0.999, epsilon=1e-8, epsilon_root=1e-8): opt = optax.chain( optax.scale_by_adam( b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root), optax.scale_by_schedule(piecewise_linear(times, vals=lrs)), optax.scale(-1), ) super().__init__(opt) @gin.configurable class RMSProp(OptaxOptimizer): """RMSProp optimizer (including momentum).""" def __init__( self, learning_rate=0.01, decay=0.9, epsilon=1e-8, momentum=0.0, nesterov=False, ): self.learning_rate = learning_rate self.decay = decay self.epsilon = epsilon self.momentum = momentum self.nesterov = nesterov opt = optax.rmsprop( learning_rate=learning_rate, decay=decay, eps=epsilon, nesterov=nesterov, momentum=momentum, ) super().__init__(opt) @property def name(self): return (f"RMSProp_lr{self.learning_rate}_d{self.decay}_eps{self.epsilon}" f"_m{self.momentum}_nesterov{self.nesterov}") @gin.configurable class AdaBelief(OptaxOptimizer): """AdaBelief optimizer.""" def __init__(self, learning_rate=0.01, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16): opt = optax.adabelief( learning_rate=learning_rate, b1=b1, b2=b2, eps=eps, eps_root=eps_root) super().__init__(opt) @gin.configurable class AdamW(OptaxOptimizer): """AdamW optimizer.""" def __init__( self, learning_rate, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, mu_dtype: Optional[Any] = None, weight_decay: float = 1e-4, mask: Optional[Union[Any, Callable[[Params], Any]]] = None, ): opt = optax.adamw( learning_rate=learning_rate, b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype, weight_decay=weight_decay, mask=mask) super().__init__(opt) @gin.configurable class Fromage(OptaxOptimizer): def __init__(self, learning_rate, min_norm: float = 1e-6): opt = optax.fromage(learning_rate, min_norm) super().__init__(opt) @gin.configurable class Lars(OptaxOptimizer): """Lars optimizer.""" def __init__(self, learning_rate: float, weight_decay: float = 0., weight_decay_mask=True, trust_coefficient: float = 0.001, eps: float = 0., trust_ratio_mask=True, momentum: float = 0.9, nesterov: bool = False, min_norm: float = 1e-6): opt = optax.lars( learning_rate=learning_rate, weight_decay=weight_decay, weight_decay_mask=weight_decay_mask, trust_coefficient=trust_coefficient, eps=eps, trust_ratio_mask=trust_ratio_mask, momentum=momentum, nesterov=nesterov) super().__init__(opt) @gin.configurable class Lamb(OptaxOptimizer): """Lamb optimizer.""" def __init__(self, learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-6, eps_root: float = 0.0, weight_decay: float = 0., mask=None): opt = optax.lamb( learning_rate=learning_rate, b1=b1, b2=b2, eps=eps, eps_root=eps_root, weight_decay=weight_decay, mask=mask) super().__init__(opt) @gin.configurable class RAdam(OptaxOptimizer): """RAdam optimizer.""" def __init__(self, learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, threshold: float = 5.0): opt = optax.radam( learning_rate=learning_rate, b1=b1, b2=b2, eps=eps, eps_root=eps_root, threshold=threshold) super().__init__(opt) @struct.dataclass class SM3OptState: params: Params state: ModelState optax_opt_state: Any iteration: int should_reshape: Any = struct.field(pytree_node=False) def _expand_scalar(x, r): return jnp.expand_dims(x, 0) if r else x def _sm3( learning_rate: float, momentum: float = 0.9, b2: float = 1.0, ): return optax.chain( optax.scale_by_sm3(momentum, b2=b2), optax.scale(-learning_rate), ) @gin.configurable class SM3(OptaxOptimizer): """SM3 optimizer.""" def __init__(self, learning_rate: float, momentum: float = 0.9, b2: float = 1.0): opt = _sm3(learning_rate=learning_rate, momentum=momentum, b2=b2) super().__init__(opt) # SM3 doesn't support scalars, so we have to reshape the params and grads. def init( self, params: Any, model_state: Optional[Any] = None, num_steps: Optional[int] = None, key: Optional[chex.PRNGKey] = None, ) -> SM3OptState: should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape) out = super().init(params, model_state, num_steps, key) return SM3OptState( # pytype: disable=wrong-arg-types # jax-ndarray params=out.params, state=out.state, optax_opt_state=out.optax_opt_state, iteration=out.iteration, should_reshape=should_reshape, ) def update(self, opt_state: SM3OptState, grad: Any, loss: Optional[jnp.ndarray] = None, model_state: Optional[Any] = None, key: Optional[chex.PRNGKey] = None, **kwargs: Any) -> SM3OptState: grad = jax.tree_util.tree_map(_expand_scalar, grad, opt_state.should_reshape) out = super().update(opt_state, grad, loss, model_state, key, **kwargs) return SM3OptState( params=out.params, state=out.state, optax_opt_state=out.optax_opt_state, iteration=out.iteration, should_reshape=opt_state.should_reshape) def get_params(self, state: Any) -> Any: def _to_scalar(x, r): return jnp.squeeze(x, 0) if r else x return jax.tree_util.tree_map(_to_scalar, state.params, state.should_reshape) @gin.configurable class Yogi(OptaxOptimizer): """Yogi optimizer.""" def __init__(self, learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-3): opt = optax.yogi(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps) super().__init__(opt) @gin.configurable class Adafactor(OptaxOptimizer): """Adafactor optimizer.""" def __init__(self, learning_rate: float, min_dim_size_to_factor: int = 128, decay_rate: float = 0.8, decay_offset: int = 0, multiply_by_parameter_scale: float = True, clipping_threshold: Optional[float] = 1.0, momentum: Optional[float] = None, dtype_momentum: Any = jnp.float32, weight_decay_rate: Optional[float] = None, eps: float = 1e-30, factored: bool = True, weight_decay_mask=None): opt = optax.adafactor( learning_rate=learning_rate, min_dim_size_to_factor=min_dim_size_to_factor, decay_rate=decay_rate, decay_offset=decay_offset, multiply_by_parameter_scale=multiply_by_parameter_scale, clipping_threshold=clipping_threshold, momentum=momentum, dtype_momentum=dtype_momentum, weight_decay_rate=weight_decay_rate, eps=eps, factored=factored, weight_decay_mask=weight_decay_mask) super().__init__(opt) @gin.configurable class AdaGrad(OptaxOptimizer): """AdaGrad optimizer.""" def __init__(self, learning_rate: float, initial_accumulator_value: float = 0.1, eps: float = 1e-7): opt = optax.adagrad( learning_rate=learning_rate, initial_accumulator_value=initial_accumulator_value, eps=eps) super().__init__(opt) # TODO(lmetz) deprecate and/or delete this! # Put the basic optimizers in base namespace for compatibility. base.SGD = SGD base.Adam = Adam base.RMSProp = RMSProp