Пример #1
0
  def set_weights(self, weights):
    """Sets the weights of the optimizer, from Numpy arrays.

    Should only be called after computing the gradients
    (otherwise the optimizer has no weights).

    Arguments:
        weights: a list of Numpy arrays. The number
            of arrays and their shape must match
            number of the dimensions of the weights
            of the optimizer (i.e. it should match the
            output of `get_weights`).

    Raises:
        ValueError: in case of incompatible weight shapes.
    """
    params = self.weights
    if len(params) != len(weights):
      raise ValueError(
          'Length of the specified weight list (' + str(len(weights)) +
          ') does not match the number of weights '
          'of the optimizer (' + str(len(params)) + ')')
    weight_value_tuples = []
    param_values = K.batch_get_value(params)
    for pv, p, w in zip(param_values, params, weights):
      if pv.shape != w.shape:
        raise ValueError(
            'Optimizer weight shape ' + str(pv.shape) + ' not compatible with '
            'provided weight shape ' + str(w.shape))
      weight_value_tuples.append((p, w))
    K.batch_set_value(weight_value_tuples)
Пример #2
0
def load_weights_from_hdf5_group_by_name(f, layers):
  """Implements name-based weight loading.

  (instead of topological weight loading).

  Layers that have no matching name are skipped.

  Arguments:
      f: A pointer to a HDF5 group.
      layers: a list of target layers.

  Raises:
      ValueError: in case of mismatch between provided layers
          and weights file.
  """
  if 'keras_version' in f.attrs:
    original_keras_version = f.attrs['keras_version'].decode('utf8')
  else:
    original_keras_version = '1'
  if 'backend' in f.attrs:
    original_backend = f.attrs['backend'].decode('utf8')
  else:
    original_backend = None

  # New file format.
  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')

  # Reverse index of layer name to list of layers with name.
  index = {}
  for layer in layers:
    if layer.name:
      index.setdefault(layer.name, []).append(layer)

  # We batch weight value assignments in a single backend call
  # which provides a speedup in TensorFlow.
  weight_value_tuples = []
  for k, name in enumerate(layer_names):
    g = f[name]
    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]

    for layer in index.get(name, []):
      symbolic_weights = layer.weights
      weight_values = preprocess_weights_for_loading(
          layer, weight_values, original_keras_version, original_backend)
      if len(weight_values) != len(symbolic_weights):
        raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
                         '") expects ' + str(len(symbolic_weights)) +
                         ' weight(s), but the saved weights' + ' have ' +
                         str(len(weight_values)) + ' element(s).')
      # Set values.
      for i in range(len(weight_values)):
        weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
  K.batch_set_value(weight_value_tuples)
Пример #3
0
 def set_weights(self, weights):
   params = self.weights
   if len(params) != len(weights):
     raise ValueError(
         "You called `set_weights(weights)` on optimizer " + self._name +
         " with a  weight list of length " + str(len(weights)) +
         ", but the optimizer was expecting " + str(len(params)) +
         " weights. Provided weights: " + str(weights)[:50] + "...")
   if not params:
     return
   weight_value_tuples = []
   param_values = backend.batch_get_value(params)
   for pv, p, w in zip(param_values, params, weights):
     if pv.shape != w.shape:
       raise ValueError("Optimizer weight shape " + str(pv.shape) +
                        " not compatible with "
                        "provided weight shape " + str(w.shape))
     weight_value_tuples.append((p, w))
   backend.batch_set_value(weight_value_tuples)
Пример #4
0
def load_weights_verbosely(
    save_file: str,
    model: GraphTaskModel,
    warn_about_initialisations: bool = True,
    warn_about_ignored: bool = True,
    weight_name_to_var_name: Optional[Callable[
        [str], str]] = backward_compat_weight_renaming_fn,
):
    var_name_to_variable = _get_name_to_variable_map(model)

    with open(get_model_file_path(save_file, "pkl"), "rb") as in_file:
        data_to_load = pickle.load(in_file)
    var_name_to_weights = data_to_load.get("model_weights")
    if var_name_to_weights is None:
        var_name_to_weights = _read_weights_from_hdf5(save_file)

    if weight_name_to_var_name is not None:
        remapped_var_name_to_weights = {}
        for weight_name, weight in var_name_to_weights.items():
            remapped_var_name_to_weights[weight_name_to_var_name(
                weight_name)] = weight
        var_name_to_weights = remapped_var_name_to_weights

    tfvar_weight_tuples = []
    used_var_names = set()
    for var_name, tfvar in var_name_to_variable.items():
        saved_weight = var_name_to_weights.get(var_name)
        if saved_weight is None:
            if warn_about_initialisations:
                print(f"I: Weights for {var_name} freshly initialised.")
        else:
            used_var_names.add(var_name)
            tfvar_weight_tuples.append((tfvar, saved_weight))

    if warn_about_ignored:
        for var_name in var_name_to_weights.keys():
            if var_name not in used_var_names:
                print(f"I: Model does not use saved weights for {var_name}.")

    K.batch_set_value(tfvar_weight_tuples)
Пример #5
0
def convert_all_kernels_in_model(model):
  """Converts all convolution kernels in a model from Theano to TensorFlow.

  Also works from TensorFlow to Theano.

  Arguments:
      model: target model for the conversion.
  """
  # Note: SeparableConvolution not included
  # since only supported by TF.
  conv_classes = {
      'Conv1D',
      'Conv2D',
      'Conv3D',
      'Conv2DTranspose',
  }
  to_assign = []
  for layer in model.layers:
    if layer.__class__.__name__ in conv_classes:
      original_kernel = K.get_value(layer.kernel)
      converted_kernel = convert_kernel(original_kernel)
      to_assign.append((layer.kernel, converted_kernel))
  K.batch_set_value(to_assign)
Пример #6
0
def convert_all_kernels_in_model(model):
    """Converts all convolution kernels in a model from Theano to TensorFlow.

  Also works from TensorFlow to Theano.

  Arguments:
      model: target model for the conversion.
  """
    # Note: SeparableConvolution not included
    # since only supported by TF.
    conv_classes = {
        'Conv1D',
        'Conv2D',
        'Conv3D',
        'Conv2DTranspose',
    }
    to_assign = []
    for layer in model.layers:
        if layer.__class__.__name__ in conv_classes:
            original_kernel = K.get_value(layer.kernel)
            converted_kernel = convert_kernel(original_kernel)
            to_assign.append((layer.kernel, converted_kernel))
    K.batch_set_value(to_assign)
Пример #7
0
    def _load_state_dict(self, model, weight_dict):
        original_keras_version = keras_version
        original_backend = K.backend()

        weight_value_tuples = []
        for k, layer in enumerate(model.layers):
            weight_names = [l.name for l in layer.weights]
            if len(weight_names) == 0:
                continue
            weight_values = [
                np.asarray(weight_dict[weight_name]) for weight_name in weight_names
            ]

            symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
            weight_values = preprocess_weights_for_loading(
                layer, weight_values, original_keras_version, original_backend
            )

            if len(weight_values) != len(symbolic_weights):
                raise ValueError(
                    "Layer #"
                    + str(k)
                    + ' (named "'
                    + layer.name
                    + '" in the current model) was found to '
                    "correspond to layer " + layer.name + " in the save file. "
                    "However the new layer "
                    + layer.name
                    + " expects "
                    + str(len(symbolic_weights))
                    + " weights, but the saved weights have "
                    + str(len(weight_values))
                    + " elements."
                )
            weight_value_tuples += zip(symbolic_weights, weight_values)
        K.batch_set_value(weight_value_tuples)
Пример #8
0
def load_weights_from_hdf5_group(f, layers):
  """Implements topological (order-based) weight loading.

  Arguments:
      f: A pointer to a HDF5 group.
      layers: a list of target layers.

  Raises:
      ValueError: in case of mismatch between provided layers
          and weights file.
  """
  if 'keras_version' in f.attrs:
    original_keras_version = f.attrs['keras_version'].decode('utf8')
  else:
    original_keras_version = '1'
  if 'backend' in f.attrs:
    original_backend = f.attrs['backend'].decode('utf8')
  else:
    original_backend = None

  filtered_layers = []
  for layer in layers:
    weights = layer.weights
    if weights:
      filtered_layers.append(layer)

  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
  filtered_layer_names = []
  for name in layer_names:
    g = f[name]
    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
    if weight_names:
      filtered_layer_names.append(name)
  layer_names = filtered_layer_names
  if len(layer_names) != len(filtered_layers):
    raise ValueError('You are trying to load a weight file '
                     'containing ' + str(len(layer_names)) +
                     ' layers into a model with ' + str(len(filtered_layers)) +
                     ' layers.')

  # We batch weight value assignments in a single backend call
  # which provides a speedup in TensorFlow.
  weight_value_tuples = []
  for k, name in enumerate(layer_names):
    g = f[name]
    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
    layer = filtered_layers[k]
    symbolic_weights = layer.weights
    weight_values = preprocess_weights_for_loading(
        layer, weight_values, original_keras_version, original_backend)
    if len(weight_values) != len(symbolic_weights):
      raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
                       '" in the current model) was found to '
                       'correspond to layer ' + name + ' in the save file. '
                       'However the new layer ' + layer.name + ' expects ' +
                       str(len(symbolic_weights)) +
                       ' weights, but the saved weights have ' +
                       str(len(weight_values)) + ' elements.')
    weight_value_tuples += zip(symbolic_weights, weight_values)
  K.batch_set_value(weight_value_tuples)
Пример #9
0
def load_weights_from_hdf5_group_by_name(f, layers, skip_mismatch=False):
    """Implements name-based weight loading.

  (instead of topological weight loading).

  Layers that have no matching name are skipped.

  Arguments:
      f: A pointer to a HDF5 group.
      layers: a list of target layers.
      skip_mismatch: Boolean, whether to skip loading of layers
          where there is a mismatch in the number of weights,
          or a mismatch in the shape of the weights.

  Raises:
      ValueError: in case of mismatch between provided layers
          and weights file and skip_match=False.
  """
    if 'keras_version' in f.attrs:
        original_keras_version = f.attrs['keras_version'].decode('utf8')
    else:
        original_keras_version = '1'
    if 'backend' in f.attrs:
        original_backend = f.attrs['backend'].decode('utf8')
    else:
        original_backend = None

    # New file format.
    layer_names = load_attributes_from_hdf5_group(f, 'layer_names')

    # Reverse index of layer name to list of layers with name.
    index = {}
    for layer in layers:
        if layer.name:
            index.setdefault(layer.name, []).append(layer)

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]

        for layer in index.get(name, []):
            symbolic_weights = _legacy_weights(layer)
            weight_values = preprocess_weights_for_loading(
                layer, weight_values, original_keras_version, original_backend)
            if len(weight_values) != len(symbolic_weights):
                if skip_mismatch:
                    logging.warning(
                        'Skipping loading of weights for '
                        'layer {}'.format(layer.name) + ' due to mismatch '
                        'in number of weights ({} vs {}).'.format(
                            len(symbolic_weights), len(weight_values)))
                    continue
                raise ValueError('Layer #' + str(k) + ' (named "' +
                                 layer.name + '") expects ' +
                                 str(len(symbolic_weights)) +
                                 ' weight(s), but the saved weights' +
                                 ' have ' + str(len(weight_values)) +
                                 ' element(s).')
            # Set values.
            for i in range(len(weight_values)):
                if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
                    if skip_mismatch:
                        logging.warning('Skipping loading of weights for '
                                        'layer {}'.format(layer.name) +
                                        ' due to '
                                        'mismatch in shape ({} vs {}).'.format(
                                            symbolic_weights[i].shape,
                                            weight_values[i].shape))
                        continue
                    raise ValueError('Layer #' + str(k) + ' (named "' +
                                     layer.name + '"), weight ' +
                                     str(symbolic_weights[i]) +
                                     ' has shape {}'.format(
                                         K.int_shape(symbolic_weights[i])) +
                                     ', but the saved weight has shape ' +
                                     str(weight_values[i].shape) + '.')

                else:
                    weight_value_tuples.append(
                        (symbolic_weights[i], weight_values[i]))
    K.batch_set_value(weight_value_tuples)
Пример #10
0
def load_weights_from_hdf5_group(f, layers):
    """Implements topological (order-based) weight loading.

  Arguments:
      f: A pointer to a HDF5 group.
      layers: a list of target layers.

  Raises:
      ValueError: in case of mismatch between provided layers
          and weights file.
  """
    if 'keras_version' in f.attrs:
        original_keras_version = f.attrs['keras_version'].decode('utf8')
    else:
        original_keras_version = '1'
    if 'backend' in f.attrs:
        original_backend = f.attrs['backend'].decode('utf8')
    else:
        original_backend = None

    filtered_layers = []
    for layer in layers:
        weights = _legacy_weights(layer)
        if weights:
            filtered_layers.append(layer)

    layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
    filtered_layer_names = []
    for name in layer_names:
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        if weight_names:
            filtered_layer_names.append(name)
    layer_names = filtered_layer_names
    if len(layer_names) != len(filtered_layers):
        raise ValueError('You are trying to load a weight file '
                         'containing ' + str(len(layer_names)) +
                         ' layers into a model with ' +
                         str(len(filtered_layers)) + ' layers.')

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]
        layer = filtered_layers[k]
        symbolic_weights = _legacy_weights(layer)
        weight_values = preprocess_weights_for_loading(layer, weight_values,
                                                       original_keras_version,
                                                       original_backend)
        if len(weight_values) != len(symbolic_weights):
            raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
                             '" in the current model) was found to '
                             'correspond to layer ' + name +
                             ' in the save file. '
                             'However the new layer ' + layer.name +
                             ' expects ' + str(len(symbolic_weights)) +
                             ' weights, but the saved weights have ' +
                             str(len(weight_values)) + ' elements.')
        weight_value_tuples += zip(symbolic_weights, weight_values)
    K.batch_set_value(weight_value_tuples)
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False):
    """ Load pytorch state_dict in a TF 2.0 model.
    """
    try:
        import torch  # noqa: F401
        import tensorflow as tf  # noqa: F401
        from tensorflow.python.keras import backend as K
    except ImportError:
        logger.error(
            "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    if tf_inputs is None:
        tf_inputs = tf_model.dummy_inputs

    if tf_inputs is not None:
        tf_model(tf_inputs, training=False)  # Make sure model is built

    # Adapt state dict - TODO remove this and update the AWS weights files instead
    # Convert old format to new format if needed from a PyTorch state_dict
    old_keys = []
    new_keys = []
    for key in pt_state_dict.keys():
        new_key = None
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
        if "beta" in key:
            new_key = key.replace("beta", "bias")
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        pt_state_dict[new_key] = pt_state_dict.pop(old_key)

    # Make sure we are able to load PyTorch base models as well as derived models (with heads)
    # TF models always have a prefix, some of PyTorch models (base ones) don't
    start_prefix_to_remove = ""
    if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
        start_prefix_to_remove = tf_model.base_model_prefix + "."

    symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
    tf_loaded_numel = 0
    weight_value_tuples = []
    all_pytorch_weights = set(list(pt_state_dict.keys()))
    unexpected_keys = []
    for symbolic_weight in symbolic_weights:
        sw_name = symbolic_weight.name
        name, transpose = convert_tf_weight_name_to_pt_weight_name(
            sw_name, start_prefix_to_remove=start_prefix_to_remove
        )

        # Find associated numpy array in pytorch model state dict
        if name not in pt_state_dict:
            if allow_missing_keys:
                unexpected_keys.append(name)
                continue

            raise AttributeError("{} not found in PyTorch model".format(name))

        array = pt_state_dict[name].numpy()

        if transpose:
            array = numpy.transpose(array)

        if len(symbolic_weight.shape) < len(array.shape):
            array = numpy.squeeze(array)
        elif len(symbolic_weight.shape) > len(array.shape):
            array = numpy.expand_dims(array, axis=0)

        try:
            assert list(symbolic_weight.shape) == list(array.shape)
        except AssertionError as e:
            e.args += (symbolic_weight.shape, array.shape)
            raise e

        tf_loaded_numel += array.size
        # logger.warning("Initialize TF weight {}".format(symbolic_weight.name))

        weight_value_tuples.append((symbolic_weight, array))
        all_pytorch_weights.discard(name)

    K.batch_set_value(weight_value_tuples)

    if tf_inputs is not None:
        tf_model(tf_inputs, training=False)  # Make sure restore ops are run

    logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))

    missing_keys = list(all_pytorch_weights)

    if len(unexpected_keys) > 0:
        logger.warning(
            f"Some weights of the PyTorch model were not used when "
            f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
            f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model trained on another task "
            f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
            f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model that you expect "
            f"to be exactly identical (initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
        )
    else:
        logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
    if len(missing_keys) > 0:
        logger.warning(
            f"Some weights or buffers of the PyTorch model {tf_model.__class__.__name__} were not initialized from the TF 2.0 model "
            f"and are newly initialized: {missing_keys}\n"
            f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
        )
    else:
        logger.warning(
            f"All the weights of {tf_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
            f"If your task is similar to the task the model of the ckeckpoint was trained on, "
            f"you can already use {tf_model.__class__.__name__} for predictions without further training."
        )

    return tf_model
Пример #12
0
 def reset_state(self):
     backend.batch_set_value([(v, np.zeros(v.shape))
                              for v in self.variables])
     pass  # reset_state()
Пример #13
0
 def reset_states(self):
     K.batch_set_value([(self.variables[0], tf.zeros(self.n_outputs))])
     K.set_value(self.count, [0])
Пример #14
0
def load_pytorch_weights_in_tf2_model(tf_model,
                                      pt_state_dict,
                                      tf_inputs=None,
                                      allow_missing_keys=False,
                                      output_loading_info=False):
    """Load pytorch state_dict in a TF 2.0 model."""
    try:
        import tensorflow as tf  # noqa: F401
        import torch  # noqa: F401
        from tensorflow.python.keras import backend as K
    except ImportError:
        logger.error(
            "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
        )
        raise

    if tf_inputs is None:
        tf_inputs = tf_model.dummy_inputs

    if tf_inputs is not None:
        tf_model(tf_inputs, training=False)  # Make sure model is built
    # Adapt state dict - TODO remove this and update the AWS weights files instead
    # Convert old format to new format if needed from a PyTorch state_dict
    old_keys = []
    new_keys = []
    for key in pt_state_dict.keys():
        new_key = None
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
        if "beta" in key:
            new_key = key.replace("beta", "bias")
        if "running_var" in key:
            new_key = key.replace("running_var", "moving_variance")
        if "running_mean" in key:
            new_key = key.replace("running_mean", "moving_mean")
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        pt_state_dict[new_key] = pt_state_dict.pop(old_key)

    # Make sure we are able to load PyTorch base models as well as derived models (with heads)
    # TF models always have a prefix, some of PyTorch models (base ones) don't
    start_prefix_to_remove = ""
    if not any(
            s.startswith(tf_model.base_model_prefix)
            for s in pt_state_dict.keys()):
        start_prefix_to_remove = tf_model.base_model_prefix + "."

    symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
    tf_loaded_numel = 0
    weight_value_tuples = []
    all_pytorch_weights = set(list(pt_state_dict.keys()))
    missing_keys = []
    for symbolic_weight in symbolic_weights:
        sw_name = symbolic_weight.name
        name, transpose = convert_tf_weight_name_to_pt_weight_name(
            sw_name,
            start_prefix_to_remove=start_prefix_to_remove,
            tf_weight_shape=symbolic_weight.shape)

        # Find associated numpy array in pytorch model state dict
        if name not in pt_state_dict:
            if allow_missing_keys:
                missing_keys.append(name)
                continue
            elif tf_model._keys_to_ignore_on_load_missing is not None:
                # authorized missing keys don't have to be loaded
                if any(
                        re.search(pat, name) is not None
                        for pat in tf_model._keys_to_ignore_on_load_missing):
                    continue
            raise AttributeError(f"{name} not found in PyTorch model")

        array = pt_state_dict[name].numpy()

        if transpose is TransposeType.CONV2D:
            # Conv2D weight:
            #    PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
            # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
            array = numpy.transpose(array, axes=(2, 3, 1, 0))
        elif transpose is TransposeType.CONV1D:
            # Conv1D weight:
            #    PT: (num_out_channel, num_in_channel, kernel)
            # -> TF: (kernel, num_in_channel, num_out_channel)
            array = numpy.transpose(array, axes=(2, 1, 0))
        elif transpose is TransposeType.SIMPLE:
            array = numpy.transpose(array)

        if len(symbolic_weight.shape) < len(array.shape):
            array = numpy.squeeze(array)
        elif len(symbolic_weight.shape) > len(array.shape):
            array = numpy.expand_dims(array, axis=0)

        if list(symbolic_weight.shape) != list(array.shape):
            try:
                array = numpy.reshape(array, symbolic_weight.shape)
            except AssertionError as e:
                e.args += (symbolic_weight.shape, array.shape)
                raise e

        try:
            assert list(symbolic_weight.shape) == list(array.shape)
        except AssertionError as e:
            e.args += (symbolic_weight.shape, array.shape)
            raise e

        tf_loaded_numel += array.size
        # logger.warning(f"Initialize TF weight {symbolic_weight.name}")

        weight_value_tuples.append((symbolic_weight, array))
        all_pytorch_weights.discard(name)

    K.batch_set_value(weight_value_tuples)

    if tf_inputs is not None:
        tf_model(tf_inputs, training=False)  # Make sure restore ops are run

    logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")

    unexpected_keys = list(all_pytorch_weights)

    if tf_model._keys_to_ignore_on_load_missing is not None:
        for pat in tf_model._keys_to_ignore_on_load_missing:
            missing_keys = [
                k for k in missing_keys if re.search(pat, k) is None
            ]
    if tf_model._keys_to_ignore_on_load_unexpected is not None:
        for pat in tf_model._keys_to_ignore_on_load_unexpected:
            unexpected_keys = [
                k for k in unexpected_keys if re.search(pat, k) is None
            ]

    if len(unexpected_keys) > 0:
        logger.warning(
            "Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
            f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
            f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture"
            " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
            f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect"
            " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
            " BertForSequenceClassification model).")
    else:
        logger.warning(
            f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n"
        )
    if len(missing_keys) > 0:
        logger.warning(
            f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
            f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
            " down-stream task to be able to use it for predictions and inference."
        )
    else:
        logger.warning(
            f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
            "If your task is similar to the task the model of the checkpoint was trained on, "
            f"you can already use {tf_model.__class__.__name__} for predictions without further training."
        )

    if output_loading_info:
        loading_info = {
            "missing_keys": missing_keys,
            "unexpected_keys": unexpected_keys
        }
        return tf_model, loading_info

    return tf_model
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            dynamic_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched = True
                weights = backend.batch_get_value(self.trainable_variables)
                while switched:
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, dynamic_switch, verbose)
                        # If a switch occurred, we need to restore the weights
                        if switched:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                if self.accumulate_gradients:
                    self.optimizer.apply_gradients(
                        zip(self.accumulated_gradients,
                            self.trainable_variables))

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history
Пример #16
0
 def reset_states(self):
     K.batch_set_value([(v, np.zeros(
         (self.num_thresholds, self.num_classes))) for v in self.variables])
Пример #17
0
    def __init__(self, graph):
        import tensorflow
        from tensorflow.python.keras import backend as K
        import h5py
        self.graph = graph
        self.layers = []
        for layer in graph.layer_list:
            self.layers.append(to_real_tf_layer(layer))

        # Construct the keras graph.
        # Input
        topo_node_list = self.graph.topological_order
        output_id = topo_node_list[-1]
        input_id = topo_node_list[0]
        input_tensor = tensorflow.keras.layers.Input(
            shape=graph.node_list[input_id].shape)

        node_list = deepcopy(self.graph.node_list)
        node_list[input_id] = input_tensor

        # Output
        for v in topo_node_list:
            for u, layer_id in self.graph.reverse_adj_list[v]:
                layer = self.graph.layer_list[layer_id]
                tf_layer = self.layers[layer_id]

                if isinstance(layer, (StubAdd, StubConcatenate)):
                    edge_input_tensor = list(
                        map(
                            lambda x: node_list[x],
                            self.graph.layer_id_to_input_node_ids[layer_id],
                        ))
                else:
                    edge_input_tensor = node_list[u]

                temp_tensor = tf_layer(edge_input_tensor)
                node_list[v] = temp_tensor

        output_tensor = node_list[output_id]
        output_tensor = tensorflow.keras.layers.Activation(
            "softmax", name="activation_add")(output_tensor)
        self.model = tensorflow.keras.models.Model(inputs=input_tensor,
                                                   outputs=output_tensor)

        self.count = 0
        self.loadh5 = 0
        try:
            with h5py.File(
                    "/userhome/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
                    'r') as f:
                layer_names = self.load_attributes_from_hdf5_group(
                    f, 'layer_names')
                filtered_layer_names = []
                for name in layer_names:
                    g = f[name]
                    weight_names = self.load_attributes_from_hdf5_group(
                        g, 'weight_names')
                    if weight_names:
                        filtered_layer_names.append(name)
                layer_names = filtered_layer_names  # 107 layers
                try:
                    for k, name in enumerate(layer_names):
                        g = f[name]
                        weight_names = self.load_attributes_from_hdf5_group(
                            g, 'weight_names')
                        weight_values = [
                            np.asarray(g[weight_name])
                            for weight_name in weight_names
                        ]
                        while not self.legacy_weights():
                            self.count += 1
                        symbolic_weights = self.legacy_weights()
                        weight_value_tuples = zip(symbolic_weights,
                                                  weight_values)
                        try:
                            K.batch_set_value(weight_value_tuples)
                            self.loadh5 += 1
                        except Exception as E:
                            continue
                        self.count += 1
                except Exception as E:
                    self.loadh5 += 0
            print("############## Loading initial weights for " +
                  str(self.loadh5) + " layers.")
        except Exception as E:
            print(E)
Пример #18
0
 def reset_states(self):
     num_thresholds = len(to_list(self.thresholds))
     K.batch_set_value([(v, np.zeros((num_thresholds, )))
                        for v in self.variables])
def load_pytorch_weights_in_tf2_model(tf_model,
                                      pt_state_dict,
                                      tf_inputs=None,
                                      allow_missing_keys=False):
    """ Load pytorch state_dict in a TF 2.0 model.
    """
    try:
        import torch  # noqa: F401
        import tensorflow as tf  # noqa: F401
        from tensorflow.python.keras import backend as K
    except ImportError as e:
        logger.error(
            "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
        )
        raise e

    if tf_inputs is None:
        tf_inputs = tf_model.dummy_inputs

    if tf_inputs is not None:
        tfo = tf_model(tf_inputs, training=False)  # Make sure model is built

    # Adapt state dict - TODO remove this and update the AWS weights files instead
    # Convert old format to new format if needed from a PyTorch state_dict
    old_keys = []
    new_keys = []
    for key in pt_state_dict.keys():
        new_key = None
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
        if "beta" in key:
            new_key = key.replace("beta", "bias")
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        pt_state_dict[new_key] = pt_state_dict.pop(old_key)

    # Make sure we are able to load PyTorch base models as well as derived models (with heads)
    # TF models always have a prefix, some of PyTorch models (base ones) don't
    start_prefix_to_remove = ""
    if not any(
            s.startswith(tf_model.base_model_prefix)
            for s in pt_state_dict.keys()):
        start_prefix_to_remove = tf_model.base_model_prefix + "."

    symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
    tf_loaded_numel = 0
    weight_value_tuples = []
    all_pytorch_weights = set(list(pt_state_dict.keys()))
    for symbolic_weight in symbolic_weights:
        sw_name = symbolic_weight.name
        name, transpose = convert_tf_weight_name_to_pt_weight_name(
            sw_name, start_prefix_to_remove=start_prefix_to_remove)

        # Find associated numpy array in pytorch model state dict
        if name not in pt_state_dict:
            if allow_missing_keys:
                continue
            raise AttributeError("{} not found in PyTorch model".format(name))

        array = pt_state_dict[name].numpy()

        if transpose:
            array = numpy.transpose(array)

        if len(symbolic_weight.shape) < len(array.shape):
            array = numpy.squeeze(array)
        elif len(symbolic_weight.shape) > len(array.shape):
            array = numpy.expand_dims(array, axis=0)

        try:
            assert list(symbolic_weight.shape) == list(array.shape)
        except AssertionError as e:
            e.args += (symbolic_weight.shape, array.shape)
            raise e

        tf_loaded_numel += array.size
        # logger.warning("Initialize TF weight {}".format(symbolic_weight.name))

        weight_value_tuples.append((symbolic_weight, array))
        all_pytorch_weights.discard(name)

    K.batch_set_value(weight_value_tuples)

    if tf_inputs is not None:
        tfo = tf_model(tf_inputs,
                       training=False)  # Make sure restore ops are run

    logger.info(
        "Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))

    logger.info("Weights or buffers not loaded from PyTorch model: {}".format(
        all_pytorch_weights))

    return tf_model
 def reset_states(self):
     K.batch_set_value([(v, np.zeros(shape=v.get_shape()))
                        for v in self.variables])
Пример #21
0
 def reset_states(self):
     K.batch_set_value([(self.variables[0],
                         tf.zeros((self.n_classes, self.n_classes)))])
Пример #22
0
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            auto_switch=True,
            retry_fit=True,
            absorb=True,
            train_after_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False,
            revert_after_fit=False):
        """
        Custom fit function for the context model

        auto_switch:        Enable/disable autonomous context switching
        train_after_switch:
        retry_fit:          Locate the next fitting context by re-performing fit.
        absorb:             Reset the switch sequence counter upon successful training.
                            This is mainly used to maintain switch sequencing for temporally-extended tasks
        revert_after_fit    This is a debug parameter to revert weights after performing a fit. This is used
                            to calculate the context deltas without incorrectly learning while auto switching
                            is disabled
        """

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
                data_adapter.train_validation_split(
                    (x, y, sample_weight),
                    validation_split=validation_split,
                    shuffle=False))

        with self.distribute_strategy.scope(
        ), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            self.initialize_fit()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (
                self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched_during_epoch = False  # Indicate if the model has attempted at least one switch during this epoch
                switched = True  # Indicate if the model switched on the most recent fit iteration
                weights = backend.batch_get_value(self.trainable_variables)
                # Perform a 'fit call'. Assuming retry_fit, this call is re-attempted after each switch until a context fits
                while switched and (retry_fit or not switched_during_epoch):
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)

                    # Perform a fit call
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe('TraceContext',
                                                 graph_type='train',
                                                 epoch_num=epoch,
                                                 step_num=step,
                                                 batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)

                        switched = not self.update_and_switch(
                            epoch, auto_switch, absorb, retry_fit, verbose)
                        switched_during_epoch |= switched

                        # If a switch occurred, we need to restore the weights
                        if switched or (switched_during_epoch
                                        and not train_after_switch
                                        ) or revert_after_fit:
                            backend.batch_set_value(
                                zip(self.trainable_variables, weights))
                            self.reset_metrics()

                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(
                        epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {
                        'val_' + name: val
                        for name, val in val_logs.items()
                    }
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history
Пример #23
0
def load_weights_from_hdf5_group(f, fh_dict, layers):
    """Implements topological (order-based) weight loading.
    This is revised version. We split the attributes of HDF5 group into another
    JSON file to avoid the heading memory excessing problem. Compared to original
    Keras API, we need to load an extra file IO handle, fh_dict.
    In the same time, the keras_version and backend infomation should be provided
    by fh_dict directly.
    Arguments:
        f:       a pointer to a HDF5 group.
        fh_dict: JSON config dictionary.
        layers:  a list of target layers.
    Raises:
        ValueError: in case of mismatch between provided layers
            and weights file.
    """
    original_keras_version = fh_dict.get('keras_version', '1')
    original_backend = fh_dict.get('backend', None)

    filtered_layers = []
    for layer in layers:
        weights = layer.weights
        if weights:
            filtered_layers.append(layer)

    layer_names = load_attributes_from_hdf5_group(fh_dict, f.name,
                                                  'layer_names')
    filtered_layer_names = []
    for name in layer_names:
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(fh_dict, g.name,
                                                       'weight_names')
        if weight_names:
            filtered_layer_names.append(name)
    layer_names = filtered_layer_names
    if len(layer_names) != len(filtered_layers):
        raise ValueError('You are trying to load a weight file '
                         'containing ' + str(len(layer_names)) +
                         ' layers into a model with ' +
                         str(len(filtered_layers)) + ' layers.')

    # We batch weight value assignments in a single backend call
    # which provides a speedup in TensorFlow.
    weight_value_tuples = []
    for k, name in enumerate(layer_names):
        g = f[name]
        weight_names = load_attributes_from_hdf5_group(fh_dict, g.name,
                                                       'weight_names')
        weight_values = [
            np.asarray(g[weight_name]) for weight_name in weight_names
        ]
        layer = filtered_layers[k]
        symbolic_weights = layer.weights
        weight_values = preprocess_weights_for_loading(layer, weight_values,
                                                       original_keras_version,
                                                       original_backend)
        if len(weight_values) != len(symbolic_weights):
            raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
                             '" in the current model) was found to '
                             'correspond to layer ' + name +
                             ' in the save file. '
                             'However the new layer ' + layer.name +
                             ' expects ' + str(len(symbolic_weights)) +
                             ' weights, but the saved weights have ' +
                             str(len(weight_values)) + ' elements.')
        weight_value_tuples += zip(symbolic_weights, weight_values)
    K.batch_set_value(weight_value_tuples)