Exemplo n.º 1
0
    def init_model(self, session, checkpoint_file):
        """
        Initializes network parameters either from ckpt or t7.
        """

        if checkpoint_file.lower().endswith('.t7'):
            # Load network parameters from t7 into a dict
            net_params = soundnet.soundnet5_model_params(
                model_filename=checkpoint_file,
                num_classes=self.num_classes,
                scope=self.scope)
            # Restore all model variables up to conv5 (excluded)
            init_fn = slim.assign_from_values_fn(net_params)
            init_fn(session)
        else:
            # Restore all model variables up to conv5 (excluded)
            model_variables = slim.get_variables(self.scope)
            variables_to_restore = slim.filter_variables(
                model_variables,
                exclude_patterns=['conv5', 'conv6', 'fc1', 'fc2'])
            init_fn = slim.assign_from_checkpoint_fn(checkpoint_file,
                                                     variables_to_restore)
            init_fn(session)

        # Initialize conv5, conv6, fc1 and fc2 variables
        init_op = tf.variables_initializer(
            slim.get_variables(self.scope + '/conv5') +
            slim.get_variables(self.scope + '/conv6') +
            slim.get_model_variables(self.scope + '/fc1') +
            slim.get_model_variables(self.scope + '/fc2'))
        session.run(init_op)
Exemplo n.º 2
0
def get_init_fn(npy_path, return_op=False):
    def load_tf_model_from_npy(save_path):
        import pickle
        import numpy as np
        with open(save_path, 'rb') as f:
            key_value = pickle.load(f)
            for key in sorted(key_value.keys()):
                print(key, np.shape(key_value[key]))
            return key_value

    pre_trained_model_paras = load_tf_model_from_npy(npy_path)

    load_dict = {}
    # vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    vars_list = tf.model_variables()
    for var in vars_list:
        print(var)
    for var in vars_list:
        # print('var_value:',var.value)
        vname = str(var.name)
        if vname.startswith('layer') or vname.startswith('seg'):
            from_name = vname.replace('Block', '')
            from_name = from_name.replace('layer', 'base.')
            from_name = from_name.replace(':0', '')
            # from_name = from_name.replace('/ExponentialMovingAverage:0', '')
            from_name = from_name.replace('weights', 'weight')
            from_name = from_name.replace('gamma', 'weight')
            from_name = from_name.replace('beta', 'bias')
            from_name = from_name.replace('moving_mean', 'running_mean')
            from_name = from_name.replace('moving_variance', 'running_var')
            from_name = from_name.replace('/', '.')
            from_name = from_name.replace('CONV', '')
            from_name = from_name.replace('BN', '')
            from_name = from_name.replace('seg.', 'seg')
            from_name = from_name.replace('biases', 'bias')
        else:
            print('ignore ', vname)
            continue
        try:
            # from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
            # print_tensors_in_checkpoint_file(checkpoint_path)
            var_value = pre_trained_model_paras[from_name]
            var_shape = var.get_shape().as_list()
            from_shape = np.shape(var_value)
            if np.sum(var_shape) != np.sum(from_shape):
                print('Shape not equal! ', vname, var_shape, '<---', from_name,
                      from_shape)
                continue
            print(vname, '<---', from_name)
            load_dict[vname] = var_value
        except:
            print('Skip, ', vname, from_name)
            continue
        # print('var_value:',var_value)
        # assign_ops.append(tf.assign(var, var_value))
    if return_op:
        return slim.assign_from_values(load_dict)
    return slim.assign_from_values_fn(load_dict)
Exemplo n.º 3
0
def get_special_assigns(special_assign_vars):
  init_wts = {}
  special_assign_vars = special_assign_vars.split(',')
  for i in range(len(special_assign_vars) / 2):
    var_name = special_assign_vars[2*i]
    file_path = special_assign_vars[2*i+1]
    with h5py.File(file_path, 'r') as fin:
      init_wts[var_name] = fin['feat'].value
    logging.info('Special Assign: %s with a %s array' % (
      var_name, init_wts[var_name].shape))
  return slim.assign_from_values_fn(init_wts)
Exemplo n.º 4
0
def get_special_assigns(special_assign_vars):
  init_wts = {}
  special_assign_vars = special_assign_vars.split(',')
  for i in range(len(special_assign_vars) / 2):
    var_name = special_assign_vars[2*i]
    file_path = special_assign_vars[2*i+1]
    with h5py.File(file_path, 'r') as fin:
      init_wts[var_name] = fin['feat'].value
    logging.info('Special Assign: %s with a %s array' % (
      var_name, init_wts[var_name].shape))
  return slim.assign_from_values_fn(init_wts)
def load_caffe_weights(weights_path):
    """Initialize the network parameters from a .npy caffe weights file
    Args:
    Path to the .npy file containing the value of the network parameters
    Returns:
    Function that takes a session and initializes the network
    """
    osvos_weights = dict(np.load(weights_path).item())
    vars_corresp = dict()
    vars_corresp['osvos/conv1/conv1_1/weights'] = osvos_weights['conv1_1_w']
    vars_corresp['osvos/conv1/conv1_1/biases'] = osvos_weights['conv1_1_b']
    vars_corresp['osvos/conv1/conv1_2/weights'] = osvos_weights['conv1_2_w']
    vars_corresp['osvos/conv1/conv1_2/biases'] = osvos_weights['conv1_2_b']

    vars_corresp['osvos/conv2/conv2_1/weights'] = osvos_weights['conv2_1_w']
    vars_corresp['osvos/conv2/conv2_1/biases'] = osvos_weights['conv2_1_b']
    vars_corresp['osvos/conv2/conv2_2/weights'] = osvos_weights['conv2_2_w']
    vars_corresp['osvos/conv2/conv2_2/biases'] = osvos_weights['conv2_2_b']

    vars_corresp['osvos/conv3/conv3_1/weights'] = osvos_weights['conv3_1_w']
    vars_corresp['osvos/conv3/conv3_1/biases'] = osvos_weights['conv3_1_b']
    vars_corresp['osvos/conv3/conv3_2/weights'] = osvos_weights['conv3_2_w']
    vars_corresp['osvos/conv3/conv3_2/biases'] = osvos_weights['conv3_2_b']
    vars_corresp['osvos/conv3/conv3_3/weights'] = osvos_weights['conv3_3_w']
    vars_corresp['osvos/conv3/conv3_3/biases'] = osvos_weights['conv3_3_b']

    vars_corresp['osvos/conv4/conv4_1/weights'] = osvos_weights['conv4_1_w']
    vars_corresp['osvos/conv4/conv4_1/biases'] = osvos_weights['conv4_1_b']
    vars_corresp['osvos/conv4/conv4_2/weights'] = osvos_weights['conv4_2_w']
    vars_corresp['osvos/conv4/conv4_2/biases'] = osvos_weights['conv4_2_b']
    vars_corresp['osvos/conv4/conv4_3/weights'] = osvos_weights['conv4_3_w']
    vars_corresp['osvos/conv4/conv4_3/biases'] = osvos_weights['conv4_3_b']

    vars_corresp['osvos/conv5/conv5_1/weights'] = osvos_weights['conv5_1_w']
    vars_corresp['osvos/conv5/conv5_1/biases'] = osvos_weights['conv5_1_b']
    vars_corresp['osvos/conv5/conv5_2/weights'] = osvos_weights['conv5_2_w']
    vars_corresp['osvos/conv5/conv5_2/biases'] = osvos_weights['conv5_2_b']
    vars_corresp['osvos/conv5/conv5_3/weights'] = osvos_weights['conv5_3_w']
    vars_corresp['osvos/conv5/conv5_3/biases'] = osvos_weights['conv5_3_b']

    vars_corresp['osvos/conv2_2_16/weights'] = osvos_weights['conv2_2_16_w']
    vars_corresp['osvos/conv2_2_16/biases'] = osvos_weights['conv2_2_16_b']
    vars_corresp['osvos/conv3_3_16/weights'] = osvos_weights['conv3_3_16_w']
    vars_corresp['osvos/conv3_3_16/biases'] = osvos_weights['conv3_3_16_b']
    vars_corresp['osvos/conv4_3_16/weights'] = osvos_weights['conv4_3_16_w']
    vars_corresp['osvos/conv4_3_16/biases'] = osvos_weights['conv4_3_16_b']
    vars_corresp['osvos/conv5_3_16/weights'] = osvos_weights['conv5_3_16_w']
    vars_corresp['osvos/conv5_3_16/biases'] = osvos_weights['conv5_3_16_b']

    vars_corresp['osvos/score-dsn_2/weights'] = osvos_weights['score-dsn_2_w']
    vars_corresp['osvos/score-dsn_2/biases'] = osvos_weights['score-dsn_2_b']
    vars_corresp['osvos/score-dsn_3/weights'] = osvos_weights['score-dsn_3_w']
    vars_corresp['osvos/score-dsn_3/biases'] = osvos_weights['score-dsn_3_b']
    vars_corresp['osvos/score-dsn_4/weights'] = osvos_weights['score-dsn_4_w']
    vars_corresp['osvos/score-dsn_4/biases'] = osvos_weights['score-dsn_4_b']
    vars_corresp['osvos/score-dsn_5/weights'] = osvos_weights['score-dsn_5_w']
    vars_corresp['osvos/score-dsn_5/biases'] = osvos_weights['score-dsn_5_b']

    vars_corresp['osvos/upscore-fuse/weights'] = osvos_weights[
        'new-score-weighting_w']
    vars_corresp['osvos/upscore-fuse/biases'] = osvos_weights[
        'new-score-weighting_b']
    return slim.assign_from_values_fn(vars_corresp)
Exemplo n.º 6
0
def restore_model(checkpoint_paths,
                  variables_to_restore,
                  ignore_missing_vars=False,
                  num_streams=1,
                  checkpoint_style=None,
                  special_assign_vars=None):
    all_ops = []
    if len(checkpoint_paths) == 1 and num_streams > 1:
      logging.info('Provided one checkpoint for multi-stream '
                   'network. Will use this as a saved model '
                   'with this exact multi stream network.')
      all_ops.append(slim.assign_from_checkpoint_fn(
        checkpoint_paths[0],
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars))
    else:
      for sid in range(num_streams):
        this_checkpoint_style = checkpoint_style.split(',')[sid] if \
                                checkpoint_style is not None else None
        checkpoint_path = checkpoint_paths[sid]
        # assert tf.gfile.Exists(checkpoint_path)
        this_stream_name = 'stream%d/' % sid
        this_checkpoint_variables = [var for var in variables_to_restore
                                     if var in
                                     slim.get_model_variables(this_stream_name)]
        if checkpoint_path.endswith('.npy'):
          vars_to_restore_names = [
              el.name for el in this_checkpoint_variables]
          key_name_mapper = var_name_mapper.map()
          init_weights = np.load(checkpoint_path).item()
          init_weights_final = {}
          vars_restored = []
          for key in init_weights.keys():
            for subkey in init_weights[key].keys():
              prefix = this_stream_name
              if this_checkpoint_style == 'v2_withStream':
                prefix = 'stream0/'  # because any model trained with stream
                                     # will have that stream as 0
              final_key_name = prefix + key_name_mapper(
                  key + '/' + subkey)
              if final_key_name not in vars_to_restore_names:
                logging.error('Not using %s from npy' % final_key_name)
                continue
              
              target_shape = slim.get_model_variables(
                final_key_name)[0].get_shape().as_list()
              pretrained_wts = init_weights[key][subkey]
              target_shape_squeezed = np.delete(
                target_shape, np.where(np.array(target_shape) == 1))
              pretrained_shape_squeezed = np.delete(
                pretrained_wts.shape, np.where(np.array(pretrained_wts.shape) == 1))
              if np.all(target_shape_squeezed !=
                        pretrained_shape_squeezed):
                logging.error('Shape mismatch var: %s from npy [%s vs %s]' 
                              % (final_key_name, target_shape,
                                 pretrained_wts.shape))

              init_weights_final[final_key_name] = \
                  pretrained_wts
              vars_restored.append(final_key_name)
          init_weights = init_weights_final
          for v in vars_to_restore_names:
            if v not in vars_restored:
              logging.fatal('No weights found for %s' % v)
          all_ops.append(slim.assign_from_values_fn(
              init_weights))
        else:
          if this_checkpoint_style != 'v2_withStream':
            all_ops.append(slim.assign_from_checkpoint_fn(
                checkpoint_path,
                # stripping the stream name to map variables
                dict(
                  [('/'.join(el.name.split('/')[1:]).split(':')[0], el) for
                      el in this_checkpoint_variables]),
                ignore_missing_vars=ignore_missing_vars))
          else:
            all_ops.append(slim.assign_from_checkpoint_fn(
                checkpoint_path,
                # stripping the stream name to map variables, to stream0,
                # as the model is v2_withStream, hence must be trained with
                # stream0/ prefix
                dict(
                  [('/'.join(['stream0'] + el.name.split('/')[1:]).split(':')[0], el) for
                      el in this_checkpoint_variables]),
                ignore_missing_vars=ignore_missing_vars))
    if special_assign_vars is not None:
      all_ops.append(get_special_assigns(special_assign_vars))
    def combined(sess):
      for op in all_ops:
        op(sess)
    return combined
Exemplo n.º 7
0
def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False,
                              reshape_variables=False, resize_variables=False):
  """Modified function from
  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/framework/python/ops/variables.py
  Mod by rgirdhar to allow for repeating the channels dimension in case a layer
  does not match. It's useful for setting the first layer in flow models for
  videos. Does this only when resize_variables is True.
  """
  """Returns a function that assigns specific variables from a checkpoint.

  If ignore_missing_vars is True and no variables are found in the checkpoint
  it returns None.

  Args:
    model_path: The full path to the model checkpoint. To get latest checkpoint
        use `model_path = tf.train.latest_checkpoint(checkpoint_dir)`
    var_list: A list of `Variable` objects or a dictionary mapping names in the
        checkpoint to the corresponding variables to initialize. If empty or
        `None`, it would return `no_op(), None`.
    ignore_missing_vars: Boolean, if True it would ignore variables missing in
        the checkpoint with a warning instead of failing.
    reshape_variables: Boolean, if True it would automatically reshape variables
        which are of different shape then the ones stored in the checkpoint but
        which have the same number of elements.
    resize_variables: Boolean, if True it would repeat the channels to match
        the target variable dimensions

  Returns:
    A function that takes a single argument, a `tf.Session`, that applies the
    assignment operation. If no matching variables were found in the checkpoint
    then `None` is returned.

  Raises:
    ValueError: If var_list is empty.
  """
  if not var_list:
    raise ValueError('var_list cannot be empty')
  reader = pywrap_tensorflow.NewCheckpointReader(model_path)
  if isinstance(var_list, dict):
    var_dict = var_list
  else:
    var_dict = {var.op.name: var for var in var_list}
  available_vars = {}
  for var in var_dict:
    if reader.has_tensor(var):
      go_ahead = False
      V = reader.get_tensor(var)
      ckpt_shape = list(V.shape)
      target_shape = var_dict[var].get_shape().as_list()
      if np.all(ckpt_shape == target_shape):
        go_ahead = True
      else:
        if resize_variables:
          logging.warning('Resizing to assign to variable {} to {} from {}'.format(
            var, var_dict[var].get_shape().as_list(),
            V.shape))
          V = np.repeat(
            np.mean(V, axis=-2, keepdims=True),
            repeats=target_shape[-2],
            axis=-2)
          ckpt_shape = list(V.shape)
          if np.all(ckpt_shape == target_shape):
            logging.info('Was able to match shape, so restoring the var :-)')
            go_ahead = True
          else:
            logging.error('Was not able to match shape, not restoring it!!!')
            go_ahead = False
        else:
          logging.error('Found a shape mismatch. Set resize_var to true to '
                        'do a hacky shape copy.')
      if go_ahead:
        available_vars[var] = V
    else:
      logging.warning(
          'Variable %s missing in checkpoint %s', var, model_path)
      if not ignore_missing_vars:
        raise ValueError()
  return slim.assign_from_values_fn(available_vars)
Exemplo n.º 8
0
def restore_model(checkpoint_path,
                  variables_to_restore,
                  ignore_missing_vars=False,
                  var_name_mapper_type=None):
  all_ops = []
  checkpoint_variables = variables_to_restore
  if checkpoint_path.endswith('.npy'):
    vars_to_restore_names = [
      el.name for el in checkpoint_variables]
    key_name_mapper = var_name_mapper.map(var_name_mapper_type)
    init_weights = np.load(checkpoint_path).item()
    init_weights_final = {}
    vars_restored = []
    for key in init_weights.keys():
      for subkey in init_weights[key].keys():
        final_key_name = key_name_mapper(
          key + '/' + subkey)
        if final_key_name not in vars_to_restore_names:
          logging.info('Not using %s from npy' % final_key_name)
          continue
        target_shape = slim.get_model_variables(
          final_key_name)[0].get_shape().as_list()
        pretrained_wts = init_weights[key][subkey].copy()
        target_shape_squeezed = np.delete(
          target_shape, np.where(np.array(target_shape) == 1))
        pretrained_shape_squeezed = np.delete(
          pretrained_wts.shape, np.where(np.array(pretrained_wts.shape) == 1))

        go_ahead = False  # whether or not I'll be able to copy these weights
        if np.any(target_shape_squeezed !=
                  pretrained_shape_squeezed):
          logging.info('Shape mismatch var: %s from npy [%s vs %s]. ' % (
                       final_key_name, target_shape,
                       pretrained_wts.shape))
          if pretrained_shape_squeezed[-2] != target_shape_squeezed[-2]:
            logging.info('Trying repeating channels to make it similar.')
            pretrained_wts = np.repeat(
              np.mean(pretrained_wts, axis=-2, keepdims=True),
              repeats=target_shape_squeezed[-2],
              axis=-2)
            if np.all(target_shape_squeezed == pretrained_wts.shape):
              logging.info('Success! Copying the hacked weights.')
              go_ahead = True
            else:
              logging.info('Couldnot match the weights still.')
        else:
          go_ahead = True
        if go_ahead:
          init_weights_final[final_key_name] = \
            pretrained_wts
          vars_restored.append(final_key_name)
    init_weights = init_weights_final
    for v in vars_to_restore_names:
      if v not in vars_restored:
        logging.fatal('No weights found for %s' % v)
        if not ignore_missing_vars:
          raise ValueError()
    all_ops.append(slim.assign_from_values_fn(init_weights))
  else:
    all_ops.append(assign_from_checkpoint_fn(
      checkpoint_path, checkpoint_variables,
      ignore_missing_vars=ignore_missing_vars,
      resize_variables=True))
  def combined(sess):
    for op in all_ops:
      op(sess)
  return combined
Exemplo n.º 9
0
def restore_model(checkpoint_paths,
                  variables_to_restore,
                  ignore_missing_vars=False,
                  num_streams=1,
                  checkpoint_style=None,
                  special_assign_vars=None):
    all_ops = []
    if len(checkpoint_paths) == 1 and num_streams > 1:
        logging.info('Provided one checkpoint for multi-stream '
                     'network. Will use this as a saved model '
                     'with this exact multi stream network.')
        all_ops.append(
            slim.assign_from_checkpoint_fn(
                checkpoint_paths[0],
                variables_to_restore,
                ignore_missing_vars=ignore_missing_vars))
    else:
        for sid in range(num_streams):
            this_checkpoint_style = checkpoint_style.split(',')[sid] if \
                                    checkpoint_style is not None else None
            checkpoint_path = checkpoint_paths[sid]
            # assert tf.gfile.Exists(checkpoint_path)
            this_stream_name = 'stream%d/' % sid
            this_checkpoint_variables = [
                var for var in variables_to_restore
                if var in slim.get_model_variables(this_stream_name)
            ]
            if checkpoint_path.endswith('.npy'):
                vars_to_restore_names = [
                    el.name for el in this_checkpoint_variables
                ]
                key_name_mapper = var_name_mapper.map()
                init_weights = np.load(checkpoint_path).item()
                init_weights_final = {}
                vars_restored = []
                for key in init_weights.keys():
                    for subkey in init_weights[key].keys():
                        prefix = this_stream_name
                        if this_checkpoint_style == 'v2_withStream':
                            prefix = 'stream0/'  # because any model trained with stream
                            # will have that stream as 0
                        final_key_name = prefix + key_name_mapper(key + '/' +
                                                                  subkey)
                        if final_key_name not in vars_to_restore_names:
                            logging.error('Not using %s from npy' %
                                          final_key_name)
                            continue

                        target_shape = slim.get_model_variables(
                            final_key_name)[0].get_shape().as_list()
                        pretrained_wts = init_weights[key][subkey]
                        target_shape_squeezed = np.delete(
                            target_shape,
                            np.where(np.array(target_shape) == 1))
                        pretrained_shape_squeezed = np.delete(
                            pretrained_wts.shape,
                            np.where(np.array(pretrained_wts.shape) == 1))
                        if np.all(target_shape_squeezed !=
                                  pretrained_shape_squeezed):
                            logging.error(
                                'Shape mismatch var: %s from npy [%s vs %s]' %
                                (final_key_name, target_shape,
                                 pretrained_wts.shape))

                        init_weights_final[final_key_name] = \
                            pretrained_wts
                        vars_restored.append(final_key_name)
                init_weights = init_weights_final
                for v in vars_to_restore_names:
                    if v not in vars_restored:
                        logging.fatal('No weights found for %s' % v)
                all_ops.append(slim.assign_from_values_fn(init_weights))
            else:
                if this_checkpoint_style != 'v2_withStream':
                    all_ops.append(
                        slim.assign_from_checkpoint_fn(
                            checkpoint_path,
                            # stripping the stream name to map variables
                            dict([('/'.join(
                                el.name.split('/')[1:]).split(':')[0], el)
                                  for el in this_checkpoint_variables]),
                            ignore_missing_vars=ignore_missing_vars))
                else:
                    all_ops.append(
                        slim.assign_from_checkpoint_fn(
                            checkpoint_path,
                            # stripping the stream name to map variables, to stream0,
                            # as the model is v2_withStream, hence must be trained with
                            # stream0/ prefix
                            dict([(
                                '/'.join(['stream0'] +
                                         el.name.split('/')[1:]).split(':')[0],
                                el) for el in this_checkpoint_variables]),
                            ignore_missing_vars=ignore_missing_vars))
    if special_assign_vars is not None:
        all_ops.append(get_special_assigns(special_assign_vars))

    def combined(sess):
        for op in all_ops:
            op(sess)

    return combined