def testWarmStartVarWithVocabBothVarsPartitioned(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") self._create_prev_run_var( "fruit_weights", shape=[4, 1], initializer=[[0.5], [1.], [1.5], [2.]], partitioner=lambda shape, dtype: [2, 1]) # New vocab with elements in reverse order and two new elements. new_vocab_path = self._write_vocab( ["orange", "guava", "banana", "apple", "raspberry", "blueberry"], "new_vocab") # New session and new graph. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: fruit_weights = variable_scope.get_variable( "fruit_weights", shape=[6, 1], initializer=[[0.], [0.], [0.], [0.], [0.], [0.]], partitioner=lambda shape, dtype: [2, 1]) ws_util._warmstart_var_with_vocab(fruit_weights, new_vocab_path, 6, self.get_temp_dir(), prev_vocab_path) sess.run(variables.global_variables_initializer()) self.assertTrue( isinstance(fruit_weights, variables.PartitionedVariable)) fruit_weights_vars = fruit_weights._get_variable_list() self.assertAllEqual([[2.], [1.5], [1.]], fruit_weights_vars[0].eval(sess)) self.assertAllEqual([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess))
def testWarmStartVarWithVocabConstrainedOldVocabSize(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") self._create_prev_run_var( "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]]) # New vocab with elements in reverse order and one new element. new_vocab_path = self._write_vocab( ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab") # New session and new graph. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: fruit_weights = variable_scope.get_variable( "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]]) ws_util._warmstart_var_with_vocab( fruit_weights, new_vocab_path, 5, self.get_temp_dir(), prev_vocab_path, previous_vocab_size=2) sess.run(variables.global_variables_initializer()) # Old vocabulary limited to ['apple', 'banana']. self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]], fruit_weights.eval(sess))