示例#1
0
def compile_args_from_training_config(training_config, custom_objects=None):
  """Return model.compile arguments from training config."""
  if custom_objects is None:
    custom_objects = {}

  optimizer_config = training_config['optimizer_config']
  optimizer = optimizers.deserialize(
      optimizer_config, custom_objects=custom_objects)

  # Recover loss functions and metrics.
  loss_config = training_config['loss']  # Deserialize loss class.
  if isinstance(loss_config, dict) and 'class_name' in loss_config:
    loss_config = losses.get(loss_config)
  loss = nest.map_structure(
      lambda obj: custom_objects.get(obj, obj), loss_config)
  metrics = nest.map_structure(
      lambda obj: custom_objects.get(obj, obj), training_config['metrics'])
  weighted_metrics = nest.map_structure(
      lambda obj: custom_objects.get(obj, obj),
      training_config.get('weighted_metrics', None))
  sample_weight_mode = training_config['sample_weight_mode']
  loss_weights = training_config['loss_weights']

  return dict(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      weighted_metrics=weighted_metrics,
      loss_weights=loss_weights,
      sample_weight_mode=sample_weight_mode)
示例#2
0
def get_loss_function(loss):
  """Returns the loss function corresponding to the given loss input."""
  if loss is None or isinstance(loss, losses.Loss):
    return loss

  # TODO(psv): After we have added all V2 losses, update this function.
  if loss in ['mse', 'MSE', 'mean_squared_error']:
    return losses.MeanSquaredError()
  return losses.get(loss)
示例#3
0
def get_loss_function(loss):
    """Returns the loss function corresponding to the given loss input."""
    if loss is None or isinstance(loss, losses.Loss):
        return loss

    # TODO(psv): After we have added all V2 losses, update this function.
    if loss in ['mse', 'MSE', 'mean_squared_error']:
        return losses.MeanSquaredError()
    return losses.get(loss)
示例#4
0
    def _get_loss_object(self, loss):
        """Returns a `Loss` object.

    Converts the user-supplied loss to a `Loss` object. Also allows
    `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.

    Arguments:
      loss: A string, function, or `Loss` object.

    Returns:
      A `Loss` object.
    """
        if loss is None:
            return None  # Ok to have no loss for an output.

        loss = losses_mod.get(loss)
        if not isinstance(loss, losses_mod.Loss):
            loss = losses_mod.LossFunctionWrapper(loss, name=loss.__name__)
        loss._allow_sum_over_batch_size = True  # pylint: disable=protected-access
        return loss
示例#5
0
  def _get_loss_object(self, loss):
    """Returns a `Loss` object.

    Converts the user-supplied loss to a `Loss` object. Also allows
    `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.

    Args:
      loss: A string, function, or `Loss` object.

    Returns:
      A `Loss` object.
    """
    if loss is None:
      return None  # Ok to have no loss for an output.

    loss = losses_mod.get(loss)
    if not isinstance(loss, losses_mod.Loss):
      loss_name = get_custom_object_name(loss)
      if loss_name is None:
        raise ValueError('Loss should be a callable, found: {}'.format(loss))
      loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
    loss._allow_sum_over_batch_size = True  # pylint: disable=protected-access
    return loss
示例#6
0
    def _get_loss_object(self, loss):
        """Returns a `Loss` object.

    Converts the user-supplied loss to a `Loss` object. Also allows
    `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.

    Arguments:
      loss: A string, function, or `Loss` object.

    Returns:
      A `Loss` object.
    """
        if loss is None:
            return None  # Ok to have no loss for an output.

        # TODO(omalleyt): Handle special casing for crossentropy.
        loss = losses_mod.get(loss)
        if not isinstance(loss, losses_mod.Loss):
            loss = losses_mod.LossFunctionWrapper(loss)
        # Allow AUTO and SUM_OVER_BATCH_SIZE reductions.
        # TODO(omalleyt): Can we reconcile CTL and built-in loss reductions?
        loss._allow_sum_over_batch_size = True  # pylint: disable=protected-access
        return loss
示例#7
0
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
    """Loads a model saved via `save_model_to_hdf5`.

  Arguments:
      filepath: One of the following:
          - String, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed. When `compile` is set
      to False, the compilation is omitted without any
      warning.

  Raises:
      ImportError: if h5py is not available.
      ValueError: In case of an invalid savefile.
  """
    if h5py is None:
        raise ImportError('`load_model` requires h5py.')

    if not custom_objects:
        custom_objects = {}

    def convert_custom_objects(obj):
        """Handles custom object lookup.

    Arguments:
        obj: object, dict, or list.

    Returns:
        The same structure, where occurrences
            of a custom object name have been replaced
            with the custom object.
    """
        if isinstance(obj, list):
            deserialized = []
            for value in obj:
                deserialized.append(convert_custom_objects(value))
            return deserialized
        if isinstance(obj, dict):
            deserialized = {}
            for key, value in obj.items():
                deserialized[key] = convert_custom_objects(value)
            return deserialized
        if obj in custom_objects:
            return custom_objects[obj]
        return obj

    opened_new_file = not isinstance(filepath, h5py.File)
    if opened_new_file:
        f = h5py.File(filepath, mode='r')
    else:
        f = filepath

    model = None
    try:
        # instantiate model
        model_config = f.attrs.get('model_config')
        if model_config is None:
            raise ValueError('No model found in config file.')
        model_config = json.loads(model_config.decode('utf-8'))
        model = model_config_lib.model_from_config(
            model_config, custom_objects=custom_objects)

        # set weights
        load_weights_from_hdf5_group(f['model_weights'], model.layers)

        if compile:
            # instantiate optimizer
            training_config = f.attrs.get('training_config')
            if training_config is None:
                logging.warning(
                    'No training configuration found in save file: '
                    'the model was *not* compiled. Compile it manually.')
                return model
            training_config = json.loads(training_config.decode('utf-8'))
            optimizer_config = training_config['optimizer_config']
            optimizer = optimizers.deserialize(optimizer_config,
                                               custom_objects=custom_objects)

            # Recover loss functions and metrics.
            loss_config = training_config['loss']  # Deserialize loss class.
            if isinstance(loss_config, dict) and 'class_name' in loss_config:
                loss_config = losses.get(loss_config)
            loss = convert_custom_objects(loss_config)
            metrics = convert_custom_objects(training_config['metrics'])
            weighted_metrics = convert_custom_objects(
                training_config.get('weighted_metrics', None))
            sample_weight_mode = training_config['sample_weight_mode']
            loss_weights = training_config['loss_weights']

            # Compile model.
            model.compile(optimizer=optimizer,
                          loss=loss,
                          metrics=metrics,
                          weighted_metrics=weighted_metrics,
                          loss_weights=loss_weights,
                          sample_weight_mode=sample_weight_mode)

            # Set optimizer weights.
            if 'optimizer_weights' in f:
                # Build train function (to get weight updates).
                # Models that aren't graph networks must wait until they are called
                # with data to _make_train_function() and so can't load optimizer
                # weights.
                if model._is_graph_network:  # pylint: disable=protected-access
                    model._make_train_function()
                    optimizer_weight_values = load_optimizer_weights_from_hdf5_group(
                        f)
                    try:
                        model.optimizer.set_weights(optimizer_weight_values)
                    except ValueError:
                        logging.warning('Error in loading the saved optimizer '
                                        'state. As a result, your model is '
                                        'starting with a freshly initialized '
                                        'optimizer.')
                else:
                    logging.warning(
                        'Sequential models without an `input_shape` '
                        'passed to the first layer cannot reload their '
                        'optimizer state. As a result, your model is'
                        'starting with a freshly initialized optimizer.')

    finally:
        if opened_new_file:
            f.close()
    return model