Exemplo n.º 1
0
  def __init__(self, saved_model_dir, model_dir=None):
    """Initialize a SavedModelEstimator.

    The SavedModelEstimator loads its model function and variable values from
    the graphs defined in the SavedModel. There is no option to pass in
    `RunConfig` or `params` arguments, because the model function graph is
    defined statically in the SavedModel.

    Args:
      saved_model_dir: Directory containing SavedModel protobuf and subfolders.
      model_dir: Directory to save new checkpoints during training.

    Raises:
      NotImplementedError: If a DistributionStrategy is defined in the config.
        Unless the SavedModelEstimator is subclassed, this shouldn't happen.
    """
    checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir)  # pylint: disable=protected-access
    vars_to_warm_start = [name for name, _ in
                          checkpoint_utils.list_variables(checkpoint)]
    warm_start_settings = estimator_lib.WarmStartSettings(
        ckpt_to_initialize_from=checkpoint,
        vars_to_warm_start=vars_to_warm_start)

    super(SavedModelEstimator, self).__init__(
        model_fn=self._model_fn_from_saved_model, model_dir=model_dir,
        warm_start_from=warm_start_settings)
    if self._train_distribution or self._eval_distribution:
      raise NotImplementedError(
          'SavedModelEstimator currently does not support '
          'DistributionStrategy.')
    self.saved_model_dir = saved_model_dir
    self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
    self._available_modes = self._extract_available_modes()
    def _assert_checkpoint(self,
                           expected_global_step,
                           expected_weights=None,
                           expected_bias=None):
        """Assert the values and shapes of the variables saved in the checkpoint."""
        shapes = {
            name: shape
            for (name,
                 shape) in checkpoint_utils.list_variables(self._model_dir)
        }

        reader = checkpoint_utils.load_checkpoint(self._model_dir)

        self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
        self.assertEqual(expected_global_step,
                         reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))

        self.assertEqual([3, 2], shapes[WEIGHTS_NAME])
        if expected_weights is not None:
            self.assertAllClose(expected_weights,
                                reader.get_tensor(WEIGHTS_NAME))

        self.assertEqual([2], shapes[BIAS_NAME])
        if expected_bias is not None:
            self.assertAllClose(expected_bias, reader.get_tensor(BIAS_NAME))
Exemplo n.º 3
0
 def testGetAllVariables(self):
     checkpoint_dir = self.get_temp_dir()
     with self.test_session() as session:
         _create_checkpoints(session, checkpoint_dir)
     self.assertEqual(checkpoint_utils.list_variables(checkpoint_dir),
                      [("useful_scope/var4", [9, 9]), ("var1", [1, 10]),
                       ("var2", [10, 10]), ("var3", [100, 100])])
Exemplo n.º 4
0
  def _assert_checkpoint(
      self, n_classes, input_units, cell_units, expected_global_step):

    shapes = {
        name: shape for (name, shape) in
        checkpoint_utils.list_variables(self._model_dir)
    }

    self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
    self.assertEqual(
        expected_global_step,
        checkpoint_utils.load_variable(
            self._model_dir, ops.GraphKeys.GLOBAL_STEP))

    # RNN Cell variables.
    if len(cell_units) > 1:
      for i, cell_unit in enumerate(cell_units):
        self.assertEqual([input_units + cell_unit, cell_unit],
                         shapes[MULTI_CELL_WEIGHTS_NAME_PATTERN % i])
        self.assertEqual([cell_unit],
                         shapes[MULTI_CELL_BIAS_NAME_PATTERN % i])
        input_units = cell_unit
    elif len(cell_units) == 1:
      self.assertEqual([input_units + cell_unit, cell_unit],
                       shapes[CELL_WEIGHTS_NAME])
      self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME])

    # Logits variables.
    logits_dimension = n_classes if n_classes > 2 else 1
    self.assertEqual([cell_units[-1], logits_dimension],
                     shapes[LOGITS_WEIGHTS_NAME])
    self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME])
Exemplo n.º 5
0
  def _assert_checkpoint(
      self, expected_global_step, expected_age_weight=None, expected_bias=None):
    logits_dimension = self._logits_dimensions

    shapes = {
        name: shape for (name, shape) in
        checkpoint_utils.list_variables(self._model_dir)
    }

    self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
    self.assertEqual(
        expected_global_step,
        checkpoint_utils.load_variable(
            self._model_dir, ops.GraphKeys.GLOBAL_STEP))

    self.assertEqual([1, logits_dimension], shapes[_AGE_WEIGHT_NAME])
    if expected_age_weight is not None:
      self.assertAllEqual(
          expected_age_weight,
          checkpoint_utils.load_variable(self._model_dir, _AGE_WEIGHT_NAME))

    self.assertEqual([logits_dimension], shapes[_BIAS_NAME])
    if expected_bias is not None:
      self.assertAllEqual(
          expected_bias,
          checkpoint_utils.load_variable(self._model_dir, _BIAS_NAME))
Exemplo n.º 6
0
    def _assert_checkpoint(self, n_classes, input_units, cell_units,
                           expected_global_step):

        shapes = {
            name: shape
            for (name,
                 shape) in checkpoint_utils.list_variables(self._model_dir)
        }

        self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
        self.assertEqual(
            expected_global_step,
            checkpoint_utils.load_variable(self._model_dir,
                                           ops.GraphKeys.GLOBAL_STEP))

        # RNN Cell variables.
        if len(cell_units) > 1:
            for i, cell_unit in enumerate(cell_units):
                self.assertEqual([input_units + cell_unit, cell_unit],
                                 shapes[MULTI_CELL_WEIGHTS_NAME_PATTERN % i])
                self.assertEqual([cell_unit],
                                 shapes[MULTI_CELL_BIAS_NAME_PATTERN % i])
                input_units = cell_unit
        elif len(cell_units) == 1:
            self.assertEqual([input_units + cell_unit, cell_unit],
                             shapes[CELL_WEIGHTS_NAME])
            self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME])

        # Logits variables.
        logits_dimension = n_classes if n_classes > 2 else 1
        self.assertEqual([cell_units[-1], logits_dimension],
                         shapes[LOGITS_WEIGHTS_NAME])
        self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME])
Exemplo n.º 7
0
  def __init__(self, saved_model_dir, model_dir=None):
    """Initialize a SavedModelEstimator.

    The SavedModelEstimator loads its model function and variable values from
    the graphs defined in the SavedModel. There is no option to pass in
    `RunConfig` or `params` arguments, because the model function graph is
    defined statically in the SavedModel.

    Args:
      saved_model_dir: Directory containing SavedModel protobuf and subfolders.
      model_dir: Directory to save new checkpoints during training.

    Raises:
      NotImplementedError: If a DistributionStrategy is defined in the config.
        Unless the SavedModelEstimator is subclassed, this shouldn't happen.
    """
    checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir)  # pylint: disable=protected-access
    vars_to_warm_start = [name for name, _ in
                          checkpoint_utils.list_variables(checkpoint)]
    warm_start_settings = estimator_lib.WarmStartSettings(
        ckpt_to_initialize_from=checkpoint,
        vars_to_warm_start=vars_to_warm_start)

    super(SavedModelEstimator, self).__init__(
        model_fn=self._model_fn_from_saved_model, model_dir=model_dir,
        warm_start_from=warm_start_settings)
    if self._distribution is not None:
      raise NotImplementedError(
          'SavedModelEstimator currently does not support '
          'DistributionStrategy.')
    self.saved_model_dir = saved_model_dir
    self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
    self._available_modes = self._extract_available_modes()
Exemplo n.º 8
0
    def _assert_checkpoint(self,
                           expected_global_step,
                           expected_age_weight=None,
                           expected_bias=None):
        logits_dimension = self._logits_dimensions

        shapes = {
            name: shape
            for (name,
                 shape) in checkpoint_utils.list_variables(self._model_dir)
        }

        self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
        self.assertEqual(
            expected_global_step,
            checkpoint_utils.load_variable(self._model_dir,
                                           ops.GraphKeys.GLOBAL_STEP))

        self.assertEqual([1, logits_dimension],
                         shapes[linear_testing_utils.AGE_WEIGHT_NAME])
        if expected_age_weight is not None:
            self.assertAllEqual(
                expected_age_weight,
                checkpoint_utils.load_variable(
                    self._model_dir, linear_testing_utils.AGE_WEIGHT_NAME))

        self.assertEqual([logits_dimension],
                         shapes[linear_testing_utils.BIAS_NAME])
        if expected_bias is not None:
            self.assertAllEqual(
                expected_bias,
                checkpoint_utils.load_variable(self._model_dir,
                                               linear_testing_utils.BIAS_NAME))
Exemplo n.º 9
0
    def _assert_checkpoint(self,
                           expected_global_step,
                           expected_age_weight=None,
                           expected_bias=None):
        shapes = {
            name: shape
            for (name,
                 shape) in checkpoint_utils.list_variables(self._model_dir)
        }

        self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
        self.assertEqual(
            expected_global_step,
            checkpoint_utils.load_variable(self._model_dir,
                                           ops.GraphKeys.GLOBAL_STEP))

        self.assertEqual([1, 1], shapes[AGE_WEIGHT_NAME])
        if expected_age_weight is not None:
            self.assertEqual(
                expected_age_weight,
                checkpoint_utils.load_variable(self._model_dir,
                                               AGE_WEIGHT_NAME))

        self.assertEqual([1], shapes[BIAS_NAME])
        if expected_bias is not None:
            self.assertEqual(
                expected_bias,
                checkpoint_utils.load_variable(self._model_dir, BIAS_NAME))
Exemplo n.º 10
0
 def load_tensorflow_model_from_ckpt(self, tensorflow_model_path):
     self.tensor_map = checkpoint_utils.list_variables(tensorflow_model_path)
     reader = pywrap_tensorflow.NewCheckpointReader(tensorflow_model_path)
     variable_shape_map = reader.get_variable_to_shape_map()
     weight_map = {}
     for key in variable_shape_map:
         weight_map[key] = reader.get_tensor(key)
     return weight_map
Exemplo n.º 11
0
 def testGetAllVariables(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     _create_checkpoints(session, checkpoint_dir)
   self.assertEqual(
       checkpoint_utils.list_variables(checkpoint_dir),
       [("useful_scope/var4", [9, 9]), ("var1", [1, 10]), ("var2", [10, 10]),
        ("var3", [100, 100])])
  def test_check_checkpoint_variable_names_are_same_on_cpu_and_tpu(
      self, optimizer):
    # Reinitialize the TPU so that we can re-initialize the embeddings with the
    # given optimizer.
    if optimizer != tpu_embedding_v2_utils.SGD:
      self.skip_if_oss()
    strategy = self._get_strategy()
    num_rows = strategy.num_replicas_in_sync

    with strategy.scope():
      first_mid_level_contents = np.ones((num_rows, 4))
      first_mid_level_optimizer = optimizer(learning_rate=0.1)
      initializer = init_ops_v2.Constant(first_mid_level_contents)

      table = tpu_embedding_v2_utils.TableConfig(
          vocabulary_size=num_rows,
          dim=4,
          initializer=initializer,
          combiner='sum',
          name='table')
      feature_config = (tpu_embedding_v2_utils.FeatureConfig(
          table=table, name='feature'),)

      first_mid_level = tpu_embedding_v2.TPUEmbedding(
          feature_config, first_mid_level_optimizer)

      first_mid_level.build(64)

    cpu_mid_level_optimizer = optimizer(learning_rate=0.1)
    cpu_mid_level = tpu_embedding_v2.TPUEmbedding(feature_config,
                                                  cpu_mid_level_optimizer)
    cpu_mid_level.build(64)

    tpu_checkpoint = util.Checkpoint(model=first_mid_level)
    tpu_checkpoint.save(self._get_tmpdir('save-tpu', 'save'))
    tpu_variables = checkpoint_utils.list_variables(
        self._get_tmpdir('save-tpu'))

    cpu_checkpoint = util.Checkpoint(model=cpu_mid_level)
    cpu_checkpoint.save(self._get_tmpdir('save-cpu', 'save'))
    cpu_variables = checkpoint_utils.list_variables(
        self._get_tmpdir('save-cpu'))

    self.assertAllEqual(tpu_variables, cpu_variables)
Exemplo n.º 13
0
def restore_variables_on_create(save_path):
    """ContextManager that restores variables on creation.

    When save_path is None (e.g. No checkpoint), does nothing.
    Otherwise, it preloads all values from checkpoint. When the
    corresponding variable is first created, it assigns the checkpoint
    value to the variable.

    ```python
    with restore_variables_on_create(
        tf.train.latest_checkpoint(checkpoint_dir)):
    ```

  Args:
    save_path: The checkpoint file prefix.

  Yields:
    Nothing.

  Raises:
    NotFoundError: If the variable is not found in checkpoint.
    ValueError: If not used in eager mode.
  """
    if context.in_graph_mode():
        raise ValueError(
            "Currently, restore_variables_on_create can only be used with "
            "eager execution enabled.")
    if save_path:
        ckpt_var_cache = dict()
        reader = checkpoint_utils.load_checkpoint(save_path)
        for k, _ in checkpoint_utils.list_variables(save_path):
            ckpt_var_cache[k] = reader.get_tensor(k)

        old_init = getattr(resource_variable_ops.ResourceVariable,
                           "_init_from_args", None)
        assert old_init, "ResourceVariable misses _init_from_args method."
        setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                _init_from_checkpoint)
        setattr(resource_variable_ops.ResourceVariable, "old_init", old_init)
        setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
                ckpt_var_cache)
    try:
        yield
    except Exception as e:
        raise e
    finally:
        if save_path:
            setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                    old_init)
            setattr(resource_variable_ops.ResourceVariable, "old_init", None)
            setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
                    None)
Exemplo n.º 14
0
def restore_variables_on_create(save_path):
  """ContextManager that restores variables on creation.

    When save_path is None (e.g. No checkpoint), does nothing.
    Otherwise, it preloads all values from checkpoint. When the
    corresponding variable is first created, it assigns the checkpoint
    value to the variable.

    ```python
    with restore_variables_on_create(
        tf.train.latest_checkpoint(checkpoint_dir)):
    ```

  Args:
    save_path: The checkpoint file prefix.

  Yields:
    Nothing.

  Raises:
    NotFoundError: If the variable is not found in checkpoint.
    ValueError: If not used in eager mode.
  """
  if context.in_graph_mode():
    raise ValueError(
        "Currently, restore_variables_on_create can only be used with "
        "eager execution enabled.")
  if save_path:
    ckpt_var_cache = dict()
    reader = checkpoint_utils.load_checkpoint(save_path)
    for k, _ in checkpoint_utils.list_variables(save_path):
      ckpt_var_cache[k] = reader.get_tensor(k)

    old_init = getattr(
        resource_variable_ops.ResourceVariable, "_init_from_args", None)
    assert old_init, "ResourceVariable misses _init_from_args method."
    setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
            _init_from_checkpoint)
    setattr(resource_variable_ops.ResourceVariable, "old_init", old_init)
    setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
            ckpt_var_cache)
  try:
    yield
  except Exception as e:
    raise e
  finally:
    if save_path:
      setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
              old_init)
      setattr(resource_variable_ops.ResourceVariable, "old_init", None)
      setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None)
Exemplo n.º 15
0
  def test_check_checkpoint_variable_names_are_same_on_cpu_and_tpu(self,
                                                                   optimizer):
    # Reinitialize the TPU so that we can re-initialize the embeddings with the
    # given optimizer.
    tpu_strategy_util.initialize_tpu_system(self.resolver)
    optimizer = optimizer(learning_rate=0.1)

    with self.strategy.scope():
      tpu_mid_level = self.build_mid_level(
          self.first_mid_level_contents, optimizer)

    tpu_checkpoint = util.Checkpoint(model=tpu_mid_level)
    tpu_checkpoint.save(_get_tmpdir('save-tpu', 'save'))
    tpu_variables = checkpoint_utils.list_variables(_get_tmpdir('save-tpu'))

    cpu_mid_level = self.build_mid_level(
        self.first_mid_level_contents, optimizer)

    cpu_checkpoint = util.Checkpoint(model=cpu_mid_level)
    cpu_checkpoint.save(_get_tmpdir('save-cpu', 'save'))
    cpu_variables = checkpoint_utils.list_variables(_get_tmpdir('save-cpu'))

    self.assertAllEqual(tpu_variables, cpu_variables)
Exemplo n.º 16
0
    def testFSPath(self):
        checkpoint_dir = pathlib.Path(self.get_temp_dir())
        with self.cached_session() as session:
            v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)  # pylint: disable=unused-variable

        reader = checkpoint_utils.load_checkpoint(checkpoint_dir)
        self.assertAllEqual(reader.get_tensor("var1"), v1)

        self.assertAllEqual(
            checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)

        self.assertEqual(checkpoint_utils.list_variables(checkpoint_dir),
                         [("useful_scope/var4", [9, 9]), ("var1", [1, 10]),
                          ("var2", [10, 10]), ("var3", [100, 100])])
Exemplo n.º 17
0
  def make_checkpoint_and_get_embedding(self, name, model):
    """Saves model to checkpoint name, retrieves embedding variables."""
    checkpoint = util.Checkpoint(model=model)
    checkpoint.save(_get_tmpdir(name, 'save'))

    # Get the name of the parameters variable which should be the only
    # [self.num_rows, 4] shaped tensor in the checkpoint. Note that we do this
    # as the key can change.
    variables = checkpoint_utils.list_variables(_get_tmpdir(name))
    variables = [name for name, size in variables if size == [self.num_rows, 4]]
    if len(variables) != 1:
      raise RuntimeError('Found {} copies of the parameter variable in the '
                         'checkpoint. Exactly one copy exported.'.format(
                             len(variables)))
    return checkpoint_utils.load_variable(_get_tmpdir(name), variables[0])
Exemplo n.º 18
0
 def __init__(self, tensorflow_model_path, max_seq_length, encoder_layers, num_attentions, caffe_model_path_prefix, calc=False):
     self.scopes = ["" for i in range(10)]
     self.scopes[0] = "bert"
     self.tensorflow_model = pywrap_tensorflow.NewCheckpointReader(tensorflow_model_path)
     self.caffe_model = caffe_net.CaffeModel('')
     self.caffe_model.net.name = "bert"
     self.caffe_model_path_prefix = caffe_model_path_prefix
     self.max_seq_length = max_seq_length
     self.encoder_layers = encoder_layers
     self.num_attentions = num_attentions
     self.tensor_map = checkpoint_utils.list_variables(tensorflow_model_path)
     self.batch = 1
     self.check = True;
     self.name_dict = {}
     self.calculate = True;
     self.data_dict = {}
     Operators.set_calculate(calc)
Exemplo n.º 19
0
    def testTrainSpinn(self):
        """Test with fake toy SNLI data and GloVe vectors."""

        # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
        snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
        fake_train_file = self._create_test_data(snli_1_0_dir)

        vocab = data.load_vocabulary(self._temp_data_dir)
        word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

        train_data = data.SnliData(fake_train_file, word2index)
        dev_data = data.SnliData(fake_train_file, word2index)
        test_data = data.SnliData(fake_train_file, word2index)

        # 2. Create a fake config.
        config = _test_spinn_config(data.WORD_VECTOR_LEN,
                                    4,
                                    logdir=os.path.join(
                                        self._temp_data_dir, "logdir"))

        # 3. Test training of a SPINN model.
        trainer = spinn.train_or_infer_spinn(embed, word2index, train_data,
                                             dev_data, test_data, config)

        # 4. Load train loss values from the summary files and verify that they
        #    decrease with training.
        summary_file = glob.glob(os.path.join(config.logdir,
                                              "events.out.*"))[0]
        events = summary_test_util.events_from_file(summary_file)
        train_losses = [
            event.summary.value[0].simple_value for event in events if
            event.summary.value and event.summary.value[0].tag == "train/loss"
        ]
        self.assertEqual(config.epochs, len(train_losses))
        self.assertLess(train_losses[-1], train_losses[0])

        # 5. Verify that checkpoints exist and contains all the expected variables.
        self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
        ckpt_variable_names = [
            item[0] for item in checkpoint_utils.list_variables(config.logdir)
        ]
        self.assertIn("global_step", ckpt_variable_names)
        for v in trainer.variables:
            variable_name = v.name[:v.name.
                                   index(":")] if ":" in v.name else v.name
            self.assertIn(variable_name, ckpt_variable_names)
Exemplo n.º 20
0
    def maybe_restore_on_create(self, save_path):
        """ContextManager that restores variables on creation.

      When save_path is None (e.g. No checkpoint), does nothing.
      Otherwise, it preloads all values from checkpoint. When the
      corresponding variable is first created, it assigns the checkpoint
      value to the variable.

    Args:
      save_path: Same as save_path of retore. If None, do not restore.

    Yields:
      Nothing.

    Raises:
      NotFoundError: If the variable is not found in checkpoint.
    """
        if save_path:
            ckpt_var_cache = dict()
            reader = checkpoint_utils.load_checkpoint(save_path)
            for k, _ in checkpoint_utils.list_variables(save_path):
                ckpt_var_cache[k] = reader.get_tensor(k)

            old_init = getattr(resource_variable_ops.ResourceVariable,
                               "_init_from_args", None)
            assert old_init, "ResourceVariable misses _init_from_args method."
            setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                    _init_from_checkpoint)
            setattr(resource_variable_ops.ResourceVariable, "old_init",
                    old_init)
            setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
                    ckpt_var_cache)
        try:
            yield
        except Exception as e:
            raise e
        finally:
            if save_path:
                setattr(resource_variable_ops.ResourceVariable,
                        "_init_from_args", old_init)
                setattr(resource_variable_ops.ResourceVariable, "old_init",
                        None)
                setattr(resource_variable_ops.ResourceVariable,
                        "ckpt_var_cache", None)
Exemplo n.º 21
0
def get_shape(filename):
    a = cp.list_variables(filename)
    if a[0][0] == 'Variable':
        ni = a[0][1][0] # number of inputs
        no = a[-1][1][0] # number of outputs (-2) used to access the last weight output size of the list... -1 accesses the bias 
    else:
        ai = []
        bi = []
        for i in range(len(a)):
            if 'w' in a[i][0]:
                ai.append(a[i])
            elif 'b'in a[i][0]:
                bi.append(a[i])
            else:
                print('Names assigned to the variables cannot be recognized')
        ni = ai[0][1][1] 
        no = ai[-1][1][0]
    nl = int(len(a)/2) # number of layers
    return ni,no,nl
Exemplo n.º 22
0
  def _assert_checkpoint(self,
                         label_dimension,
                         expected_global_step,
                         expected_bias=None):
    shapes = {
        name: shape
        for (name, shape) in checkpoint_utils.list_variables(self._model_dir)
    }

    self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
    self.assertEqual(expected_global_step,
                     checkpoint_utils.load_variable(self._model_dir,
                                                    ops.GraphKeys.GLOBAL_STEP))

    self.assertEqual([label_dimension], shapes[BIAS_NAME])
    if expected_bias is not None:
      self.assertEqual(expected_bias,
                       checkpoint_utils.load_variable(self._model_dir,
                                                      BIAS_NAME))
Exemplo n.º 23
0
  def _assert_checkpoint(self,
                         label_dimension,
                         expected_global_step,
                         expected_bias=None):
    shapes = {
        name: shape
        for (name, shape) in checkpoint_utils.list_variables(self._model_dir)
    }

    self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
    self.assertEqual(expected_global_step,
                     checkpoint_utils.load_variable(self._model_dir,
                                                    ops.GraphKeys.GLOBAL_STEP))

    self.assertEqual([label_dimension], shapes[BIAS_NAME])
    if expected_bias is not None:
      self.assertEqual(expected_bias,
                       checkpoint_utils.load_variable(self._model_dir,
                                                      BIAS_NAME))
Exemplo n.º 24
0
 def __init__(self,
              tensorflow_model_path,
              caffe_model_path_prefix,
              caffe_model_name,
              check=False,
              calc=False):
     self.scopes = ["" for i in range(100)]
     self.tensorflow_model = pywrap_tensorflow.NewCheckpointReader(
         tensorflow_model_path)
     self.caffe_model = caffe_net.CaffeModel('')
     self.caffe_model.net.name = caffe_model_name
     self.caffe_model_path_prefix = caffe_model_path_prefix
     self.tensor_map = checkpoint_utils.list_variables(
         tensorflow_model_path)
     self.batch = 1
     self.check = check
     self.name_dict = {}
     self.calculate = True
     self.data_dict = {}
     Operators.set_calculate(calc)
Exemplo n.º 25
0
  def testTrainSpinn(self):
    """Test with fake toy SNLI data and GloVe vectors."""

    # 1. Create and load a fake SNLI data file and a fake GloVe embedding file.
    snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0")
    fake_train_file = self._create_test_data(snli_1_0_dir)

    vocab = data.load_vocabulary(self._temp_data_dir)
    word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab)

    train_data = data.SnliData(fake_train_file, word2index)
    dev_data = data.SnliData(fake_train_file, word2index)
    test_data = data.SnliData(fake_train_file, word2index)

    # 2. Create a fake config.
    config = _test_spinn_config(
        data.WORD_VECTOR_LEN, 4,
        logdir=os.path.join(self._temp_data_dir, "logdir"))

    # 3. Test training of a SPINN model.
    trainer = spinn.train_or_infer_spinn(
        embed, word2index, train_data, dev_data, test_data, config)

    # 4. Load train loss values from the summary files and verify that they
    #    decrease with training.
    summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0]
    events = summary_test_util.events_from_file(summary_file)
    train_losses = [event.summary.value[0].simple_value for event in events
                    if event.summary.value
                    and event.summary.value[0].tag == "train/loss"]
    self.assertEqual(config.epochs, len(train_losses))
    self.assertLess(train_losses[-1], train_losses[0])

    # 5. Verify that checkpoints exist and contains all the expected variables.
    self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
    ckpt_variable_names = [
        item[0] for item in checkpoint_utils.list_variables(config.logdir)]
    self.assertIn("global_step", ckpt_variable_names)
    for v in trainer.variables:
      variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name
      self.assertIn(variable_name, ckpt_variable_names)
Exemplo n.º 26
0
def _assert_checkpoint(
    testcase, global_step, input_units, hidden_units, output_units, model_dir):
  """Asserts checkpoint contains expected variables with proper shapes.

  Args:
    testcase: A TestCase instance.
    global_step: Expected global step value.
    input_units: The dimension of input layer.
    hidden_units: Iterable of integer sizes for the hidden layers.
    output_units: The dimension of output layer (logits).
    model_dir: The model directory.
  """
  shapes = {
      name: shape
      for (name, shape) in checkpoint_utils.list_variables(model_dir)
  }

  # Global step.
  testcase.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
  testcase.assertEqual(
      global_step,
      checkpoint_utils.load_variable(
          model_dir, ops.GraphKeys.GLOBAL_STEP))

  # Hidden layer weights.
  prev_layer_units = input_units
  for i in range(len(hidden_units)):
    layer_units = hidden_units[i]
    testcase.assertAllEqual(
        (prev_layer_units, layer_units),
        shapes[HIDDEN_WEIGHTS_NAME_PATTERN % i])
    testcase.assertAllEqual(
        (layer_units,),
        shapes[HIDDEN_BIASES_NAME_PATTERN % i])
    prev_layer_units = layer_units

  # Output layer weights.
  testcase.assertAllEqual((prev_layer_units, output_units),
                          shapes[LOGITS_WEIGHTS_NAME])
  testcase.assertAllEqual((output_units,),
                          shapes[LOGITS_BIASES_NAME])
Exemplo n.º 27
0
def _assert_checkpoint(
    testcase, global_step, input_units, hidden_units, output_units, model_dir):
  """Asserts checkpoint contains expected variables with proper shapes.

  Args:
    testcase: A TestCase instance.
    global_step: Expected global step value.
    input_units: The dimension of input layer.
    hidden_units: Iterable of integer sizes for the hidden layers.
    output_units: The dimension of output layer (logits).
    model_dir: The model directory.
  """
  shapes = {
      name: shape
      for (name, shape) in checkpoint_utils.list_variables(model_dir)
  }

  # Global step.
  testcase.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
  testcase.assertEqual(
      global_step,
      checkpoint_utils.load_variable(
          model_dir, ops.GraphKeys.GLOBAL_STEP))

  # Hidden layer weights.
  prev_layer_units = input_units
  for i in range(len(hidden_units)):
    layer_units = hidden_units[i]
    testcase.assertAllEqual(
        (prev_layer_units, layer_units),
        shapes[dnn_testing_utils.HIDDEN_WEIGHTS_NAME_PATTERN % i])
    testcase.assertAllEqual(
        (layer_units,),
        shapes[dnn_testing_utils.HIDDEN_BIASES_NAME_PATTERN % i])
    prev_layer_units = layer_units

  # Output layer weights.
  testcase.assertAllEqual((prev_layer_units, output_units),
                          shapes[dnn_testing_utils.LOGITS_WEIGHTS_NAME])
  testcase.assertAllEqual((output_units,),
                          shapes[dnn_testing_utils.LOGITS_BIASES_NAME])
Exemplo n.º 28
0
  def maybe_restore_on_create(self, save_path):
    """ContextManager that restores variables on creation.

      When save_path is None (e.g. No checkpoint), does nothing.
      Otherwise, it preloads all values from checkpoint. When the
      corresponding variable is first created, it assigns the checkpoint
      value to the variable.

    Args:
      save_path: Same as save_path of retore. If None, do not restore.

    Yields:
      Nothing.

    Raises:
      NotFoundError: If the variable is not found in checkpoint.
    """
    if save_path:
      ckpt_var_cache = dict()
      reader = checkpoint_utils.load_checkpoint(save_path)
      for k, _ in checkpoint_utils.list_variables(save_path):
        ckpt_var_cache[k] = reader.get_tensor(k)

      old_init = getattr(
          resource_variable_ops.ResourceVariable, "_init_from_args", None)
      assert old_init, "ResourceVariable misses _init_from_args method."
      setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
              _init_from_checkpoint)
      setattr(resource_variable_ops.ResourceVariable, "old_init", old_init)
      setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache",
              ckpt_var_cache)
    try:
      yield
    except Exception as e:
      raise e
    finally:
      if save_path:
        setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
                old_init)
        setattr(resource_variable_ops.ResourceVariable, "old_init", None)
        setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None)
Exemplo n.º 29
0
 def _set_restore_on_create(self, save_path, map_func, user_map_func,
                            existing_variables_by_checkpoint_name):
     """If necessary, request deferred restorations of variables."""
     checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
     checkpointed_variables_to_restore = {}
     for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
         if checkpoint_name in existing_variables_by_checkpoint_name:
             # This variable was already created and restored.
             continue
         # Save the variable for later restoration in a custom getter.
         checkpointed_variables_to_restore[checkpoint_name] = (
             checkpoint_reader.get_tensor(checkpoint_name))
     # Only set a deferred restoration if there are checkpoint variables which
     # have not been assigned to existing variables. Note that this loses out on
     # some opportunity for error checking, but avoids creating
     # _DeferredRestoration objects once a Network has been built (so that
     # restoring in a loop does not take increasing amounts of memory).
     if checkpointed_variables_to_restore:
         if context.in_eager_mode():
             sess = None
         else:
             sess = ops.get_default_session()
         # We need a name for error messages. If we haven't been added to another
         # Network yet, we're top-level.
         self._finalize_name(False)
         self._set_scope()
         # Save a record of this restoration for use in the custom getter.
         deferred_restoration = _DeferredRestoration(
             map_func=map_func,
             map_func_is_user=(user_map_func is not None),
             checkpointed_variables_to_restore=
             checkpointed_variables_to_restore,
             restored_variables={},
             session=sess,
             network_name=self.name,
             network_scope_name=self.scope_name)
         self._deferred_restorations.append(deferred_restoration)
         # Add the deferred registration to non-Network children, and request that
         # Networks propagate the request to their children.
         self._add_deferred_restoration(deferred_restoration)
Exemplo n.º 30
0
 def _set_restore_on_create(self, save_path, map_func, user_map_func,
                            existing_variables_by_checkpoint_name):
   """If necessary, request deferred restorations of variables."""
   checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
   checkpointed_variables_to_restore = {}
   for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
     if checkpoint_name in existing_variables_by_checkpoint_name:
       # This variable was already created and restored.
       continue
     # Save the variable for later restoration in a custom getter.
     checkpointed_variables_to_restore[checkpoint_name] = (
         checkpoint_reader.get_tensor(checkpoint_name))
   # Only set a deferred restoration if there are checkpoint variables which
   # have not been assigned to existing variables. Note that this loses out on
   # some opportunity for error checking, but avoids creating
   # _DeferredRestoration objects once a Network has been built (so that
   # restoring in a loop does not take increasing amounts of memory).
   if checkpointed_variables_to_restore:
     if context.in_eager_mode():
       sess = None
     else:
       sess = ops.get_default_session()
     # We need a name for error messages. If we haven't been added to another
     # Network yet, we're top-level.
     self._finalize_name(False)
     self._set_scope()
     # Save a record of this restoration for use in the custom getter.
     deferred_restoration = _DeferredRestoration(
         map_func=map_func,
         map_func_is_user=(user_map_func is not None),
         checkpointed_variables_to_restore=checkpointed_variables_to_restore,
         restored_variables={},
         session=sess,
         network_name=self.name,
         network_scope_name=self.scope_name)
     self._deferred_restorations.append(deferred_restoration)
     # Add the deferred registration to non-Network children, and request that
     # Networks propagate the request to their children.
     self._add_deferred_restoration(deferred_restoration)
Exemplo n.º 31
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        (masked_lm_loss,
         masked_lm_example_loss, masked_lm_log_probs, output_bias) = \
          get_masked_lm_output(
             bert_config, model.get_sequence_output(), model.get_embedding_table(),
             masked_lm_positions, masked_lm_ids, masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + next_sentence_loss

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            # replace the embeddngs weights
            if full_vocab_file:

                def pick_sub_set(embeddings, full_vocab_file, new_vocab_file):
                    old_vocab = [
                        line.strip()
                        for line in open(full_vocab_file, encoding='utf-8')
                    ]
                    old_vocab_inv = {
                        old_vocab[i]: i
                        for i in range(len(old_vocab))
                    }
                    _new_vocab = [
                        line.strip()
                        for line in open(new_vocab_file, encoding='utf-8')
                    ]
                    if len(_new_vocab) != len(set(_new_vocab)):
                        tf.logging.ERROR("Dupllicated entries in %s." %
                                         new_vocab_file)
                    new_vocab = [x for x in _new_vocab if x in old_vocab_inv]
                    if len(new_vocab) != len(_new_vocab):
                        tf.logging.ERROR(
                            "Threre are OOVs in %s that not in %s. " %
                            (new_vocab_file, full_vocab_file))
                    ids = [old_vocab_inv[c] for c in new_vocab]
                    return np.asarray([embeddings[x] for x in ids])

                from tensorflow.python.training import checkpoint_utils
                ckpt_variables = dict(
                    checkpoint_utils.list_variables(init_checkpoint))
                v1 = model.embedding_table
                v2 = output_bias

                tvars = [x for x in tvars if x.name not in [v1.name, v2.name]]
                for v in [v1, v2]:
                    with tf.device(v.device), tf.device("/cpu:0"):
                        constant_op = tf.constant(
                            pick_sub_set(
                                checkpoint_utils.load_variable(
                                    init_checkpoint, v.op.name),
                                full_vocab_file, subset_vocab_file))
                        v._initializer_op = v.assign(constant_op)
                        v._initial_value = constant_op

            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                with tf.name_scope("metric"):
                    masked_lm_log_probs = tf.reshape(
                        masked_lm_log_probs,
                        [-1, masked_lm_log_probs.shape[-1]])
                    masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                    masked_lm_example_loss = tf.reshape(
                        masked_lm_example_loss, [-1])
                    masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                    masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                    masked_lm_accuracy = tf.metrics.accuracy(
                        labels=masked_lm_ids,
                        predictions=masked_lm_predictions,
                        weights=masked_lm_weights)
                    masked_lm_mean_loss = tf.metrics.mean(
                        values=masked_lm_example_loss,
                        weights=masked_lm_weights)

                    next_sentence_log_probs = tf.reshape(
                        next_sentence_log_probs,
                        [-1, next_sentence_log_probs.shape[-1]])
                    next_sentence_predictions = tf.argmax(
                        next_sentence_log_probs, axis=-1, output_type=tf.int32)
                    next_sentence_labels = tf.reshape(next_sentence_labels,
                                                      [-1])
                    next_sentence_accuracy = tf.metrics.accuracy(
                        labels=next_sentence_labels,
                        predictions=next_sentence_predictions)
                    next_sentence_mean_loss = tf.metrics.mean(
                        values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            """Computes the loss and accuracy of the model."""
            with tf.name_scope("predict"):
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, (-1, FLAGS.max_predictions_per_seq,
                                          masked_lm_log_probs.shape[-1]))
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)

                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)

            predictions = {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "masked_lm_positions": masked_lm_positions,
                "masked_lm_ids": masked_lm_ids,
                "masked_lm_weights": masked_lm_weights,
                "masked_lm_log_probs": masked_lm_log_probs,
                "masked_lm_predictions": masked_lm_predictions,
                "next_sentence_predictions": next_sentence_predictions,
            }

            output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                     loss=total_loss,
                                                     predictions=predictions)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
Exemplo n.º 32
0
def restore_variables_on_create(save_path, map_func=None):
  """ContextManager that restores variables on creation.

    When save_path is None (e.g. No checkpoint), does nothing.
    Otherwise, it preloads all values from checkpoint. When the
    corresponding variable is first created, it assigns the checkpoint
    value to the variable.

    ```python
    with restore_variables_on_create(
        tf.train.latest_checkpoint(checkpoint_dir)):
    ```

  Args:
    save_path: The checkpoint file prefix.
    map_func: A function that given the variable name as argument
        and returns a variable name in checkpoint for restore. If
        None, use the variable with the same name in checkpoint to restore.
        It's an error that the mapped variable name doesn't exist in
        checkpoint.

  Yields:
    Nothing.

  Raises:
    NotFoundError: If the variable is not found in checkpoint.
    ValueError: If not used in eager mode or map_func is not callable.
  """
  if not context.executing_eagerly():
    raise ValueError(
        "Currently, restore_variables_on_create can only be used with "
        "eager execution enabled.")
  if save_path:
    if map_func is None:
      map_func_wrapper = lambda self, x: x
    else:
      if not callable(map_func):
        raise ValueError("map_func must be callable.")
      map_func_wrapper = lambda self, x: map_func(x)

    ckpt_var_cache = dict()
    reader = checkpoint_utils.load_checkpoint(save_path)
    for k, _ in checkpoint_utils.list_variables(save_path):
      ckpt_var_cache[k] = reader.get_tensor(k)

    old_init = getattr(resource_variable_ops.ResourceVariable,
                       "_init_from_args", None)
    assert old_init, "ResourceVariable misses _init_from_args method."
    setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
            _init_from_checkpoint)
    setattr(resource_variable_ops.ResourceVariable, "_old_init", old_init)
    setattr(resource_variable_ops.ResourceVariable, "_map_func",
            map_func_wrapper)
    setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache",
            ckpt_var_cache)
  try:
    yield
  except Exception as e:
    raise e
  finally:
    if save_path:
      setattr(resource_variable_ops.ResourceVariable, "_init_from_args",
              old_init)
      setattr(resource_variable_ops.ResourceVariable, "_old_init", None)
      setattr(resource_variable_ops.ResourceVariable, "_map_func", None)
      setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", None)
Exemplo n.º 33
0
def tf2caffe():
    checkpoint_path = "./VGGnet_fast_rcnn_iter_70000.ckpt"
    tensorName = cp.list_variables(checkpoint_path)

    cf_prototxt = "./vgg14faster-rcnn.prototxt"
    cf_model = "./vgg16faster0814.caffemodel"
    net = caffe.Net(cf_prototxt, caffe.TRAIN)
    
    for key_value in tensorName:
        key_i = key_value[0]
        nddary_data = cp.load_variable(checkpoint_path, key_i)
        try:

            if 'data' in key_i:
                pass
            elif 'weights' in key_i:
                a = key_i.split('/')
                if (len(a) == 2):
                    key_caffe = a[0]
                if (len(a) == 3):
                    key_caffe = "rpn_conv_3x3"
                if key_caffe == 'cls_score':
                    weights = tensor2d_transform(nddary_data)  # 2dim
                if key_caffe == 'bbox_pred':
                    weights = tensor2d_transform(nddary_data)  # 2dim
                if key_caffe == 'fc7':
                    weights = tensor2d_transform(nddary_data)  # 2dim
                if key_caffe == 'fc6':
                    weights = tensor2d_transform(nddary_data)  # 2dim

                if (nddary_data.ndim == 4):
                    if key_caffe == 'rpn_cls_score':
                        a = np.squeeze(nddary_data[0][0])
                        weights = tensor2d_transform(a)  # 2dim
                    elif key_caffe == 'rpn_bbox_pred':
                        a = np.squeeze(nddary_data[0][0])
                        weights = tensor2d_transform(a)  # 2dim
                    else:
                        weights = tensor4d_transform(nddary_data)
                net.params[key_caffe][0].data.flat = weights.flat
            elif 'biases' in key_i:
                a = key_i.split('/')
                if (len(a) == 2):
                    key_caffe = a[0]
                if (len(a) == 3):
                    key_caffe = "rpn_conv_3x3"
                net.params[key_caffe][1].data.flat = nddary_data.flat
            elif 'bn_gamma' in key_i:
                a = key_i.split('/')
                if (len(a) == 3):
                    key_caffe = a[1]
                else:
                    key_caffe = a[2]
                net.params[key_caffe][0].data.flat = nddary_data.flat
            elif '_gamma' in key_i:  # for prelu
                a = key_i.split('/')
                if (len(a) == 3):
                    key_caffe = a[1]
                else:
                    key_caffe = a[2]
                assert (len(net.params[key_caffe]) == 1)
                net.params[key_caffe][0].data.flat = nddary_data.flat
            elif 'mean_rgb' in key_i:
                pass
            elif 'global' in key_i:
                pass
            else:
                sys.exit("Warning!  Unknown tf:{}".format(key_i))

        except KeyError:
            print("\nWarning!  key error tf:{}".format(key_i))

    net.save(cf_model)
    print("\n- Finished.\n")
Exemplo n.º 34
0
def load_svhn_netparams_tf(ckpt_path, trainable=False):
    #data_dict = np.load(ckpt_path, encoding='latin1').item()
    data_dict = cp.list_variables(ckpt_path)
    #cp.load_variable(ckpt_path,'digit1/dense/bias')

    weights = {}  # kernel
    biases = {}  # bias
    mean = {}  # moving_mean
    variance = {}  # moving_variance
    scale = {}  # beta (?)
    offset = {}  # gamma (?)

    netparams = {}
    layer_names = []

    # get layers names
    layers_names = []
    layers_dict = {}
    for each in data_dict:
        words = each[0].split("/")
        if words[0] not in layers_names:
            layers_names.append(words[0])  # save all unique layers names
            layers_dict[words[0]] = []  # initialize a list for each layer

    for each in data_dict:
        words = each[0].split("/")
        layers_dict[words[0]].append(each[0])

    #print('layers_dict')
    #print(layers_dict)

    for layer_name in layers_dict:
        with tf.variable_scope(layer_name):
            for each in layers_dict[layer_name]:
                words = each.split("/")
                param_name = words[-1]

                data = cp.load_variable(ckpt_path, each)

                #print('#################: DEBUG')
                #print(each)
                #print(param_name)

                if param_name == 'kernel':
                    weights[layer_name] = tf.get_variable(
                        name=param_name,
                        initializer=tf.constant(data),
                        trainable=True)
                elif param_name == 'bias':
                    biases[layer_name] = tf.get_variable(
                        name=param_name,
                        initializer=tf.constant(data),
                        trainable=True)
                elif param_name == 'moving_mean':
                    mean[layer_name] = tf.get_variable(
                        name=param_name,
                        initializer=tf.constant(data),
                        trainable=True)
                elif param_name == 'moving_variance':
                    variance[layer_name] = tf.get_variable(
                        name=param_name,
                        initializer=tf.constant(data),
                        trainable=True)
                elif param_name == 'beta':
                    scale[layer_name] = tf.get_variable(
                        name=param_name,
                        initializer=tf.constant(data),
                        trainable=True)
                elif param_name == 'gamma':
                    offset[layer_name] = tf.get_variable(
                        name=param_name,
                        initializer=tf.constant(data),
                        trainable=True)

    #print(len(layer_names))
    netparams['weights'] = weights
    #print(len(weights))
    netparams['biases'] = biases
    netparams['mean'] = mean
    netparams['variance'] = variance
    netparams['scale'] = scale
    netparams['offset'] = offset
    #print(netparams['weights']['hidden1'])
    return netparams