Exemplo n.º 1
0
    def test_encode_decode_type_spec(self):
        spec = tensor_spec.TensorSpec((1, 5), dtypes.float32)
        string = json_utils.Encoder().encode(spec)
        loaded = json_utils.decode(string)
        self.assertEqual(spec, loaded)

        invalid_type_spec = {
            'class_name': 'TypeSpec',
            'type_spec': 'Invalid Type',
            'serialized': None
        }
        string = json_utils.Encoder().encode(invalid_type_spec)
        with self.assertRaisesRegexp(ValueError,
                                     'No TypeSpec has been registered'):
            loaded = json_utils.decode(string)
Exemplo n.º 2
0
  def _load_layer(self, node_id, identifier, metadata):
    """Load a single layer from a SavedUserObject proto."""
    metadata = json_utils.decode(metadata)

    # If node was already created
    if node_id in self.loaded_nodes:
      node, setter = self.loaded_nodes[node_id]

      # Revive setter requires the object to have a `_serialized_attributes`
      # property. Add it here.
      _maybe_add_serialized_attributes(node, metadata)

      config = metadata.get('config')
      if _is_graph_network(node) and generic_utils.validate_config(config):
        child_nodes = self._get_child_layer_node_ids(node_id)
        self.model_layer_dependencies[node_id] = (node, child_nodes)
        if not child_nodes:
          self._models_to_reconstruct.append(node_id)
      return node, setter

    # Detect whether this object can be revived from the config. If not, then
    # revive from the SavedModel instead.
    obj, setter = self._revive_from_config(identifier, metadata, node_id)
    if obj is None:
      obj, setter = revive_custom_object(identifier, metadata)

    # Add an attribute that stores the extra functions/objects saved in the
    # SavedModel. Most of these functions/objects are ignored, but some are
    # used later in the loading process (e.g. the list of regularization
    # losses, or the training config of compiled models).
    _maybe_add_serialized_attributes(obj, metadata)
    return obj, setter
Exemplo n.º 3
0
  def _reconstruct_model(self, model_id, model, layers):
    config = json_utils.decode(
        self._proto.nodes[model_id].user_object.metadata)['config']
    if isinstance(model, models_lib.Sequential):
      if not isinstance(layers[0], input_layer.InputLayer):
        if config['layers'][0]['class_name'] == 'InputLayer':
          layers.insert(0, input_layer.InputLayer.from_config(
              config['layers'][0]['config']))
        elif 'batch_input_shape' in config['layers'][0]['config']:
          batch_input_shape = config['layers'][0]['config']['batch_input_shape']
          layers.insert(0, input_layer.InputLayer(
              input_shape=batch_input_shape[1:],
              batch_size=batch_input_shape[0],
              dtype=layers[0].dtype,
              name=layers[0].name + '_input'))
      model.__init__(layers, name=config['name'])
      if not model.inputs:
        first_layer = self._get_child_layer_node_ids(model_id, model.name)[0]
        input_specs = self._infer_inputs(first_layer)
        input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
        model._set_inputs(input_specs)  # pylint: disable=protected-access
        if not model.built and not isinstance(input_specs, dict):
          model.build(input_shapes)
    else:
      (inputs, outputs,
       created_layers) = functional_lib.reconstruct_from_config(
           config, created_layers={layer.name: layer for layer in layers})
      model.__init__(inputs, outputs, name=config['name'])
      functional_lib.connect_ancillary_layers(model, created_layers)

    # Set model dtype and trainable status.
    _set_network_attributes_from_metadata(model)

    # Unblock models that are dependent on this model.
    self._unblock_model_reconstruction(model_id, model)
Exemplo n.º 4
0
    def _add_children_recreated_from_config(self, obj, proto, node_id):
        """Recursively records objects recreated from config."""
        # pylint: disable=protected-access
        if node_id in self._traversed_nodes_from_config:
            return
        self._traversed_nodes_from_config.append(node_id)
        obj._maybe_initialize_trackable()
        if isinstance(obj, base_layer.Layer) and not obj.built:
            metadata = json_utils.decode(proto.user_object.metadata)
            self._try_build_layer(obj, node_id,
                                  metadata.get('build_input_shape'))

        # Create list of all possible children
        children = []
        # Look for direct children
        for reference in proto.children:
            obj_child = obj._lookup_dependency(reference.local_name)
            children.append((obj_child, reference.node_id))

        # Add metrics that may have been added to the layer._metrics list.
        # This is stored in the SavedModel as layer.keras_api.layer_metrics in
        # SavedModels created after Tf 2.2.
        metric_list_node_id = self._search_for_child_node(
            node_id, [constants.KERAS_ATTR, 'layer_metrics'],
            raise_error=False)
        if metric_list_node_id is not None and hasattr(obj, '_metrics'):
            obj_metrics = {m.name: m for m in obj._metrics}
            for reference in self._proto.nodes[metric_list_node_id].children:
                metric = obj_metrics.get(reference.local_name)
                if metric is not None:
                    children.append((metric, reference.node_id))

        for (obj_child, child_id) in children:
            child_proto = self._proto.nodes[child_id]

            if not isinstance(obj_child, trackable.Trackable):
                continue
            if (child_proto.user_object.identifier
                    in revived_types.registered_identifiers()):
                setter = revived_types.get_setter(child_proto.user_object)
            elif obj_child._object_identifier in KERAS_OBJECT_IDENTIFIERS:
                setter = _revive_setter
            else:
                setter = setattr
                # pylint: enable=protected-access

            if (child_id in self._nodes_recreated_from_config
                    and self._nodes_recreated_from_config[child_id][0]
                    is not obj_child):
                # This means that the same trackable object is referenced by two
                # different objects that were recreated from the config.
                logging.warn(
                    'Looks like there is an object (perhaps variable or layer)'
                    ' that is shared between different layers/models. This '
                    'may cause issues when restoring the variable values.'
                    'Object: {}'.format(obj_child))
            self._nodes_recreated_from_config[child_id] = (
                obj_child, self._config_node_setter(setter))
            self._add_children_recreated_from_config(obj_child, child_proto,
                                                     child_id)
Exemplo n.º 5
0
    def test_encode_decode_tuple(self):
        metadata = {'key1': (3, 5), 'key2': [(1, (3, 4)), (1, )]}
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'], (3, 5))
        self.assertAllEqual(loaded['key2'], [(1, (3, 4)), (1, )])
    def test_encode_decode_enum(self):
        class Enum(enum.Enum):
            CLASS_A = 'a'
            CLASS_B = 'b'

        config = {'key': Enum.CLASS_A, 'key2': Enum.CLASS_B}
        string = json_utils.Encoder().encode(config)
        loaded = json_utils.decode(string)
        self.assertAllEqual({'key': 'a', 'key2': 'b'}, loaded)
Exemplo n.º 7
0
    def test_encode_decode_tensor_shape(self):
        metadata = {
            'key1':
            tensor_shape.TensorShape(None),
            'key2': [
                tensor_shape.TensorShape([None]),
                tensor_shape.TensorShape([3, None, 5])
            ]
        }
        string = json_utils.Encoder().encode(metadata)
        loaded = json_utils.decode(string)

        self.assertEqual(set(loaded.keys()), {'key1', 'key2'})
        self.assertAllEqual(loaded['key1'].rank, None)
        self.assertAllEqual(loaded['key2'][0].as_list(), [None])
        self.assertAllEqual(loaded['key2'][1].as_list(), [3, None, 5])
Exemplo n.º 8
0
def model_from_json(json_string, custom_objects=None):
    """Parses a JSON model configuration string and returns a model instance.

  Usage:

  >>> model = tf.keras.Sequential([
  ...     tf.keras.layers.Dense(5, input_shape=(3,)),
  ...     tf.keras.layers.Softmax()])
  >>> config = model.to_json()
  >>> loaded_model = tf.keras.models.model_from_json(config)

  Args:
      json_string: JSON string encoding a model configuration.
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.

  Returns:
      A Keras model instance (uncompiled).
  """
    config = json_utils.decode(json_string)
    from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
    return deserialize(config, custom_objects=custom_objects)
Exemplo n.º 9
0
  def _add_children_recreated_from_config(self, obj, proto, node_id):
    """Recursively records objects recreated from config."""
    # pylint: disable=protected-access
    if node_id in self._traversed_nodes_from_config:
      return

    parent_path = self._node_paths[node_id]
    self._traversed_nodes_from_config.add(node_id)
    obj._maybe_initialize_trackable()
    if isinstance(obj, base_layer.Layer) and not obj.built:
      metadata = json_utils.decode(proto.user_object.metadata)
      self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))

    # Create list of all possible children
    children = []
    # Look for direct children
    for reference in proto.children:
      obj_child = obj._lookup_dependency(reference.local_name)
      children.append((obj_child, reference.node_id, reference.local_name))

    # Add metrics that may have been added to the layer._metrics list.
    # This is stored in the SavedModel as layer.keras_api.layer_metrics in
    # SavedModels created after Tf 2.2.
    metric_list_node_id = self._search_for_child_node(
        node_id, [constants.KERAS_ATTR, 'layer_metrics'])
    if metric_list_node_id is not None and hasattr(obj, '_metrics'):
      obj_metrics = {m.name: m for m in obj._metrics}
      for reference in self._proto.nodes[metric_list_node_id].children:
        metric = obj_metrics.get(reference.local_name)
        if metric is not None:
          metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
                                                     reference.local_name)
          children.append((metric, reference.node_id, metric_path))

    for (obj_child, child_id, child_name) in children:
      child_proto = self._proto.nodes[child_id]

      if not isinstance(obj_child, trackable.Trackable):
        continue
      if (child_proto.user_object.identifier in
          revived_types.registered_identifiers()):
        setter = revived_types.get_setter(child_proto.user_object)
      elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
        setter = _revive_setter
      else:
        setter = setattr
        # pylint: enable=protected-access

      if child_id in self.loaded_nodes:
        if self.loaded_nodes[child_id][0] is not obj_child:
          # This means that the same trackable object is referenced by two
          # different objects that were recreated from the config.
          logging.warn('Looks like there is an object (perhaps variable or '
                       'layer) that is shared between different layers/models. '
                       'This may cause issues when restoring the variable '
                       'values. Object: {}'.format(obj_child))
        continue

      # Overwrite variable names with the ones saved in the SavedModel.
      if (child_proto.WhichOneof('kind') == 'variable' and
          child_proto.variable.name):
        obj_child._handle_name = child_proto.variable.name + ':0'  # pylint: disable=protected-access

      if isinstance(obj_child, data_structures.TrackableDataStructure):
        setter = lambda *args: None

      child_path = '{}.{}'.format(parent_path, child_name)
      self._node_paths[child_id] = child_path
      self._add_children_recreated_from_config(
          obj_child, child_proto, child_id)
      self.loaded_nodes[child_id] = obj_child, setter
Exemplo n.º 10
0
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
    """Loads a model saved via `save_model_to_hdf5`.

  Arguments:
      filepath: One of the following:
          - String, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed. When `compile` is set
      to False, the compilation is omitted without any
      warning.

  Raises:
      ImportError: if h5py is not available.
      ValueError: In case of an invalid savefile.
  """
    if h5py is None:
        raise ImportError('`load_model` requires h5py.')

    if not custom_objects:
        custom_objects = {}

    opened_new_file = not isinstance(filepath, h5py.File)
    if opened_new_file:
        f = h5py.File(filepath, mode='r')
    else:
        f = filepath

    model = None
    try:
        # instantiate model
        model_config = f.attrs.get('model_config')
        if model_config is None:
            raise ValueError('No model found in config file.')
        model_config = json_utils.decode(model_config.decode('utf-8'))
        model = model_config_lib.model_from_config(
            model_config, custom_objects=custom_objects)

        # set weights
        load_weights_from_hdf5_group(f['model_weights'], model.layers)

        if compile:
            # instantiate optimizer
            training_config = f.attrs.get('training_config')
            if training_config is None:
                logging.warning(
                    'No training configuration found in the save file, so '
                    'the model was *not* compiled. Compile it manually.')
                return model
            training_config = json_utils.decode(
                training_config.decode('utf-8'))

            # Compile model.
            model.compile(**saving_utils.compile_args_from_training_config(
                training_config, custom_objects))
            saving_utils.try_build_compiled_arguments(model)

            # Set optimizer weights.
            if 'optimizer_weights' in f:
                try:
                    model.optimizer._create_all_weights(
                        model.trainable_variables)
                except (NotImplementedError, AttributeError):
                    logging.warning(
                        'Error when creating the weights of optimizer {}, making it '
                        'impossible to restore the saved optimizer state. As a result, '
                        'your model is starting with a freshly initialized optimizer.'
                    )

                optimizer_weight_values = load_optimizer_weights_from_hdf5_group(
                    f)
                try:
                    model.optimizer.set_weights(optimizer_weight_values)
                except ValueError:
                    logging.warning('Error in loading the saved optimizer '
                                    'state. As a result, your model is '
                                    'starting with a freshly initialized '
                                    'optimizer.')
    finally:
        if opened_new_file:
            f.close()
    return model