Ejemplo n.º 1
0
  def test_serialization(self):
    # Test policies that are equivalent to a single dtype
    for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
      policy = mp_policy.Policy(policy_name)
      config = mp_policy.serialize(policy)
      self.assertEqual(config, policy_name)
      new_policy = mp_policy.deserialize(config)
      self.assertEqual(str(policy), str(new_policy))

    # Test "_infer" policy
    policy = mp_policy.Policy('_infer')
    config = mp_policy.serialize(policy)
    self.assertIsNone(config)
    new_policy = mp_policy.deserialize(config)
    self.assertEqual(str(policy), str(new_policy))

    class MyPolicy(mp_policy.Policy):
      pass

    # Test policies that do not override the loss scale
    for policy in (
        mp_policy.Policy('mixed_float16'),
        mp_policy.Policy('mixed_bfloat16'),
        MyPolicy('float32')
    ):
      config = mp_policy.serialize(policy)
      self.assertEqual(config, {'class_name': policy.__class__.__name__,
                                'config': {'name': policy.name}})
      new_policy = mp_policy.deserialize(config,
                                         custom_objects={'MyPolicy': MyPolicy})
      self.assertEqual(str(policy), str(new_policy))

    # Test policies that override the loss scale
    for policy in (
        mp_policy.Policy('float32', loss_scale=2.),
        mp_policy.Policy('float32', loss_scale=None),
        mp_policy.Policy('mixed_float16', loss_scale=2.),
        mp_policy.Policy('mixed_float16', loss_scale=None),
        mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
        mp_policy.Policy('mixed_bfloat16', loss_scale=None),
    ):
      config = mp_policy.serialize(policy)
      expected_loss_scale_config = None
      if policy.loss_scale:
        expected_loss_scale_config = {
            'class_name': 'FixedLossScale',
            'config': {'loss_scale_value': 2.}
        }
      self.assertEqual(
          config, {
              'class_name': policy.__class__.__name__,
              'config': {
                  'name': policy.name,
                  'loss_scale': expected_loss_scale_config
              }
          })
      new_policy = mp_policy.deserialize(
          config, custom_objects={'MyPolicy': MyPolicy})
      self.assertEqual(str(policy), str(new_policy))
    def _python_properties_internal(self):
        """Returns dictionary of all python properties."""
        # TODO(kathywu): Add support for metrics serialization.
        # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
        # the python config serialization has caught up.
        metadata = dict(
            class_name=type(self.obj).__name__,
            name=self.obj.name,
            trainable=self.obj.trainable,
            expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
            dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
            batch_input_shape=getattr(self.obj, '_batch_input_shape', None),
            stateful=self.obj.stateful)

        metadata.update(get_config(self.obj))
        if self.obj.input_spec is not None:
            # Layer's input_spec has already been type-checked in the property setter.
            metadata['input_spec'] = nest.map_structure(
                lambda x: generic_utils.serialize_keras_object(x)
                if x else None, self.obj.input_spec)
        if (self.obj.activity_regularizer is not None
                and hasattr(self.obj.activity_regularizer, 'get_config')):
            metadata[
                'activity_regularizer'] = generic_utils.serialize_keras_object(
                    self.obj.activity_regularizer)
        if self.obj._build_input_shape is not None:  # pylint: disable=protected-access
            metadata['build_input_shape'] = self.obj._build_input_shape  # pylint: disable=protected-access
        return metadata
Ejemplo n.º 3
0
    def _python_properties_internal(self):
        """Returns dictionary of all python properties."""
        # TODO(kathywu): Add support for metrics serialization.
        # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
        # the python config serialization has caught up.
        metadata = dict(
            class_name=type(self.obj).__name__,
            name=self.obj.name,
            trainable=self.obj.trainable,
            expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
            dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
            batch_input_shape=getattr(self.obj, '_batch_input_shape', None))

        with generic_utils.skip_failed_serialization():
            # Store the config dictionary, which may be used when reviving the object.
            # When loading, the program will attempt to revive the object from config,
            # and if that fails, the object will be revived from the SavedModel.
            config = generic_utils.serialize_keras_object(self.obj)['config']
            if config is not None:
                metadata['config'] = config
        if self.obj.input_spec is not None:
            # Layer's input_spec has already been type-checked in the property setter.
            metadata['input_spec'] = nest.map_structure(
                lambda x: generic_utils.serialize_keras_object(x)
                if x else None, self.obj.input_spec)
        if (self.obj.activity_regularizer is not None
                and hasattr(self.obj.activity_regularizer, 'get_config')):
            metadata[
                'activity_regularizer'] = generic_utils.serialize_keras_object(
                    self.obj.activity_regularizer)
        return metadata
Ejemplo n.º 4
0
 def _python_properties_internal(self):
     """Returns dictionary of all python properties."""
     # TODO(kathywu): Add support for metrics serialization.
     # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
     # the python config serialization has caught up.
     metadata = dict(
         class_name=type(self.obj).__name__,
         name=self.obj.name,
         trainable=self.obj.trainable,
         expects_training_arg=self.obj._expects_training_arg,  # pylint: disable=protected-access
         dtype=policy.serialize(self.obj._dtype_policy),  # pylint: disable=protected-access
         batch_input_shape=getattr(self.obj, '_batch_input_shape', None))
     try:
         # Store the config dictionary, which is only used by the revived object
         # to return the original config when revived_obj.get_config() is called.
         # It is not important for recreating the revived object.
         metadata['config'] = self.obj.get_config()
     except NotImplementedError:
         # in the case of a subclassed model, the get_config() method will throw
         # a NotImplementedError.
         pass
     if self.obj.input_spec is not None:
         # Layer's input_spec has already been type-checked in the property setter.
         metadata['input_spec'] = nest.map_structure(
             lambda x: None if x is None else serialize_keras_object(x),
             self.obj.input_spec)
     if (self.obj.activity_regularizer is not None
             and hasattr(self.obj.activity_regularizer, 'get_config')):
         metadata['activity_regularizer'] = serialize_keras_object(
             self.obj.activity_regularizer)
     return metadata