Source code for learned_optimization.optimizers.base

# coding=utf-8
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for Optimizer and a couple hand designed optimizer."""

import abc
from typing import Any, Optional, Tuple
import warnings

import chex
import flax
import gin
import jax
import jax.numpy as jnp
from learned_optimization import tree_utils


# pytree containing jax types
ModelState = Any
Params = Any
Gradient = Params
OptState = Any


@flax.struct.dataclass
class StatelessState:
  params: chex.ArrayTree
  state: chex.ArrayTree


[docs] class Optimizer(abc.ABC): """Baseclass for the Optimizer interface.""" def get_params(self, state: OptState) -> Params: return state.params def get_state(self, state: OptState) -> ModelState: return state.state def get_params_state(self, state: OptState) -> Tuple[Params, ModelState]: return self.get_params(state), self.get_state(state) def init(self, params: Params, state: Optional[ModelState] = None, num_steps: Optional[int] = None, key: Optional[chex.PRNGKey] = None, **kwargs) -> OptState: raise NotImplementedError def set_params(self, state: OptState, params: Params) -> OptState: return state.replace(params=params) def update( self, opt_state: OptState, grad: Gradient, model_state: Optional[ModelState] = None, key: Optional[chex.PRNGKey] = None, **kwargs, ) -> OptState: raise NotImplementedError() @property def name(self) -> str: """Name of optimizer. This property is used when serializing results / baselines. This should lead with the class name, and follow with all parameters used to create the object. For example: "<ClassName>_<param1><value>_<param2><value>" """ return "UnnamedOptimizer"
@gin.configurable class GradientClipOptimizer(Optimizer): """Clip gradients by value before passing into an optimizer.""" def __init__(self, opt: Optimizer, grad_clip: float = 3.0): if not isinstance(opt, Optimizer): raise ValueError("Must instance of Optimizer. Maybe you are passing the" f" class and not an instance? Received {opt}.") self.opt = opt self.grad_clip = grad_clip def get_params(self, state): return self.opt.get_params(state) def get_state(self, state): return self.opt.get_state(state) def init(self, *args, **kwargs): return self.opt.init(*args, **kwargs) def update(self, opt_state, grad, *args, **kwargs): grad = jax.tree_util.tree_map( lambda x: jnp.clip(x, -self.grad_clip, self.grad_clip), grad) return self.opt.update(opt_state, grad, *args, **kwargs) @flax.struct.dataclass class GraftedOptimizerState: iteration: jnp.ndarray params: chex.ArrayTree state: chex.ArrayTree mag_opt_state: chex.ArrayTree dir_opt_state: chex.ArrayTree @gin.configurable() class GraftedOptimizer(Optimizer): """Implements Learning Rate Grafting. Reference: https://openreview.net/forum?id=FpKgG31Z_i9 """ def __init__(self, magnitude_opt: Optimizer, direction_opt: Optimizer): self.magnitude_opt = magnitude_opt self.direction_opt = direction_opt def init(self, params, model_state=None, num_steps=None, **kwargs): return GraftedOptimizerState( iteration=jnp.asarray(0, dtype=jnp.int32), params=params, state=model_state, mag_opt_state=self.magnitude_opt.init( params, model_state=model_state, num_steps=num_steps, **kwargs), dir_opt_state=self.direction_opt.init( params, model_state=model_state, num_steps=num_steps, **kwargs)) def update(self, opt_state, grad, model_state=None, **kwargs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks base_params = opt_state.params next_mag_opt_state = self.magnitude_opt.update( opt_state.mag_opt_state, grad, model_state=model_state, **kwargs) next_mag_params = self.magnitude_opt.get_params(next_mag_opt_state) next_dir_opt_state = self.direction_opt.update( opt_state.dir_opt_state, grad, model_state=model_state, **kwargs) next_dir_params = self.direction_opt.get_params(next_dir_opt_state) mag_step = tree_utils.tree_sub(next_mag_params, base_params) dir_step = tree_utils.tree_sub(next_dir_params, base_params) step_size = tree_utils.tree_norm(mag_step) / tree_utils.tree_norm(dir_step) next_params = tree_utils.tree_add(base_params, tree_utils.tree_mul(dir_step, step_size)) next_dir_opt_state = next_dir_opt_state.replace(params=next_params) next_mag_opt_state = next_mag_opt_state.replace(params=next_params) return GraftedOptimizerState( iteration=opt_state.iteration + 1, params=next_params, state=model_state, mag_opt_state=next_mag_opt_state, dir_opt_state=next_dir_opt_state, ) # TODO(lmetz) remove these in May 2022. def SGD(*args, **kwargs): # pylint: disable=invalid-name from learned_optimization.optimizers import optax_opts # pytype: disable=import-error # pylint: disable=g-import-not-at-top warnings.warn("SGD module has been moved to optax_opts!" " Calling here from base is deprecated!") return optax_opts.SGD(*args, **kwargs)
[docs] def SGDM(*args, **kwargs): # pylint: disable=invalid-name from learned_optimization.optimizers import optax_opts # pytype: disable=import-error # pylint: disable=g-import-not-at-top warnings.warn("SGDM module has been moved to optax_opts!" " Calling here from base is deprecated!") return optax_opts.SGDM(*args, **kwargs)
def RMSProp(*args, **kwargs): # pylint: disable=invalid-name from learned_optimization.optimizers import optax_opts # pytype: disable=import-error # pylint: disable=g-import-not-at-top warnings.warn("RMSProp module has been moved to optax_opts!" " Calling here from base is deprecated!") return optax_opts.RMSProp(*args, **kwargs) def Adam(*args, **kwargs): # pylint: disable=invalid-name from learned_optimization.optimizers import optax_opts # pytype: disable=import-error # pylint: disable=g-import-not-at-top warnings.warn("Adammodule has been moved to optax_opts!" " Calling here from base is deprecated!") return optax_opts.Adam(*args, **kwargs)