def testInitFromCheckpointWithScopes(self): model_dir = tempfile.mkdtemp(prefix=os.path.join( self.get_temp_dir(), 'init_from_checkpoint_with_scopes')) init_value0 = np.asarray( [1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1)) init_value1 = np.asarray( [2.0, 4.0, 6.0, 8.0], dtype=np.float32).reshape((2, 1, 2)) var_names_to_values = {'layer0/v0': init_value0, 'layer1/v1': init_value1} with self.test_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) with variable_scope.variable_scope('my_model/my_layer0'): var0 = variables_lib2.variable('my_var0', shape=init_value0.shape) with variable_scope.variable_scope('my_model/my_layer1'): var1 = variables_lib2.variable('my_var1', shape=init_value1.shape) vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1} op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) # Perform the assignment. sess.run(op, feed_dict) # Request and test the variable values: self.assertAllEqual(init_value0, var0.eval()) self.assertAllEqual(init_value1, var1.eval())
def testLoadExistingVariables(self): model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) init_value0 = 10.0 init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} with self.test_session() as sess: model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) # Perform the assignment. sess.run(op, feed_dict) # Request and test the variable values: self.assertEqual(init_value0, var0.eval()) self.assertEqual(init_value1, var1.eval())
def testRaisesValueErrorIfAVariableIsntFound(self): model_dir = tempfile.mkdtemp(prefix=os.path.join( self.get_temp_dir(), 'raises_value_error_if_var_isnt_found')) init_value0 = 10.0 init_value1 = 20.0 var_names_to_values = {'v0': init_value0, 'v1': init_value1} with self.test_session(): model_path = self.create_checkpoint_from_values(var_names_to_values, model_dir) var0 = variables_lib2.variable('my_var0', shape=[]) var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0_fake': var0, 'v1': var1} with self.assertRaises(ValueError): variables_lib2.assign_from_checkpoint(model_path, vars_to_restore)
def testTrainWithInitFromCheckpoint(self): logdir1 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs1') logdir2 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs2') # First, train the model one step (make sure the error is high). with ops.Graph().as_default(): random_seed.set_random_seed(0) train_op = self.create_train_op() loss = learning.train(train_op, logdir1, number_of_steps=1) self.assertGreater(loss, .5) # Next, train the model to convergence. with ops.Graph().as_default(): random_seed.set_random_seed(1) train_op = self.create_train_op() loss = learning.train(train_op, logdir1, number_of_steps=300, log_every_n_steps=10) self.assertIsNotNone(loss) self.assertLess(loss, .02) # Finally, advance the model a single step and validate that the loss is # still low. with ops.Graph().as_default(): random_seed.set_random_seed(2) train_op = self.create_train_op() model_variables = variables_lib.global_variables() model_path = os.path.join(logdir1, 'model.ckpt-300') init_op = variables_lib.global_variables_initializer() op, init_feed_dict = variables_lib2.assign_from_checkpoint( model_path, model_variables) def InitAssignFn(sess): sess.run(op, init_feed_dict) loss = learning.train(train_op, logdir2, number_of_steps=1, init_op=init_op, init_fn=InitAssignFn) self.assertIsNotNone(loss) self.assertLess(loss, .02)
def testTrainWithInitFromCheckpoint(self): logdir1 = os.path.join( tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs1') logdir2 = os.path.join( tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs2') # First, train the model one step (make sure the error is high). with ops.Graph().as_default(): random_seed.set_random_seed(0) train_op = self.create_train_op() loss = learning.train(train_op, logdir1, number_of_steps=1) self.assertGreater(loss, .5) # Next, train the model to convergence. with ops.Graph().as_default(): random_seed.set_random_seed(1) train_op = self.create_train_op() loss = learning.train( train_op, logdir1, number_of_steps=300, log_every_n_steps=10) self.assertIsNotNone(loss) self.assertLess(loss, .02) # Finally, advance the model a single step and validate that the loss is # still low. with ops.Graph().as_default(): random_seed.set_random_seed(2) train_op = self.create_train_op() model_variables = variables_lib.global_variables() model_path = os.path.join(logdir1, 'model.ckpt-300') init_op = variables_lib.global_variables_initializer() op, init_feed_dict = variables_lib2.assign_from_checkpoint( model_path, model_variables) def InitAssignFn(sess): sess.run(op, init_feed_dict) loss = learning.train( train_op, logdir2, number_of_steps=1, init_op=init_op, init_fn=InitAssignFn) self.assertIsNotNone(loss) self.assertLess(loss, .02)