Beispiel #1
0
 def testVocabVariableUpdate(self):
   ref_variable, ref_optimizer = _create_variable_and_slots([1, 2, 3, 4, 5, 6, 7])
   new_variable, new_optimizer = _create_variable_and_slots([0, 0, 0, 0, 0, 0])
   mapping = [0, -1, -1, 4, -1, 2]
   expected = [1, 0, 0, 5, 0, 3]
   vocab_lib.update_variable_and_slots(
       ref_variable,
       new_variable,
       ref_optimizer,
       new_optimizer,
       mapping)
   self.assertAllEqual(self.evaluate(new_variable), expected)
   for slot in ("m", "v"):
     self.assertAllEqual(self.evaluate(new_optimizer.get_slot(new_variable, slot)), expected)
Beispiel #2
0
 def _map_variables(inputter_fn, vars_fn):
     mapping, _ = vocab.get_mapping(
         inputter_fn(self).vocabulary_file,
         inputter_fn(new_model).vocabulary_file,
     )
     vars_a, vocab_axes = vars_fn(self)
     vars_b, _ = vars_fn(new_model)
     for var_a, var_b, vocab_axis in zip(vars_a, vars_b, vocab_axes):
         if new_optimizer is not None and optimizer is not None:
             variables = vocab.update_variable_and_slots(
                 var_a,
                 var_b,
                 optimizer,
                 new_optimizer,
                 mapping,
                 vocab_axis=vocab_axis,
             )
         else:
             variables = [
                 vocab.update_variable(var_a,
                                       var_b,
                                       mapping,
                                       vocab_axis=vocab_axis)
             ]
         updated_variables.extend(variables)
     return vars_b
Beispiel #3
0
 def testVocabVariableUpdate(self):
   ref_variable, ref_optimizer = _create_variable_and_slots([1, 2, 3, 4, 5, 6, 7])
   new_variable, new_optimizer = _create_variable_and_slots([0, 0, 0, 0, 0, 0])
   mapping = [0, -1, -1, 4, -1, 2]
   expected = [1, 0, 0, 5, 0, 3]
   vocab_lib.update_variable_and_slots(
       ref_variable,
       new_variable,
       ref_optimizer,
       new_optimizer,
       mapping)
   variables = [new_variable] + [new_optimizer.get_slot(new_variable, slot) for slot in ("m", "v")]
   variables = list(map(self.evaluate, variables))
   for i, index in enumerate(mapping):
     if index < 0:
       continue
     for variable in variables:
       self.assertAllEqual(variable[i], expected[i])
Beispiel #4
0
 def _map_variable(mapping, var_a, var_b, axis=0):
     if new_optimizer is not None and optimizer is not None:
         variables = update_variable_and_slots(
             var_a,
             var_b,
             optimizer,
             new_optimizer,
             mapping,
             vocab_axis=axis,
         )
     else:
         variables = [
             update_variable(var_a, var_b, mapping, vocab_axis=axis)
         ]
     updated_variables.extend(variables)