def _get_optimizer_preconditioner(optimizer: tf.keras.optimizers.Optimizer, model_weights: tff.learning.ModelWeights): """Get the preconditioner states from the optimizer.""" config = optimizer.get_config() config_name = config['name'] if config_name == 'Adagrad': eps = optimizer.epsilon v = tf.nest.map_structure( lambda var: optimizer.get_slot(var, 'accumulator'), model_weights.trainable) return tf.nest.map_structure( lambda a: tf.math.divide_no_nan( # pylint: disable=g-long-lambda 1.0, tf.math.sqrt(a) + eps), v) elif config_name in {'Adam', 'Yogi'}: eps = optimizer.epsilon v = tf.nest.map_structure(lambda var: optimizer.get_slot(var, 'v'), model_weights.trainable) return tf.nest.map_structure( lambda a: tf.math.divide_no_nan( # pylint: disable=g-long-lambda 1.0, tf.math.sqrt(a) + eps), v) elif config_name == 'SGD': return tf.nest.map_structure(tf.ones_like, model_weights.trainable) else: raise TypeError( 'client optimizer should be one of these: SGD, Adagrad, Adam, Yogi.' )
def _get_optimizer_momentum_beta(optimizer: tf.keras.optimizers.Optimizer): """Get the momentum beta value from the optimizer.""" config = optimizer.get_config() config_name = config['name'] if config_name == 'SGD': return config['momentum'] elif config_name == 'Adam': return config['beta_1'] elif config_name == 'Yogi': return config['beta1'] elif config_name == 'Adagrad': return 0.0 else: raise TypeError( 'client optimizer should be one of these: SGD, Adagrad, Adam, Yogi.')
def _get_optimizer_preconditioner(optimizer: tf.keras.optimizers.Optimizer, model_weights: tff.learning.ModelWeights): """Get the preconditioner states from the optimizer.""" config = optimizer.get_config() config_name = config['name'] if config_name == 'Adagrad': return tf.nest.map_structure( lambda var: optimizer.get_slot(var, 'accumulator'), model_weights.trainable) elif config_name in {'Adam', 'Yogi'}: return tf.nest.map_structure(lambda var: optimizer.get_slot(var, 'v'), model_weights.trainable) elif config_name == 'SGD': return tf.nest.map_structure(tf.ones_like, model_weights.trainable) else: raise TypeError( 'client optimizer should be one of these: SGD, Adagrad, Adam, Yogi.')
def _check_client_optimizer(optimizer: tf.keras.optimizers.Optimizer): config = optimizer.get_config() return config['name'] in SUPPORTED_CLIENT_OPTIMIZERS