示例#1
0
    def test_encode_decode_type_spec(self):
        spec = tf.TensorSpec((1, 5), tf.float32)
        string = json_utils.Encoder().encode(spec)
        loaded = json_utils.decode(string)
        self.assertEqual(spec, loaded)

        invalid_type_spec = {
            "class_name": "TypeSpec",
            "type_spec": "Invalid Type",
            "serialized": None,
        }
        string = json_utils.Encoder().encode(invalid_type_spec)
        with self.assertRaisesRegexp(ValueError,
                                     "No TypeSpec has been registered"):
            loaded = json_utils.decode(string)
示例#2
0
    def test_encode_decode_type_spec(self):
        spec = tf.TensorSpec((1, 5), tf.float32)
        string = json_utils.Encoder().encode(spec)
        loaded = json_utils.decode(string)
        self.assertEqual(spec, loaded)

        invalid_type_spec = {
            'class_name': 'TypeSpec',
            'type_spec': 'Invalid Type',
            'serialized': None
        }
        string = json_utils.Encoder().encode(invalid_type_spec)
        with self.assertRaisesRegexp(ValueError,
                                     'No TypeSpec has been registered'):
            loaded = json_utils.decode(string)
示例#3
0
    def test_saved_module_paths_and_class_names(self):
        temp_dir = os.path.join(self.get_temp_dir(), "my_model")
        subclassed_model = self._get_subclassed_model()
        x = np.random.random((100, 32))
        y = np.random.random((100, 1))
        subclassed_model.fit(x, y, epochs=1)
        subclassed_model._save_new(temp_dir)

        file_path = os.path.join(temp_dir, saving_lib._CONFIG_FILE)
        with tf.io.gfile.GFile(file_path, "r") as f:
            config_json = f.read()
        config_dict = json_utils.decode(config_json)
        self.assertEqual(
            config_dict["registered_name"], "my_custom_package>CustomModelX"
        )
        self.assertIsNone(config_dict["config"]["optimizer"]["module"])
        self.assertEqual(
            config_dict["config"]["optimizer"]["class_name"],
            "keras.optimizers.Adam",
        )
        self.assertEqual(
            config_dict["config"]["loss"]["module"],
            "keras.engine.compile_utils",
        )
        self.assertEqual(
            config_dict["config"]["loss"]["class_name"], "LossesContainer"
        )
示例#4
0
def load(dirpath):
    """Load a saved python model."""
    file_path = os.path.join(dirpath, _CONFIG_FILE)
    with tf.io.gfile.GFile(file_path, "r") as f:
        config_json = f.read()
    config_dict = json_utils.decode(config_json)
    return deserialize_keras_object(config_dict)
示例#5
0
  def _load_layer(self, node_id, identifier, metadata):
    """Load a single layer from a SavedUserObject proto."""
    metadata = json_utils.decode(metadata)

    # If node was already created
    if node_id in self.loaded_nodes:
      node, setter = self.loaded_nodes[node_id]

      # Revive setter requires the object to have a `_serialized_attributes`
      # property. Add it here.
      _maybe_add_serialized_attributes(node, metadata)

      config = metadata.get('config')
      if _is_graph_network(node) and generic_utils.validate_config(config):
        child_nodes = self._get_child_layer_node_ids(node_id)
        self.model_layer_dependencies[node_id] = (node, child_nodes)
        if not child_nodes:
          self._models_to_reconstruct.append(node_id)
      return node, setter

    # Detect whether this object can be revived from the config. If not, then
    # revive from the SavedModel instead.
    obj, setter = self._revive_from_config(identifier, metadata, node_id)
    if obj is None:
      obj, setter = revive_custom_object(identifier, metadata)

    # Add an attribute that stores the extra functions/objects saved in the
    # SavedModel. Most of these functions/objects are ignored, but some are
    # used later in the loading process (e.g. the list of regularization
    # losses, or the training config of compiled models).
    _maybe_add_serialized_attributes(obj, metadata)
    return obj, setter
示例#6
0
    def test_encode_decode_tuple(self):
        metadata = {'key1': (3, 5), 'key2': [(1, (3, 4)), (1, )]}
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'], (3, 5))
        self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1, )])
示例#7
0
    def test_encode_decode_tuple(self):
        metadata = {"key1": (3, 5), "key2": [(1, (3, 4)), (1, )]}
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {"key1", "key2"})
        self.assertAllEqual(loaded["key1"], (3, 5))
        self.assertAllEqual(loaded["key2"], [(1, (3, 4)), (1, )])
示例#8
0
    def test_encode_decode_enum(self):
        class Enum(enum.Enum):
            CLASS_A = "a"
            CLASS_B = "b"

        config = {"key": Enum.CLASS_A, "key2": Enum.CLASS_B}
        string = json_utils.Encoder().encode(config)
        loaded = json_utils.decode(string)
        self.assertAllEqual({"key": "a", "key2": "b"}, loaded)
示例#9
0
    def test_encode_decode_enum(self):
        class Enum(enum.Enum):
            CLASS_A = 'a'
            CLASS_B = 'b'

        config = {'key': Enum.CLASS_A, 'key2': Enum.CLASS_B}
        string = json_utils.Encoder().encode(config)
        loaded = json_utils.decode(string)
        self.assertAllEqual({'key': 'a', 'key2': 'b'}, loaded)
示例#10
0
    def test_encode_decode_extension_type_tensor(self):
        class MaskedTensor(tf.experimental.ExtensionType):
            __name__ = 'MaskedTensor'
            values: tf.Tensor
            mask: tf.Tensor

        x = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],
                         mask=[[True, True, False], [True, False, True]])
        string = json_utils.Encoder().encode(x)
        loaded = json_utils.decode(string)
        self.assertAllEqual(loaded, x)
示例#11
0
    def test_encode_decode_tensor_shape(self):
        metadata = {
            "key1": tf.TensorShape(None),
            "key2": [tf.TensorShape([None]), tf.TensorShape([3, None, 5])],
        }
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {"key1", "key2"})
        self.assertAllEqual(loaded["key1"].rank, None)
        self.assertAllEqual(loaded["key2"][0].as_list(), [None])
        self.assertAllEqual(loaded["key2"][1].as_list(), [3, None, 5])
示例#12
0
def _update_to_current_version(metadata):
  """Applies version updates to the metadata proto for backwards compat."""
  for node in metadata.nodes:
    if node.version.producer == 1 and node.identifier in [
        constants.MODEL_IDENTIFIER, constants.SEQUENTIAL_IDENTIFIER,
        constants.NETWORK_IDENTIFIER]:
      node_metadata = json_utils.decode(node.metadata)
      save_spec = node_metadata.get('save_spec')

      if save_spec is not None:
        node_metadata['full_save_spec'] = ([save_spec], {})
        node.metadata = json_utils.Encoder().encode(node_metadata)
  return metadata
示例#13
0
    def test_encode_decode_tensor_shape(self):
        metadata = {
            'key1': tf.TensorShape(None),
            'key2': [tf.TensorShape([None]),
                     tf.TensorShape([3, None, 5])]
        }
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'].rank, None)
        self.assertAllEqual(loaded['key2'][0].as_list(), [None])
        self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
示例#14
0
  def testAddFullSaveSpec(self):
    save_spec = tf.TensorSpec([3, 5], dtype=tf.int32)
    node_metadata = json_utils.Encoder().encode({'save_spec': save_spec})

    metadata = saved_metadata_pb2.SavedMetadata()
    metadata.nodes.add(
        version=versions_pb2.VersionDef(
            producer=1, min_consumer=1, bad_consumers=[]),
        identifier='_tf_keras_model',
        metadata=node_metadata)  # pylint: disable=protected-access

    new_metadata = keras_load._update_to_current_version(metadata)
    node_metadata = json_utils.decode(new_metadata.nodes[0].metadata)
    expected_full_spec = ([tf.TensorSpec(shape=(3, 5), dtype=tf.int32)], {})
    self.assertAllEqual(expected_full_spec, node_metadata.get('full_save_spec'))
示例#15
0
  def _reconstruct_model(self, model_id, model, layers):
    """Reconstructs the network structure."""
    config = json_utils.decode(
        self._proto.nodes[model_id].user_object.metadata)['config']

    # Set up model inputs
    if model.inputs:
      # Inputs may already be created if the model is instantiated in another
      # object's __init__.
      pass
    elif isinstance(model, models_lib.Sequential):
      if not layers or not isinstance(layers[0], input_layer.InputLayer):
        if config['layers'][0]['class_name'] == 'InputLayer':
          layers.insert(0, input_layer.InputLayer.from_config(
              config['layers'][0]['config']))
        elif 'batch_input_shape' in config['layers'][0]['config']:
          batch_input_shape = config['layers'][0]['config']['batch_input_shape']
          layers.insert(0, input_layer.InputLayer(
              input_shape=batch_input_shape[1:],
              batch_size=batch_input_shape[0],
              dtype=layers[0].dtype,
              name=layers[0].name + '_input'))
      model.__init__(layers, name=config['name'])
      if not model.inputs:
        first_layer = self._get_child_layer_node_ids(model_id)[0]
        input_specs = self._infer_inputs(first_layer)
        input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
        model._set_inputs(input_specs)  # pylint: disable=protected-access
        if not model.built and not isinstance(input_specs, dict):
          model.build(input_shapes)
    else:  # Reconstruct functional model
      (inputs, outputs,
       created_layers) = functional_lib.reconstruct_from_config(
           config, created_layers={layer.name: layer for layer in layers})
      model.__init__(inputs, outputs, name=config['name'])
      functional_lib.connect_ancillary_layers(model, created_layers)

    # Set model dtype.
    _set_network_attributes_from_metadata(model)

    # Unblock models that are dependent on this model.
    self._unblock_model_reconstruction(model_id, model)
示例#16
0
    def test_saved_module_paths_and_class_names(self):
        temp_dir = os.path.join(self.get_temp_dir(), 'my_model')
        subclassed_model = self._get_subclassed_model()
        x = np.random.random((100, 32))
        y = np.random.random((100, 1))
        subclassed_model.fit(x, y, epochs=1)
        subclassed_model._save_new(temp_dir)

        file_path = os.path.join(temp_dir, saving_lib._CONFIG_FILE)
        with tf.io.gfile.GFile(file_path, 'r') as f:
            config_json = f.read()
        config_dict = json_utils.decode(config_json)
        self.assertEqual(config_dict['registered_name'],
                         'my_custom_package>CustomModelX')
        self.assertIsNone(config_dict['config']['optimizer']['module'])
        self.assertEqual(config_dict['config']['optimizer']['class_name'],
                         'keras.optimizers.Adam')
        self.assertEqual(config_dict['config']['loss']['module'],
                         'keras.engine.compile_utils')
        self.assertEqual(config_dict['config']['loss']['class_name'],
                         'LossesContainer')
示例#17
0
def model_from_json(json_string, custom_objects=None):
    """Parses a JSON model configuration string and returns a model instance.

  Usage:

  >>> model = tf.keras.Sequential([
  ...     tf.keras.layers.Dense(5, input_shape=(3,)),
  ...     tf.keras.layers.Softmax()])
  >>> config = model.to_json()
  >>> loaded_model = tf.keras.models.model_from_json(config)

  Args:
      json_string: JSON string encoding a model configuration.
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.

  Returns:
      A Keras model instance (uncompiled).
  """
    config = json_utils.decode(json_string)
    from keras.layers import deserialize  # pylint: disable=g-import-not-at-top
    return deserialize(config, custom_objects=custom_objects)
示例#18
0
  def _add_children_recreated_from_config(self, obj, proto, node_id):
    """Recursively records objects recreated from config."""
    # pylint: disable=protected-access
    if node_id in self._traversed_nodes_from_config:
      return

    parent_path = self._node_paths[node_id]
    self._traversed_nodes_from_config.add(node_id)
    obj._maybe_initialize_trackable()
    if isinstance(obj, base_layer.Layer) and not obj.built:
      metadata = json_utils.decode(proto.user_object.metadata)
      self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))

    # Create list of all possible children
    children = []
    # Look for direct children
    for reference in proto.children:
      obj_child = obj._lookup_dependency(reference.local_name)
      children.append((obj_child, reference.node_id, reference.local_name))

    # Add metrics that may have been added to the layer._metrics list.
    # This is stored in the SavedModel as layer.keras_api.layer_metrics in
    # SavedModels created after Tf 2.2.
    metric_list_node_id = self._search_for_child_node(
        node_id, [constants.KERAS_ATTR, 'layer_metrics'])
    if metric_list_node_id is not None and hasattr(obj, '_metrics'):
      obj_metrics = {m.name: m for m in obj._metrics}
      for reference in self._proto.nodes[metric_list_node_id].children:
        metric = obj_metrics.get(reference.local_name)
        if metric is not None:
          metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
                                                     reference.local_name)
          children.append((metric, reference.node_id, metric_path))

    for (obj_child, child_id, child_name) in children:
      child_proto = self._proto.nodes[child_id]

      if not isinstance(obj_child, tf.__internal__.tracking.Trackable):
        continue
      if (child_proto.user_object.identifier in
          revived_types.registered_identifiers()):
        setter = revived_types.get_setter(child_proto.user_object)
      elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
        setter = _revive_setter
      else:
        setter = setattr
        # pylint: enable=protected-access

      if child_id in self.loaded_nodes:
        if self.loaded_nodes[child_id][0] is not obj_child:
          # This means that the same trackable object is referenced by two
          # different objects that were recreated from the config.
          logging.warn('Looks like there is an object (perhaps variable or '
                       'layer) that is shared between different layers/models. '
                       'This may cause issues when restoring the variable '
                       'values. Object: {}'.format(obj_child))
        continue

      # Overwrite variable names with the ones saved in the SavedModel.
      if (child_proto.WhichOneof('kind') == 'variable' and
          child_proto.variable.name):
        obj_child._handle_name = child_proto.variable.name + ':0'  # pylint: disable=protected-access

      if isinstance(obj_child, tf.__internal__.tracking.TrackableDataStructure):
        setter = lambda *args: None

      child_path = '{}.{}'.format(parent_path, child_name)
      self._node_paths[child_id] = child_path
      self._add_children_recreated_from_config(
          obj_child, child_proto, child_id)
      self.loaded_nodes[child_id] = obj_child, setter
示例#19
0
 def test_encode_decode_ragged_tensor(self):
     x = tf.ragged.constant([[1.0, 2.0], [3.0]])
     string = json_utils.Encoder().encode(x)
     loaded = json_utils.decode(string)
     self.assertAllEqual(loaded, x)
示例#20
0
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
  """Loads a model saved via `save_model_to_hdf5`.

  Args:
      filepath: One of the following:
          - String, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed. When `compile` is set
      to False, the compilation is omitted without any
      warning.

  Raises:
      ImportError: if h5py is not available.
      ValueError: In case of an invalid savefile.
  """
  if h5py is None:
    raise ImportError('`load_model()` using h5 format requires h5py. Could not '
                      'import h5py.')

  if not custom_objects:
    custom_objects = {}

  opened_new_file = not isinstance(filepath, h5py.File)
  if opened_new_file:
    f = h5py.File(filepath, mode='r')
  else:
    f = filepath

  model = None
  try:
    # instantiate model
    model_config = f.attrs.get('model_config')
    if model_config is None:
      raise ValueError(f'No model config found in the file at {filepath}.')
    if hasattr(model_config, 'decode'):
      model_config = model_config.decode('utf-8')
    model_config = json_utils.decode(model_config)
    model = model_config_lib.model_from_config(model_config,
                                               custom_objects=custom_objects)

    # set weights
    load_weights_from_hdf5_group(f['model_weights'], model)

    if compile:
      # instantiate optimizer
      training_config = f.attrs.get('training_config')
      if hasattr(training_config, 'decode'):
        training_config = training_config.decode('utf-8')
      if training_config is None:
        logging.warning('No training configuration found in the save file, so '
                        'the model was *not* compiled. Compile it manually.')
        return model
      training_config = json_utils.decode(training_config)

      # Compile model.
      model.compile(**saving_utils.compile_args_from_training_config(
          training_config, custom_objects), from_serialized=True)
      saving_utils.try_build_compiled_arguments(model)

      # Set optimizer weights.
      if 'optimizer_weights' in f:
        try:
          model.optimizer._create_all_weights(model.trainable_variables)
        except (NotImplementedError, AttributeError):
          logging.warning(
              'Error when creating the weights of optimizer {}, making it '
              'impossible to restore the saved optimizer state. As a result, '
              'your model is starting with a freshly initialized optimizer.')

        optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
        try:
          model.optimizer.set_weights(optimizer_weight_values)
        except ValueError:
          logging.warning('Error in loading the saved optimizer '
                          'state. As a result, your model is '
                          'starting with a freshly initialized '
                          'optimizer.')
  finally:
    if opened_new_file:
      f.close()
  return model