def testGetVariablesDontReturnsTransients(self): with self.test_session(): with variable_scope.variable_scope('A'): variables_lib2.local_variable(0) with variable_scope.variable_scope('B'): variables_lib2.local_variable(0) self.assertEquals([], variables_lib2.get_variables('A')) self.assertEquals([], variables_lib2.get_variables('B'))
def testGetVariablesSuffix(self): with self.test_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('A'): b = variables_lib2.variable('b', [5]) self.assertEquals([a], variables_lib2.get_variables(suffix='a')) self.assertEquals([b], variables_lib2.get_variables(suffix='b'))
def testGetVariablesReturns(self): with self.test_session(): with variable_scope.variable_scope('A'): a = variables_lib2.model_variable('a', [5]) with variable_scope.variable_scope('B'): b = variables_lib2.model_variable('a', [5]) self.assertEquals([a], variables_lib2.get_variables('A')) self.assertEquals([b], variables_lib2.get_variables('B'))
def testGetVariablesWithScope(self): with self.test_session(): with variable_scope.variable_scope('A') as var_scope: a = variables_lib2.variable('a', [5]) b = variables_lib2.variable('b', [5]) self.assertSetEqual( set([a, b]), set(variables_lib2.get_variables(var_scope)))
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self): # First, train only the weights of the model. with ops.Graph().as_default(): random_seed.set_random_seed(0) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) weights, biases = variables_lib.get_variables() train_op = training.create_train_op(total_loss, optimizer) train_weights = training.create_train_op( total_loss, optimizer, variables_to_train=[weights]) train_biases = training.create_train_op( total_loss, optimizer, variables_to_train=[biases]) with self.test_session() as session: # Initialize the variables. session.run(variables_lib2.global_variables_initializer()) # Get the initial weights and biases values. weights_values, biases_values = session.run([weights, biases]) self.assertGreater(np.linalg.norm(weights_values), 0) self.assertAlmostEqual(np.linalg.norm(biases_values), 0) # Update weights and biases. loss = session.run(train_op) self.assertGreater(loss, .5) new_weights, new_biases = session.run([weights, biases]) # Check that the weights and biases have been updated. self.assertGreater( np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) weights_values, biases_values = new_weights, new_biases # Update only weights. loss = session.run(train_weights) self.assertGreater(loss, .5) new_weights, new_biases = session.run([weights, biases]) # Check that the weights have been updated, but biases have not. self.assertGreater( np.linalg.norm(weights_values - new_weights), 0) self.assertAlmostEqual( np.linalg.norm(biases_values - new_biases), 0) weights_values = new_weights # Update only biases. loss = session.run(train_biases) self.assertGreater(loss, .5) new_weights, new_biases = session.run([weights, biases]) # Check that the biases have been updated, but weights have not. self.assertAlmostEqual( np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
def testReuseVariable(self): with self.test_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', []) with variable_scope.variable_scope('A', reuse=True): b = variables_lib2.variable('a', []) self.assertEquals(a, b) self.assertListEqual([a], variables_lib2.get_variables())
def testWrongIncludeGetVariablesToRestore(self): with self.test_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) with variable_scope.variable_scope('B'): b = variables_lib2.variable('a', [5]) self.assertEquals([a, b], variables_lib2.get_variables()) self.assertEquals([], variables_lib2.get_variables_to_restore(['a']))
def testExcludeGetMixedVariablesToRestore(self): with self.test_session(): with variable_scope.variable_scope('A'): a = variables_lib2.variable('a', [5]) b = variables_lib2.variable('b', [5]) with variable_scope.variable_scope('B'): c = variables_lib2.variable('c', [5]) d = variables_lib2.variable('d', [5]) self.assertEquals([a, b, c, d], variables_lib2.get_variables()) self.assertEquals( [b, d], variables_lib2.get_variables_to_restore(exclude=['A/a', 'B/c']))
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self): # First, train only the weights of the model. with ops.Graph().as_default(): random_seed.set_random_seed(0) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) weights, biases = variables_lib.get_variables() train_op = training.create_train_op(total_loss, optimizer) train_weights = training.create_train_op( total_loss, optimizer, variables_to_train=[weights]) train_biases = training.create_train_op( total_loss, optimizer, variables_to_train=[biases]) with session_lib.Session() as sess: # Initialize the variables. sess.run(variables_lib2.global_variables_initializer()) # Get the intial weights and biases values. weights_values, biases_values = sess.run([weights, biases]) self.assertGreater(np.linalg.norm(weights_values), 0) self.assertAlmostEqual(np.linalg.norm(biases_values), 0) # Update weights and biases. loss = sess.run(train_op) self.assertGreater(loss, .5) new_weights, new_biases = sess.run([weights, biases]) # Check that the weights and biases have been updated. self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) weights_values, biases_values = new_weights, new_biases # Update only weights. loss = sess.run(train_weights) self.assertGreater(loss, .5) new_weights, new_biases = sess.run([weights, biases]) # Check that the weights have been updated, but biases have not. self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0) weights_values = new_weights # Update only biases. loss = sess.run(train_biases) self.assertGreater(loss, .5) new_weights, new_biases = sess.run([weights, biases]) # Check that the biases have been updated, but weights have not. self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
def test_variable_reuse(self): """Test that variable scopes work and inference on a real-ish case.""" tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3]) tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3]) tensor2_ref = array_ops.zeros([4, 2, 3]) tensor2_examples = array_ops.zeros([2, 2, 3]) with variable_scope.variable_scope('dummy_scope', reuse=True): with self.assertRaisesRegexp( ValueError, 'does not exist, or was not created with ' 'tf.get_variable()'): virtual_batchnorm.VBN(tensor1_ref) vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1') vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2') # Fetch reference and examples after virtual batch normalization. Also # fetch in variable reuse case. to_fetch = [] to_fetch.append(vbn1.reference_batch_normalization()) to_fetch.append(vbn2.reference_batch_normalization()) to_fetch.append(vbn1(tensor1_examples)) to_fetch.append(vbn2(tensor2_examples)) variable_scope.get_variable_scope().reuse_variables() to_fetch.append(vbn1.reference_batch_normalization()) to_fetch.append(vbn2.reference_batch_normalization()) to_fetch.append(vbn1(tensor1_examples)) to_fetch.append(vbn2(tensor2_examples)) self.assertEqual(4, len(contrib_variables_lib.get_variables())) with self.session(use_gpu=True) as sess: variables_lib.global_variables_initializer().run() sess.run(to_fetch)
def test_variable_reuse(self): """Test that variable scopes work and inference on a real-ish case.""" tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3]) tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3]) tensor2_ref = array_ops.zeros([4, 2, 3]) tensor2_examples = array_ops.zeros([2, 2, 3]) with variable_scope.variable_scope('dummy_scope', reuse=True): with self.assertRaisesRegexp( ValueError, 'does not exist, or was not created with ' 'tf.get_variable()'): virtual_batchnorm.VBN(tensor1_ref) vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1') vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2') # Fetch reference and examples after virtual batch normalization. Also # fetch in variable reuse case. to_fetch = [] to_fetch.append(vbn1.reference_batch_normalization()) to_fetch.append(vbn2.reference_batch_normalization()) to_fetch.append(vbn1(tensor1_examples)) to_fetch.append(vbn2(tensor2_examples)) variable_scope.get_variable_scope().reuse_variables() to_fetch.append(vbn1.reference_batch_normalization()) to_fetch.append(vbn2.reference_batch_normalization()) to_fetch.append(vbn1(tensor1_examples)) to_fetch.append(vbn2(tensor2_examples)) self.assertEqual(4, len(contrib_variables_lib.get_variables())) with self.test_session(use_gpu=True) as sess: variables_lib.global_variables_initializer().run() sess.run(to_fetch)
def build_dqn(self): self.w = {} self.t_w = {} #initializer = tf.contrib.layers.xavier_initializer() initializer = tf.truncated_normal_initializer(0, 0.02) activation_fn = tf.nn.relu # training network with tf.variable_scope('prediction'): if self.cnn_format == 'NHWC': self.s_t = tf.placeholder('float32', [ None, self.screen_width, self.screen_height, self.history_length ], name='s_t') elif data_format == 'NCHW': self.s_t = tf.placeholder('float32', [ None, self.history_length, self.screen_width, self.screen_height ], name='s_t') self.l1, self.w['l1_w'], self.w['l1_b'] = conv2d(self.s_t / 255., 16, [8, 8], [4, 4], initializer, activation_fn, self.cnn_format, name='l1') self.l2, self.w['l2_w'], self.w['l2_b'] = conv2d(self.l1, 32, [4, 4], [2, 2], initializer, activation_fn, self.cnn_format, name='l2') shape = self.l2.get_shape().as_list() self.l2_flat = tf.reshape( self.l2, [-1, functools.reduce(lambda x, y: x * y, shape[1:])]) if self.dueling: self.value_hid, self.w['l3_val_w'], self.w['l3_val_b'] = \ linear(self.l2_flat, 256, activation_fn=activation_fn, name='value_hid') self.adv_hid, self.w['l3_adv_w'], self.w['l3_adv_b'] = \ linear(self.l2_flat, 256, activation_fn=activation_fn, name='adv_hid') self.value, self.w['val_w_out'], self.w['val_w_b'] = \ linear(self.value_hid, 1, name='value_out') self.advantage, self.w['adv_w_out'], self.w['adv_w_b'] = \ linear(self.adv_hid, self.env.action_size, name='adv_out') # Average Dueling self.q = self.value + (self.advantage - tf.reduce_mean( self.advantage, reduction_indices=1, keep_dims=True)) else: self.l3, self.w['l3_w'], self.w['l3_b'] = linear( self.l2_flat, 256, activation_fn=activation_fn, name='l3') self.q, self.w['q_w'], self.w['q_b'] = linear( self.l3, self.env.action_size, name='q') self.q_action = tf.argmax(self.q, dimension=1) # target network with tf.variable_scope('target'): if self.cnn_format == 'NHWC': self.target_s_t = tf.placeholder('float32', [ None, self.screen_width, self.screen_height, self.history_length ], name='target_s_t') else: self.target_s_t = tf.placeholder('float32', [ None, self.history_length, self.screen_width, self.screen_height ], name='target_s_t') self.target_l1, self.t_w['l1_w'], self.t_w['l1_b'] = conv2d( self.target_s_t / 255., 16, [8, 8], [4, 4], initializer, activation_fn, self.cnn_format, name='target_l1') self.target_l2, self.t_w['l2_w'], self.t_w['l2_b'] = conv2d( self.target_l1, 32, [4, 4], [2, 2], initializer, activation_fn, self.cnn_format, name='target_l2') shape = self.target_l2.get_shape().as_list() self.target_l2_flat = tf.reshape( self.target_l2, [-1, functools.reduce(lambda x, y: x * y, shape[1:])]) if self.dueling: self.t_value_hid, self.t_w['l3_val_w'], self.t_w['l3_val_b'] = \ linear(self.target_l2_flat, 256, activation_fn=activation_fn, name='target_value_hid') self.t_adv_hid, self.t_w['l3_adv_w'], self.t_w['l3_adv_b'] = \ linear(self.target_l2_flat, 256, activation_fn=activation_fn, name='target_adv_hid') self.t_value, self.t_w['val_w_out'], self.t_w['val_w_b'] = \ linear(self.t_value_hid, 1, name='target_value_out') self.t_advantage, self.t_w['adv_w_out'], self.t_w['adv_w_b'] = \ linear(self.t_adv_hid, self.env.action_size, name='target_adv_out') # Average Dueling self.target_q = self.t_value + ( self.t_advantage - tf.reduce_mean( self.t_advantage, reduction_indices=1, keep_dims=True)) else: self.target_l3, self.t_w['l3_w'], self.t_w['l3_b'] = \ linear(self.target_l2_flat, 256, activation_fn=activation_fn, name='target_l3') self.target_q, self.t_w['q_w'], self.t_w['q_b'] = \ linear(self.target_l3, self.env.action_size, name='target_q') self.target_q_idx = tf.placeholder('int32', [None, None], 'outputs_idx') self.target_q_with_idx = tf.gather_nd(self.target_q, self.target_q_idx) global_collection = tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES) for var in variables.get_variables(scope="target"): tf.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES, var) global_collection.remove(var) with tf.variable_scope('pred_to_target'): self.t_w_input = {} self.t_w_assign_op = {} for name in self.w.keys(): self.t_w_assign_op[name] = self.t_w[name].assign(self.w[name]) # optimizer with tf.variable_scope('optimizer'): self.target_q_t = tf.placeholder('float32', [None], name='target_q_t') self.action = tf.placeholder('int64', [None], name='action') action_one_hot = tf.one_hot(self.action, self.env.action_size, 1.0, 0.0, name='action_one_hot') q_acted = tf.reduce_sum(self.q * action_one_hot, reduction_indices=1, name='q_acted') self.delta = self.target_q_t - q_acted self.loss = tf.reduce_mean(tf.square(self.delta), name='loss') new_grads_and_vars = [] grads_and_vars = self.optimizer.compute_gradients( self.loss, list(self.w.values())) for grad, var in tuple(grads_and_vars): new_grads_and_vars.append((tf.clip_by_norm(grad, 40), var)) self.optim = self.optimizer.apply_gradients(new_grads_and_vars) global_collection = tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES) for var in variables.get_variables(scope="optimizer"): tf.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES, var) global_collection.remove(var) with tf.variable_scope('summary'): scalar_summary_tags = ['average.reward', 'average.loss', 'average.q', \ 'episode.max reward', 'episode.min reward', 'episode.avg reward', 'episode.num of game', 'training.learning_rate'] self.summary_placeholders = {} self.summary_ops = {} for tag in scalar_summary_tags: self.summary_placeholders[tag] = tf.placeholder( 'float32', None, name=tag.replace(' ', '_')) self.summary_ops[tag] = tf.summary.scalar( "%s-%s/%s" % (self.env_name, self.env_type, tag), self.summary_placeholders[tag]) self.summary_op = tf.summary.merge(list(self.summary_ops.values()), name='total_summary') histogram_summary_tags = ['episode.rewards', 'episode.actions'] for tag in histogram_summary_tags: self.summary_placeholders[tag] = tf.placeholder( 'float32', None, name=tag.replace(' ', '_')) self.summary_ops[tag] = tf.summary.histogram( tag, self.summary_placeholders[tag])
def build_a3c(self): self.w = {} self.t_w = {} initializer = tf.truncated_normal_initializer(0, 0.02) activation_fn = tf.nn.relu DQN_type = 'nature' data_format = self.cnn_format beta = 0.1 if data_format == 'NHWC': self.s_t = tf.placeholder('float32', [ None, self.screen_width, self.screen_height, self.history_length ], name='s_t') elif data_format == 'NCHW': self.s_t = tf.placeholder('float32', [ None, self.history_length, self.screen_width, self.screen_height ], name='s_t') if data_format == 'NCHW': device = '/gpu:0' elif data_format == 'NHWC': device = '/cpu:0' else: raise ValueError('Unknown data_format: %s' % data_format) def flat(layer): shape = layer.get_shape().as_list() return tf.reshape( layer, [-1, functools.reduce(lambda x, y: x * y, shape[1:])]) if DQN_type.lower() == 'nature': with tf.variable_scope('Nature_DQN'), tf.device(device): self.l0 = tf.div(self.s_t, 255.) self.l1, self.w['l1_w'], self.w['l1_b'] = conv2d( self.l0, 32, [8, 8], [4, 4], initializer, activation_fn, data_format, name='l1_conv') self.l2, self.w['l2_w'], self.w['l2_b'] = conv2d( self.l1, 64, [4, 4], [2, 2], initializer, activation_fn, data_format, name='l2_conv') self.l3, self.w['l3_w'], self.w['l3_b'] = conv2d( self.l2, 64, [3, 3], [1, 1], initializer, activation_fn, data_format, name='l3_conv') self.l3_flat = flat(self.l3) self.l4, self.w['l4_w'], self.w['l4_b'] = \ linear(self.l3_flat, 512, activation_fn=activation_fn, name='l4_linear') elif DQN_type.lower() == 'nips': with tf.variable_scope('Nips_DQN'), tf.device(device): self.l0 = tf.div(self.s_t, 255.) self.l1, self.w['l1_w'], self.w['l1_b'] = conv2d( self.l0, 16, [8, 8], [4, 4], initializer, activation_fn, data_format, name='l1_conv') self.l2, self.w['l2_w'], self.w['l2_b'] = conv2d( self.l1, 32, [4, 4], [2, 2], initializer, activation_fn, data_format, name='l2_conv') self.l2_flat = flat(self.l2) self.l4, self.w['l4_w'], self.w['l4_b'] = \ linear(self.l2_flat, 256, activation_fn=activation_fn, name='l4_linear') else: raise ValueError('Wrong DQN type: %s' % DQN_type) def reshape_w(w): shape = w.get_shape().as_list() return tf.transpose(tf.reshape(w, shape[:2] + [1, -1]), [3, 0, 1, 2]) # Policy head. with tf.variable_scope('policy'): # 512 -> action_size self.policy_logits, self.w['p_w'], self.w['p_b'] = linear( self.l4, self.env.action_size, name='linear') with tf.variable_scope('policy'): self.policy = tf.nn.softmax(self.policy_logits, name='pi') with tf.variable_scope('log_policy'): self.log_policy = tf.log(self.policy) with tf.variable_scope('policy_entropy'): self.policy_entropy = -tf.reduce_sum( self.policy * self.log_policy, 1) # with tf.variable_scope('pred_action'): # self.sampled_action = tf.multinomial(self.policy_logits, 1) # self.sampled_action = batch_sample(self.policy) # sampled_action_one_hot = tf.one_hot(self.sampled_action, self.env.action_size, 1., 0.) # with tf.variable_scope('log_policy_of_action'): # self.log_policy_of_sampled_action = tf.reduce_sum(self.log_policy * sampled_action_one_hot, 1) # Value head. with tf.variable_scope('value'): # 512 -> 1 self.value, self.w['q_w'], self.w['q_b'] = linear(self.l4, 1, name='linear') with tf.variable_scope('optimizer'): self.R = tf.placeholder('float32', [None], name='target_reward') self.action = tf.placeholder('int64', [None], name='action') # self.true_log_policy = tf.placeholder('float32', [None], name='true_action') self.true_log_policy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.action, logits=self.policy_logits, name='true_action') # TODO: equation on paper and codes of other implementations are different with tf.variable_scope('policy_loss'): self.policy_loss = -(self.true_log_policy \ * (self.R - self.value)) - beta * self.policy_entropy with tf.variable_scope('value_loss'): self.value_loss = tf.pow(self.R - self.value, 2) / 2 with tf.variable_scope('total_loss'): self.loss = tf.reduce_mean(self.policy_loss + self.value_loss) new_grads_and_vars = [] grads_and_vars = self.optimizer.compute_gradients( self.loss, list(self.w.values())) for grad, var in tuple(grads_and_vars): new_grads_and_vars.append((tf.clip_by_norm(grad, 40), var)) self.optim = self.optimizer.apply_gradients(new_grads_and_vars) global_collection = tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES) for var in variables.get_variables(scope="optimizer"): tf.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES, var) global_collection.remove(var) # if global_network != None: if False: with tf.variable_scope('copy_from_target'): copy_ops = [] for name in self.w.keys(): copy_op = self.w[name].assign(global_network.w[name]) copy_ops.append(copy_op) self.global_copy_op = tf.group(*copy_ops, name='global_copy_op')
def add_fc_weights_summary(name, path): biases = variables.get_variables(path + '/biases') weights = variables.get_variables(path + '/weights') biases = tf.expand_dims(biases, 1) tf.summary.image(name, [tf.transpose(tf.concat([weights, biases], 1))])
def main(unused_argv=None): with tf.Graph().as_default(): # Force all input processing onto CPU in order to reserve the GPU for the # forward inference and back-propagation. device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0' with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)): inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size, FLAGS.image_size) # Load style images and select one at random (for each graph execution, a # new random selection occurs) _, style_labels, \ style_gram_matrices = image_utils.style_image_inputs( os.path.expanduser(FLAGS.style_dataset_file), batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, square_crop=True, shuffle=True) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Process style and weight flags num_styles = FLAGS.num_styles if FLAGS.style_coefficients is None: style_coefficients = [1.0 for _ in range(num_styles)] else: style_coefficients = ast.literal_eval(FLAGS.style_coefficients) if len(style_coefficients) != num_styles: raise ValueError( 'number of style coefficients differs from number of styles' ) content_weights = ast.literal_eval(FLAGS.content_weights) style_weights = ast.literal_eval(FLAGS.style_weights) # Rescale style weights dynamically based on the current style image style_coefficient = tf.gather(tf.constant(style_coefficients), style_labels) style_weights = dict((key, style_coefficient * value) for key, value in style_weights.items()) # Define the model stylized_inputs = model.transform(inputs, alpha=FLAGS.alpha, normalizer_params={ 'labels': style_labels, 'num_categories': num_styles, 'center': True, 'scale': True }) # Compute losses. total_loss, loss_dict = learning.total_loss( inputs, stylized_inputs, style_gram_matrices, content_weights, style_weights) for key, value in loss_dict.items(): tf.summary.scalar(key, value) instance_norm_vars = [ var for var in slim.get_variables('transformer') if 'InstanceNorm' in var.name ] other_vars = [ var for var in slim.get_variables('transformer') if 'InstanceNorm' not in var.name ] # Function to restore VGG16 parameters. init_fn_vgg = slim.assign_from_checkpoint_fn( vgg.checkpoint_file(), slim.get_variables('vgg_16')) checkpoint = os.path.expanduser(FLAGS.checkpoint) if tf.gfile.IsDirectory(checkpoint): checkpoint = tf.train.latest_checkpoint(checkpoint) tf.logging.info( 'loading latest checkpoint file: {}'.format(checkpoint)) # Function to restore N-styles parameters. vars = slim.get_variables( 'transformer') if FLAGS.restore_all_weights else other_vars init_fn_n_styles = slim.assign_from_checkpoint_fn(checkpoint, vars) def init_fn(session): init_fn_vgg(session) init_fn_n_styles(session) # Set up training. optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) train_op = slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm, variables_to_train=instance_norm_vars, summarize_gradients=False) savertransformer = tf.train.Saver( variables.get_variables("transformer"), save_relative_paths=True) # Run training. slim.learning.train(train_op=train_op, logdir=os.path.expanduser(FLAGS.train_dir), log_every_n_steps=FLAGS.log_every_n_steps, master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.train_steps, init_fn=init_fn, saver=savertransformer, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)