def from_config(cls, config, custom_objects=None): config = config.copy() # Make a copy, since we mutate config # If loss_scale is in config, we assume we are deserializing a # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are # deserializing a LossScaleOptimizer from TF 2.4 or above. if 'loss_scale' in config: config['loss_scale'] = keras_loss_scale_module.deserialize( config['loss_scale']) if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale) and config['loss_scale'].multiplier != 2): raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 'DynamicLossScale whose multiplier is not 2. Got ' 'DynamicLossScale: %s' % (config['loss_scale'],)) config['optimizer'] = optimizers.deserialize( config['optimizer'], custom_objects=custom_objects) return cls(**config) # We convert the config, as generated by LossScaleOptimizer.get_config, to a # version that can be passed to LossScaleOptimizerV1.__init__ if config['dynamic']: config['loss_scale'] = loss_scale_module.DynamicLossScale( config['initial_scale'], config['dynamic_growth_steps'], multiplier=2) else: config['loss_scale'] = loss_scale_module.FixedLossScale( config['initial_scale']) del config['dynamic'] del config['initial_scale'] del config['dynamic_growth_steps'] config['optimizer'] = optimizers.deserialize( config.pop('inner_optimizer'), custom_objects=custom_objects) return cls(**config)
def from_config(cls, config, custom_objects=None): config = config.copy() # Make a copy, since we mutate config if 'loss_scale' in config: # If loss_scale is in config, we assume we are deserializing a # LossScaleOptimizer from TF 2.3 or below. We convert the config so it # can be deserialized in the current LossScaleOptimizer. loss_scale = keras_loss_scale_module.deserialize( config.pop('loss_scale')) if isinstance(loss_scale, loss_scale_module.FixedLossScale): config['dynamic'] = False config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): config['dynamic'] = True config['initial_scale'] = loss_scale.initial_loss_scale config['dynamic_growth_steps'] = loss_scale.increment_period if loss_scale.multiplier != 2: raise ValueError('Cannot deserialize LossScaleOptimizer with a ' 'DynamicLossScale whose multiplier is not 2. Got ' 'DynamicLossScale: %s' % (loss_scale,)) else: raise ValueError( 'Serialized LossScaleOptimizers with a LossScale that is neither a ' 'FixedLossScale nor a DynamicLossScale can no longer be ' 'deserialized') config['inner_optimizer'] = config.pop('optimizer') config['inner_optimizer'] = optimizers.deserialize( config['inner_optimizer'], custom_objects=custom_objects) return cls(**config)
def from_config(cls, config, custom_objects=None): if 'loss_scale' in config and isinstance(config['loss_scale'], dict): config = config.copy() config['loss_scale'] = keras_loss_scale_module.deserialize( config['loss_scale'], custom_objects=custom_objects) return cls(**config)
def __init__(self, optimizer, loss_scale): warn_msg_prefix = ( 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' 'instead. ') if isinstance(loss_scale, dict): loss_scale = keras_loss_scale_module.deserialize(loss_scale) if isinstance(loss_scale, (int, float)): tf_logging.warn( warn_msg_prefix + 'For example:\n' ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, initial_scale=loss_scale) elif isinstance(loss_scale, loss_scale_module.FixedLossScale): ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access tf_logging.warn( warn_msg_prefix + 'For example:\n' ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 'opt, dynamic=False, initial_scale={})'.format(ls_val)) super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, initial_scale=ls_val) elif loss_scale == 'dynamic': tf_logging.warn( warn_msg_prefix + 'For example:\n' ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 'opt)') super(LossScaleOptimizerV1, self).__init__(optimizer) elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): kwargs = {} extra_arguments = '' if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: kwargs['initial_scale'] = loss_scale.initial_loss_scale extra_arguments += (', initial_scale=%s' % loss_scale.initial_loss_scale) if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: kwargs['dynamic_growth_steps'] = loss_scale.increment_period extra_arguments += (', dynamic_growth_steps=%s' % loss_scale.increment_period) if loss_scale.multiplier != 2: raise ValueError('When passing a DynamicLossScale to "loss_scale", ' 'DynamicLossScale.multiplier must be 2. Got: %s' % (loss_scale,)) tf_logging.warn( warn_msg_prefix + 'Note that the non-experimental LossScaleOptimizer does not take a ' 'DynamicLossScale but instead takes the dynamic configuration ' 'directly in the constructor. For example:\n' ' opt = tf.keras.mixed_precision.LossScaleOptimizer(' 'opt{})\n'.format(extra_arguments)) super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) elif isinstance(loss_scale, loss_scale_module.LossScale): raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' 'DynamicLossScale is no longer supported. Got: {}' .format(loss_scale)) else: raise ValueError('Invalid value passed to loss_scale. loss_scale ' 'must be the string "dynamic" (recommended), an int, ' 'a float, a FixedLossScale, or a DynamicLossScale. Got ' 'value: {}'.format(loss_scale))