def testLoadExistingVariablesDifferentShapeAllowReshape(self): model_dir = tempfile.mkdtemp(prefix=os.path.join( self.get_temp_dir(), 'load_existing_variables_different_shape_allow_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) init_value0 = [[10.0, 11.0]] init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} with self.test_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[2, 1]) var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} init_fn = variables_lib2.assign_from_checkpoint_fn( model_path, vars_to_restore, reshape_variables=True) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) # Perform the assignment. init_fn(sess) # Request and test the variable values: self.assertAllEqual(np.transpose(np.array(init_value0)), var0.eval()) self.assertEqual(init_value1, var1.eval())
def testMissingVariablesDict(self): model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) init_value0 = 10.0 init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} with self.test_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) var1 = variables_lib2.variable('my_var1', shape=[]) var2 = variables_lib2.variable('my_var2', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2} init_fn = variables_lib2.assign_from_checkpoint_fn( model_path, vars_to_restore, ignore_missing_vars=True) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) # Perform the assignment. init_fn(sess) # Request and test the variable values: self.assertEqual(init_value0, var0.eval()) self.assertEqual(init_value1, var1.eval())
def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self): model_dir = tempfile.mkdtemp(prefix=os.path.join( self.get_temp_dir(), 'load_existing_vars_no_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) init_value0 = [[10.0, 11.0]] init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} with self.test_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[2, 1]) var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) # Perform the assignment. with self.assertRaises(errors_impl.InvalidArgumentError): init_fn(sess)
def testNotFoundError(self): model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'not_found_error')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) init_value0 = 10.0 init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} with self.test_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) var1 = variables_lib2.variable('my_var1', shape=[]) var2 = variables_lib2.variable('my_var2', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2} init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) # Perform the assignment. with self.assertRaises(errors_impl.NotFoundError): init_fn(sess)
def testTrainWithInitFromCheckpoint(self): logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/') logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/') if gfile.Exists(logdir1): # For running on jenkins. gfile.DeleteRecursively(logdir1) if gfile.Exists(logdir2): # For running on jenkins. gfile.DeleteRecursively(logdir2) # First, train the model one step (make sure the error is high). with ops.Graph().as_default(): random_seed.set_random_seed(0) train_op = self.create_train_op() saver = saver_lib.Saver() loss = training.train( train_op, logdir1, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir1, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=1), ], save_checkpoint_secs=None) self.assertGreater(loss, .5) # Next, train the model to convergence. with ops.Graph().as_default(): random_seed.set_random_seed(1) train_op = self.create_train_op() saver = saver_lib.Saver() loss = training.train( train_op, logdir1, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir1, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=300), ], save_checkpoint_secs=None) self.assertIsNotNone(loss) self.assertLess(loss, .02) # Finally, advance the model a single step and validate that the loss is # still low. with ops.Graph().as_default(): random_seed.set_random_seed(2) train_op = self.create_train_op() model_variables = variables_lib2.global_variables() model_path = os.path.join(logdir1, 'model.ckpt-300') assign_fn = variables_lib.assign_from_checkpoint_fn(model_path, model_variables) def init_fn(_, session): assign_fn(session) loss = training.train( train_op, logdir2, scaffold=monitored_session.Scaffold(init_fn=init_fn), hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)]) self.assertIsNotNone(loss) self.assertLess(loss, .02)