예제 #1
0
def restore(sess, global_vars):
    print('from vgg_16 pretrained model')

    reader = NewCheckpointReader(os.path.join(os.getcwd(),
                                              'model/vgg_16.ckpt'))

    # no batchnorm from vgg_16 pretrained model
    restored_var_names = [
        name + ':0' for name in reader.get_variable_to_dtype_map().keys()
        if re.match('^.*weights$', name)
    ]  # skip conv's biases

    restored_vars = [
        var for var in global_vars if var.name in restored_var_names
    ]

    restored_var_names = [var.name[:-2] for var in restored_vars]

    value_ph = tf.placeholder(dtype=tf.float32)

    for i in range(len(restored_var_names)):
        sess.run(
            tf.assign(restored_vars[i], value_ph),
            feed_dict={value_ph: reader.get_tensor(restored_var_names[i])})

    initialized_vars = [var for var in global_vars if not var in restored_vars]

    sess.run(tf.variables_initializer(initialized_vars))
예제 #2
0
def restore(sess, global_vars):
    print('from resnet_v2_50 pretrained model')

    reader = NewCheckpointReader(os.path.join(
        os.getcwd(), 'model/resnet_v2_50.ckpt'))

    # restore both weights and biases from conv and shortcut layers
    restored_var_names = [name + ':0'
                          for name in reader.get_variable_to_dtype_map().keys()
                          if re.match('^.*weights$', name) or re.match('^.*biases$', name)]

    restored_vars = [var for var in global_vars
                     if var.name in restored_var_names]

    restored_var_names = [var.name[:-2] for var in restored_vars]

    value_ph = tf.placeholder(dtype=tf.float32)

    for i in range(len(restored_var_names)):
        sess.run(tf.assign(restored_vars[i], value_ph),
                 feed_dict={value_ph: reader.get_tensor(restored_var_names[i])})

    initialized_vars = [var for var in global_vars
                        if not var in restored_vars]

    sess.run(tf.variables_initializer(initialized_vars))
예제 #3
0
파일: network.py 프로젝트: myclab/tf-yolov2
    def load_ckpt(self, pretrained):
        # restore model with ckpt/pretrain or init
        try:
            print('trying to restore last checkpoint')
            last_ckpt_path = tf.train.latest_checkpoint(
                checkpoint_dir=cfg.ckpt_dir)
            self.saver.restore(self.sess, save_path=last_ckpt_path)
            print('restored checkpoint from:', last_ckpt_path)
        except:
            if self.is_training:
                print('init variables')
                restored_vars = []
                global_vars = tf.global_variables()

                if pretrained:  # restore from tf-slim model
                    if os.path.exists(
                            os.path.join(cfg.workspace, premodel['path'])):
                        print('from ' + premodel['endp'])

                        import re
                        from tensorflow.python.pywrap_tensorflow import NewCheckpointReader

                        reader = NewCheckpointReader(
                            os.path.join(cfg.workspace, premodel['ckpt']))

                        # only restoring conv's weights
                        restored_var_names = [
                            name + ':0' for name in
                            reader.get_variable_to_dtype_map().keys()
                            if re.match(premodel['rptn'], name)
                        ]

                        # update restored variables from pretrained model
                        restored_vars = [
                            var for var in global_vars
                            if var.name in restored_var_names
                        ]

                        # update restored variables' name
                        restored_var_names = [
                            var.name[:-2] for var in restored_vars
                        ]

                        # assignment variables
                        value_ph = tf.placeholder(tf.float32, shape=None)
                        for i in range(len(restored_var_names)):
                            self.sess.run(
                                tf.assign(restored_vars[i], value_ph),
                                feed_dict={
                                    value_ph:
                                    reader.get_tensor(restored_var_names[i])
                                })

                initialized_vars = list(set(global_vars) - set(restored_vars))
                self.sess.run(tf.variables_initializer(initialized_vars))
예제 #4
0
    def _get_reader_for_run(self, run):
        if run in self.readers:
            return self.readers[run]

        config = self.configs[run]
        reader = NewCheckpointReader(config.model_checkpoint_path)
        self.readers[run] = reader
        return reader
예제 #5
0
    def _load_softmax(self, vocab: C2I, full_vocab_path: str,
                      ckpt_dir: str) -> None:
        with open(full_vocab_path) as f:
            full_vocab: List[str] = f.read().strip().split('\n')

        bias_reader = NewCheckpointReader(
            os.path.join(ckpt_dir, 'ckpt-softmax8'))
        full_bias = bias_reader.get_tensor('softmax/b')

        # SoftMax is chunked into 8 arrays of size 100000x1024
        for i in range(8):
            sm_reader = NewCheckpointReader(
                os.path.join(ckpt_dir, f'ckpt-softmax{i}'))

            sm_chunk = sm_reader.get_tensor(f'softmax/W_{i}').astype(
                np.float32)
            bias_chunk = full_bias[i:len(full_bias):8]
            vocab_chunk = full_vocab[i:len(full_bias):8]

            for j, w in enumerate(vocab_chunk):
                sm = sm_chunk[j]
                bias = bias_chunk[j]

                if w in vocab:
                    self.decoder_w[vocab[w]] = sm
                    self.decoder_b[vocab[w]] = bias

                if w == '</S>':
                    self.decoder_w[vocab[vocab.eos_token]] = sm
                    self.decoder_b[vocab[vocab.eos_token]] = bias
                elif w == '<UNK>':
                    self.decoder_w[vocab[vocab.unk_token]] = sm
                    self.decoder_b[vocab[vocab.unk_token]] = bias
예제 #6
0
    def _load_softmax(self, vocab: C2I, full_vocab_path: str,
                      ckpt_dir: str) -> None:
        from tensorflow.python.pywrap_tensorflow import NewCheckpointReader

        with open(full_vocab_path) as f:
            full_vocab: List[str] = f.read().strip().split("\n")

        bias_reader = NewCheckpointReader(
            os.path.join(ckpt_dir, "ckpt-softmax8"))
        full_bias = torch.from_numpy(bias_reader.get_tensor("softmax/b"))

        # SoftMax is chunked into 8 arrays of size 100000x1024
        for i in range(8):
            sm_reader = NewCheckpointReader(
                os.path.join(ckpt_dir, f"ckpt-softmax{i}"))

            sm_chunk = torch.from_numpy(sm_reader.get_tensor(f"softmax/W_{i}"))
            bias_chunk = full_bias[i:full_bias.size(0):8]
            vocab_chunk = full_vocab[i:full_bias.size(0):8]

            for j, w in enumerate(vocab_chunk):
                sm = sm_chunk[j]
                bias = bias_chunk[j]

                if w in vocab:
                    self.decoder_w[vocab[w]] = sm
                    self.decoder_b[vocab[w]] = bias

                if w == "</S>":
                    self.decoder_w[vocab[vocab.eos_token]] = sm
                    self.decoder_b[vocab[vocab.eos_token]] = bias
                elif w == "<UNK>":
                    self.decoder_w[vocab[vocab.unk_token]] = sm
                    self.decoder_b[vocab[vocab.unk_token]] = bias
예제 #7
0
  def _get_reader_for_run(self, run):
    if run in self.readers:
      return self.readers[run]

    config = self._configs[run]
    reader = None
    if config.model_checkpoint_path:
      try:
        reader = NewCheckpointReader(config.model_checkpoint_path)
      except Exception:  # pylint: disable=broad-except
        logging.warning('Failed reading %s', config.model_checkpoint_path)
    self.readers[run] = reader
    return reader
예제 #8
0
    def _load_lstm(self, ckpt_dir: str, device: str) -> None:
        from tensorflow.python.pywrap_tensorflow import NewCheckpointReader

        lstm_reader = NewCheckpointReader(os.path.join(ckpt_dir, "ckpt-lstm"))

        for l in range(self.num_layers):
            # Model weights are divided into 8 chunks
            # Shape: (2048, 32768)
            self.weight[l] = torch.cat(
                [
                    torch.from_numpy(
                        lstm_reader.get_tensor(f"lstm/lstm_{l}/W_{i}"))
                    for i in range(8)
                ],
                dim=0,
            )

            # Shape: (32768,)
            self.bias[l] = torch.from_numpy(
                lstm_reader.get_tensor(f"lstm/lstm_{l}/B"))

            # Shape: (8192, 1024)
            self.weight_P[l] = torch.cat(
                [
                    torch.from_numpy(
                        lstm_reader.get_tensor(f"lstm/lstm_{l}/W_P_{i}"))
                    for i in range(8)
                ],
                dim=0,
            )

            for p in ["f", "i", "o"]:
                # Shape: (8192, 8192)
                self.peepholes[l, p] = torch.from_numpy(
                    lstm_reader.get_tensor(
                        f"lstm/lstm_{l}/W_{p.upper()}_diag"))

            # Cast to float32 tensors
            self.weight[l] = self.weight[l].to(config.DTYPE).to(device)
            self.weight_P[l] = self.weight_P[l].to(config.DTYPE).to(device)
            self.bias[l] = self.bias[l].to(config.DTYPE).to(device)
            for p in ["f", "i", "o"]:
                self.peepholes[l, p] = self.peepholes[l, p].to(
                    config.DTYPE).to(device)
예제 #9
0
    def _load_lstm(self, ckpt_dir: str) -> None:
        lstm_reader = NewCheckpointReader(os.path.join(ckpt_dir, 'ckpt-lstm'))

        for l in range(2):
            # Model weights are divided into 8 chunks
            # (32768, 2048)
            self.weight[l] = np.concatenate([
                lstm_reader.get_tensor(f'lstm/lstm_{l}/W_{i}')
                for i in range(8)
            ]).astype(np.float32).T

            # (32768,)
            self.bias[l] = lstm_reader.get_tensor(f'lstm/lstm_{l}/B').astype(
                np.float32)

            # (8192, 1024)
            self.weight_P[l] = np.concatenate([
                lstm_reader.get_tensor(f'lstm/lstm_{l}/W_P_{i}')
                for i in range(8)
            ]).astype(np.float32)

            for p in ['F', 'I', 'O']:
                self.peepholes[l, p.lower()] = \
                    lstm_reader.get_tensor(f'lstm/lstm_{l}/W_{p}_diag').astype(np.float32)
예제 #10
0
def get_list_of_variables_from_ckpt(ckpt_file):
    reader = NewCheckpointReader(ckpt_file)
    names = reader.get_variable_to_dtype_map()
    return list(names.keys())
예제 #11
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    raise ValueError("Input checkpoint '" + input_checkpoint +
                     "' doesn't exist!")

  if not output_node_names:
    raise ValueError(
        "You need to supply the name of a node to --output_node_names.")

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      #reader = py_checkpoint_reader.NewCheckpointReader(input_checkpoint)
      reader = NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_partition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_partition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          raise ValueError(
              "Models containing partition variables cannot be converted "
              "from checkpoint files. Please pass in a SavedModel using "
              "the flag --input_saved_model_dir.")
        # Models that have been frozen previously do not contain Variables.
        elif _has_no_variables(sess):
          raise ValueError(
              "No variables were found in this model. It is likely the model "
              "was frozen previously. You cannot freeze a graph twice.")
          return 0
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
    tarfile.open(args.checkpoint).extractall(checkpoint_dir)
    files = [
        os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)
    ]
    checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir
elif args.checkpoint.endswith('.zip'):
    checkpoint_dir = args.tmp
    zipfile.ZipFile(args.checkpoint).extractall(checkpoint_dir)
    files = [
        os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)
    ]
    checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir
else:
    checkpoint_dir = args.checkpoint

reader = NewCheckpointReader(
    tensorflow.train.latest_checkpoint(checkpoint_dir))
blobs = {k: reader.get_tensor(k) for k in reader.get_variable_to_shape_map()}

if args.output_path.endswith('.json'):
    with open(args.output_path, 'w') as f:
        json.dump({k: blob.tolist()
                   for k, blob in blobs.items()},
                  f,
                  sort_keys=True,
                  indent=2)
elif args.output_path.endswith('.h5'):
    import h5py
    with h5py.File(args.output_path, 'w') as h:
        h.update(**blobs)
elif args.output_path.endswith('.npy') or args.output_path.endswith('.npz'):
    (np.savez if args.output_path[-1] == 'z' else numpy.save)(args.output_path,