def test_warm_start_success(self):
   """Test if GANEstimator allows explicit warm start variable assignment."""
   # Regex matches all variable names in ckpt except for new_var.
   var_regex = '^(?!.*%s.*)' % self.new_variable_name
   warmstart = WarmStartSettings(ckpt_to_initialize_from=self._model_dir,
                                 vars_to_warm_start=var_regex)
   est_warm = self._test_warm_start(warm_start_from=warmstart)
   full_variable_name = 'Generator/%s' % self.new_variable_name
   self.assertIn(full_variable_name, est_warm.get_variable_names())
   equal_vals = np.array_equal(est_warm.get_variable_value(full_variable_name),
                               self.new_variable_value)
   self.assertTrue(equal_vals)
Example #2
0
def keras_warm_start(ckpt_file: str, ) -> WarmStartSettings:
    return WarmStartSettings(ckpt_to_initialize_from=ckpt_file)