def _get_optimizer_preconditioner(optimizer: tf.keras.optimizers.Optimizer,
                                  model_weights: tff.learning.ModelWeights):
    """Get the preconditioner states from the optimizer."""
    config = optimizer.get_config()
    config_name = config['name']
    if config_name == 'Adagrad':
        eps = optimizer.epsilon
        v = tf.nest.map_structure(
            lambda var: optimizer.get_slot(var, 'accumulator'),
            model_weights.trainable)
        return tf.nest.map_structure(
            lambda a: tf.math.divide_no_nan(  # pylint: disable=g-long-lambda
                1.0,
                tf.math.sqrt(a) + eps),
            v)
    elif config_name in {'Adam', 'Yogi'}:
        eps = optimizer.epsilon
        v = tf.nest.map_structure(lambda var: optimizer.get_slot(var, 'v'),
                                  model_weights.trainable)
        return tf.nest.map_structure(
            lambda a: tf.math.divide_no_nan(  # pylint: disable=g-long-lambda
                1.0,
                tf.math.sqrt(a) + eps),
            v)
    elif config_name == 'SGD':
        return tf.nest.map_structure(tf.ones_like, model_weights.trainable)
    else:
        raise TypeError(
            'client optimizer should be one of these: SGD, Adagrad, Adam, Yogi.'
        )
Пример #2
0
def _get_optimizer_momentum_beta(optimizer: tf.keras.optimizers.Optimizer):
  """Get the momentum beta value from the optimizer."""
  config = optimizer.get_config()
  config_name = config['name']
  if config_name == 'SGD':
    return config['momentum']
  elif config_name == 'Adam':
    return config['beta_1']
  elif config_name == 'Yogi':
    return config['beta1']
  elif config_name == 'Adagrad':
    return 0.0
  else:
    raise TypeError(
        'client optimizer should be one of these: SGD, Adagrad, Adam, Yogi.')
Пример #3
0
def _get_optimizer_preconditioner(optimizer: tf.keras.optimizers.Optimizer,
                                  model_weights: tff.learning.ModelWeights):
  """Get the preconditioner states from the optimizer."""
  config = optimizer.get_config()
  config_name = config['name']
  if config_name == 'Adagrad':
    return tf.nest.map_structure(
        lambda var: optimizer.get_slot(var, 'accumulator'),
        model_weights.trainable)
  elif config_name in {'Adam', 'Yogi'}:
    return tf.nest.map_structure(lambda var: optimizer.get_slot(var, 'v'),
                                 model_weights.trainable)
  elif config_name == 'SGD':
    return tf.nest.map_structure(tf.ones_like, model_weights.trainable)
  else:
    raise TypeError(
        'client optimizer should be one of these: SGD, Adagrad, Adam, Yogi.')
Пример #4
0
def _check_client_optimizer(optimizer: tf.keras.optimizers.Optimizer):
  config = optimizer.get_config()
  return config['name'] in SUPPORTED_CLIENT_OPTIMIZERS