Exemplo n.º 1
0
  def testWarmStartVarWithVocabBothVarsPartitioned(self):
    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                        "old_vocab")
    self._create_prev_run_var(
        "fruit_weights",
        shape=[4, 1],
        initializer=[[0.5], [1.], [1.5], [2.]],
        partitioner=lambda shape, dtype: [2, 1])

    # New vocab with elements in reverse order and two new elements.
    new_vocab_path = self._write_vocab(
        ["orange", "guava", "banana", "apple", "raspberry",
         "blueberry"], "new_vocab")
    # New session and new graph.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        fruit_weights = variable_scope.get_variable(
            "fruit_weights",
            shape=[6, 1],
            initializer=[[0.], [0.], [0.], [0.], [0.], [0.]],
            partitioner=lambda shape, dtype: [2, 1])
        ws_util._warmstart_var_with_vocab(fruit_weights, new_vocab_path, 6,
                                          self.get_temp_dir(), prev_vocab_path)
        sess.run(variables.global_variables_initializer())
        self.assertTrue(
            isinstance(fruit_weights, variables.PartitionedVariable))
        fruit_weights_vars = fruit_weights._get_variable_list()
        self.assertAllEqual([[2.], [1.5], [1.]],
                            fruit_weights_vars[0].eval(sess))
        self.assertAllEqual([[0.5], [0.], [0.]],
                            fruit_weights_vars[1].eval(sess))
Exemplo n.º 2
0
  def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
    prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                        "old_vocab")
    self._create_prev_run_var(
        "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])

    # New vocab with elements in reverse order and one new element.
    new_vocab_path = self._write_vocab(
        ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
    # New session and new graph.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        fruit_weights = variable_scope.get_variable(
            "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
        ws_util._warmstart_var_with_vocab(
            fruit_weights,
            new_vocab_path,
            5,
            self.get_temp_dir(),
            prev_vocab_path,
            previous_vocab_size=2)
        sess.run(variables.global_variables_initializer())
        # Old vocabulary limited to ['apple', 'banana'].
        self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
                            fruit_weights.eval(sess))