def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self): # First, train only the weights of the model. with ops.Graph().as_default(): random_seed.set_random_seed(0) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) weights, biases = variables_lib.get_variables() train_op = training.create_train_op(total_loss, optimizer) train_weights = training.create_train_op( total_loss, optimizer, variables_to_train=[weights]) train_biases = training.create_train_op( total_loss, optimizer, variables_to_train=[biases]) with self.test_session() as session: # Initialize the variables. session.run(variables_lib2.global_variables_initializer()) # Get the initial weights and biases values. weights_values, biases_values = session.run([weights, biases]) self.assertGreater(np.linalg.norm(weights_values), 0) self.assertAlmostEqual(np.linalg.norm(biases_values), 0) # Update weights and biases. loss = session.run(train_op) self.assertGreater(loss, .5) new_weights, new_biases = session.run([weights, biases]) # Check that the weights and biases have been updated. self.assertGreater( np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) weights_values, biases_values = new_weights, new_biases # Update only weights. loss = session.run(train_weights) self.assertGreater(loss, .5) new_weights, new_biases = session.run([weights, biases]) # Check that the weights have been updated, but biases have not. self.assertGreater( np.linalg.norm(weights_values - new_weights), 0) self.assertAlmostEqual( np.linalg.norm(biases_values - new_biases), 0) weights_values = new_weights # Update only biases. loss = session.run(train_biases) self.assertGreater(loss, .5) new_weights, new_biases = session.run([weights, biases]) # Check that the biases have been updated, but weights have not. self.assertAlmostEqual( np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self): # First, train only the weights of the model. with ops.Graph().as_default(): random_seed.set_random_seed(0) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) weights, biases = variables_lib.get_variables() train_op = training.create_train_op(total_loss, optimizer) train_weights = training.create_train_op( total_loss, optimizer, variables_to_train=[weights]) train_biases = training.create_train_op( total_loss, optimizer, variables_to_train=[biases]) with session_lib.Session() as sess: # Initialize the variables. sess.run(variables_lib2.global_variables_initializer()) # Get the intial weights and biases values. weights_values, biases_values = sess.run([weights, biases]) self.assertGreater(np.linalg.norm(weights_values), 0) self.assertAlmostEqual(np.linalg.norm(biases_values), 0) # Update weights and biases. loss = sess.run(train_op) self.assertGreater(loss, .5) new_weights, new_biases = sess.run([weights, biases]) # Check that the weights and biases have been updated. self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) weights_values, biases_values = new_weights, new_biases # Update only weights. loss = sess.run(train_weights) self.assertGreater(loss, .5) new_weights, new_biases = sess.run([weights, biases]) # Check that the weights have been updated, but biases have not. self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0) weights_values = new_weights # Update only biases. loss = sess.run(train_biases) self.assertGreater(loss, .5) new_weights, new_biases = sess.run([weights, biases]) # Check that the biases have been updated, but weights have not. self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0) self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
def testEmptyUpdateOps(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = batchnorm_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer, update_ops=[]) moving_mean = variables_lib.get_variables_by_name('moving_mean')[0] moving_variance = variables_lib.get_variables_by_name('moving_variance')[ 0] with session_lib.Session() as sess: # Initialize all variables sess.run(variables_lib2.global_variables_initializer()) mean, variance = sess.run([moving_mean, moving_variance]) # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 4) self.assertAllClose(variance, [1] * 4) for _ in range(10): sess.run([train_op]) mean = moving_mean.eval() variance = moving_variance.eval() # Since we skip update_ops the moving_vars are not updated. self.assertAllClose(mean, [0] * 4) self.assertAllClose(variance, [1] * 4)
def _train_model(self, checkpoint_dir, num_steps): """Trains a simple classification model. Note that the data has been configured such that after around 300 steps, the model has memorized the dataset (e.g. we can expect %100 accuracy). Args: checkpoint_dir: The directory where the checkpoint is written to. num_steps: The number of steps to train for. """ with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss = loss_ops.log_loss(tf_predictions, tf_labels) optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) loss = training.train( train_op, checkpoint_dir, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)])
def testNoneGlobalStep(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = batchnorm_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op( total_loss, optimizer, global_step=None) global_step = variables_lib.get_or_create_global_step() with session_lib.Session() as sess: # Initialize all variables sess.run(variables_lib2.global_variables_initializer()) for _ in range(10): sess.run([train_op]) global_step = global_step.eval() # Since train_op don't use global_step it shouldn't change. self.assertAllClose(global_step, 0)
def testGlobalStepIsIncrementedByDefault(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = batchnorm_classifier(tf_inputs) loss = losses.log_loss(tf_labels, tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) global_step = variables_lib.get_or_create_global_step() with self.test_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) for _ in range(10): session.run(train_op) # After 10 updates global_step should be 10. self.assertAllClose(global_step.eval(), 10)
def testResumeTrainAchievesRoughlyTheSameLoss(self): number_of_steps = [300, 1, 5] logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss') for i in range(len(number_of_steps)): with ops.Graph().as_default(): random_seed.set_random_seed(i) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.StopAtStepHook( num_steps=number_of_steps[i]), basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=50, saver=saver), ]) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def dis_train_op(): with ops.name_scope('discriminator_train'): return training.create_train_op( total_loss=gan_loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=gan_model.discriminator_variables, update_ops=dis_update_ops)
def gen_train_op(): with ops.name_scope('generator_train'): return training.create_train_op( total_loss=gan_loss.generator_loss, optimizer=generator_optimizer, variables_to_train=gan_model.generator_variables, update_ops=gen_update_ops)
def create_train_op_v2(total_loss, optimizer, global_step=_USE_GLOBAL_STEP, update_ops_before_loss=None, update_ops_after_loss=None, variables_to_train=None, transform_grads_fn=None, summarize_gradients=False, gate_gradients=tf.train.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, check_numerics=True): """ support update_ops after train_op """ train_op = create_train_op(total_loss=total_loss, optimizer=optimizer, global_step=global_step, update_ops=update_ops_before_loss, variables_to_train=variables_to_train, transform_grads_fn=transform_grads_fn, summarize_gradients=summarize_gradients, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, check_numerics=check_numerics) if update_ops_after_loss is not None: final_train_op = with_dependencies([train_op], update_ops_after_loss) else: final_train_op = train_op return final_train_op
def dis_train_op(): with ops.name_scope('discriminator_train'): return training.create_train_op( total_loss=gan_loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=gan_model.discriminator_variables, update_ops=dis_update_ops)
def gen_train_op(): with ops.name_scope('generator_train'): return training.create_train_op( total_loss=gan_loss.generator_loss, optimizer=generator_optimizer, variables_to_train=gan_model.generator_variables, update_ops=gen_update_ops)
def _train_model(self, checkpoint_dir, num_steps): """Trains a simple classification model. Note that the data has been configured such that after around 300 steps, the model has memorized the dataset (e.g. we can expect %100 accuracy). Args: checkpoint_dir: The directory where the checkpoint is written to. num_steps: The number of steps to train for. """ with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss = loss_ops.log_loss(tf_predictions, tf_labels) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) loss = training.train( train_op, checkpoint_dir, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)])
def testCanAchieveZeroLoss(self): logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss') with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) loss = training.train( train_op, logdir, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)]) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def testTrainWithLocalVariable(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) local_multiplier = variables_lib.local_variable(1.0) tf_predictions = logistic_classifier(tf_inputs) * local_multiplier losses.log_loss(tf_labels, tf_predictions) total_loss = losses.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) loss = training.train( train_op, None, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)], save_summaries_steps=None, save_checkpoint_secs=None) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0): tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) losses.log_loss(tf_labels, tf_predictions) total_loss = losses.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=learning_rate) def transform_grads_fn(grads): if gradient_multiplier != 1.0: variables = variables_lib2.trainable_variables() gradient_multipliers = { var: gradient_multiplier for var in variables } with ops.name_scope('multiply_grads'): return training.multiply_gradients(grads, gradient_multipliers) else: return grads return training.create_train_op(total_loss, optimizer, transform_grads_fn=transform_grads_fn)
def gan_train_ops( self, model, loss, generator_optimizer, discriminator_optimizer, check_for_unused_update_ops=True, ): # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) all_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, model.EG_scope.name)) update_ops = list(all_ops & update_ops) with ops.name_scope('EG_train'): gen_train_op = training.create_train_op( total_loss=self.loss_EG, optimizer=self.EG_optimizer, variables_to_train=self.E_variables + self.G_variables, global_step=self.EG_global_step, update_ops=update_ops) update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) all_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, model.Dz_scope.name)) update_ops = list(all_ops & update_ops) with ops.name_scope('Dz_train'): gen_train_op = training.create_train_op( total_loss=self.loss_Dz, optimizer=self.D_z_optimizer, variables_to_train=self.E_variables + self.G_variables, global_step=self.EG_global_step, update_ops=update_ops) update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) all_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, model.Di_scope.name)) update_ops = list(all_ops & update_ops) with ops.name_scope('Di_train'): gen_train_op = training.create_train_op( total_loss=self.loss_Di, optimizer=self.D_img_optimizer, variables_to_train=self.E_variables + self.G_variables, global_step=self.EG_global_step, update_ops=update_ops) return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
def testTrainOpInCollection(self): with ops.Graph().as_default(): tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = batchnorm_classifier(tf_inputs) loss = losses.log_loss(tf_labels, tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) # Make sure the training op was recorded in the proper collection self.assertTrue(train_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
def testTrainOpInCollection(self): with ops.Graph().as_default(): tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = batchnorm_classifier(tf_inputs) loss = losses.log_loss(tf_labels, tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) # Make sure the training op was recorded in the proper collection self.assertTrue(train_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
def testCanAchieveZeroLoss(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) losses.log_loss(tf_labels, tf_predictions) total_loss = losses.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) loss = training.train( train_op, None, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)], save_summaries_steps=None, save_checkpoint_secs=None) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def make_train_op(scope, optimizer, global_step, total_loss, clip_gradient_norm=0, clip_gradient_value=0, summarize_gradients=False): transform_grads_fn = make_transform_grads_fn( clip_gradient_norm=clip_gradient_norm, clip_gradient_value=clip_gradient_value) variables = tf.trainable_variables(scope=scope) updates = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS, scope=scope) train_op = create_train_op(total_loss=total_loss, optimizer=optimizer, update_ops=updates, variables_to_train=variables, transform_grads_fn=transform_grads_fn, summarize_gradients=summarize_gradients, global_step=global_step) return train_op
def testTrainWithNoInitAssignCanAchieveZeroLoss(self): g = ops.Graph() with g.as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = batchnorm_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) loss = training.train( train_op, self._logdir, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)]) self.assertLess(loss, .1)
def testCanAchieveZeroLoss(self): logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss') with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) loss = training.train( train_op, logdir, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)]) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def testGlobalStepIsIncrementedByDefault(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = batchnorm_classifier(tf_inputs) loss = losses.log_loss(tf_labels, tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) global_step = variables_lib.get_or_create_global_step() with self.test_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) for _ in range(10): session.run(train_op) # After 10 updates global_step should be 10. self.assertAllClose(global_step.eval(), 10)
def testUseUpdateOps(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) expected_mean = np.mean(self._inputs, axis=(0)) expected_var = np.var(self._inputs, axis=(0)) tf_predictions = batchnorm_classifier(tf_inputs) loss = losses.log_loss(tf_labels, tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) moving_mean = variables_lib.get_variables_by_name('moving_mean')[0] moving_variance = variables_lib.get_variables_by_name( 'moving_variance')[0] with self.test_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) mean, variance = session.run([moving_mean, moving_variance]) # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 4) self.assertAllClose(variance, [1] * 4) for _ in range(10): session.run(train_op) mean = moving_mean.eval() variance = moving_variance.eval() # After 10 updates with decay 0.1 moving_mean == expected_mean and # moving_variance == expected_var. self.assertAllClose(mean, expected_mean) self.assertAllClose(variance, expected_var)
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0): tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=learning_rate) def transform_grads_fn(grads): if gradient_multiplier != 1.0: variables = variables_lib2.trainable_variables() gradient_multipliers = {var: gradient_multiplier for var in variables} with ops.name_scope('multiply_grads'): return training.multiply_gradients(grads, gradient_multipliers) else: return grads return training.create_train_op( total_loss, optimizer, transform_grads_fn=transform_grads_fn)
def testUseUpdateOps(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) expected_mean = np.mean(self._inputs, axis=(0)) expected_var = np.var(self._inputs, axis=(0)) tf_predictions = batchnorm_classifier(tf_inputs) loss = losses.log_loss(tf_labels, tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(loss, optimizer) moving_mean = variables_lib.get_variables_by_name('moving_mean')[0] moving_variance = variables_lib.get_variables_by_name('moving_variance')[ 0] with self.test_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) mean, variance = session.run([moving_mean, moving_variance]) # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 4) self.assertAllClose(variance, [1] * 4) for _ in range(10): session.run(train_op) mean = moving_mean.eval() variance = moving_variance.eval() # After 10 updates with decay 0.1 moving_mean == expected_mean and # moving_variance == expected_var. self.assertAllClose(mean, expected_mean) self.assertAllClose(variance, expected_var)
def create_train_op(total_loss, optimizer, global_step=_USE_GLOBAL_STEP, update_ops=None, variables_to_train=None, clip_gradient_norm=0, summarize_gradients=False, gate_gradients=tf_optimizer.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, gradient_multipliers=None, check_numerics=True): """Creates an `Operation` that evaluates the gradients and returns the loss. Args: total_loss: A `Tensor` representing the total loss. optimizer: A tf.Optimizer to use for computing the gradients. global_step: A `Tensor` representing the global step variable. If left as `_USE_GLOBAL_STEP`, then slim.variables.global_step() is used. update_ops: An optional list of updates to execute. If `update_ops` is `None`, then the update ops are set to the contents of the `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, a warning will be displayed. variables_to_train: an optional list of variables to train. If None, it will default to all tf.trainable_variables(). clip_gradient_norm: If greater than 0 then the gradients would be clipped by it. summarize_gradients: Whether or not add summaries for each gradient. gate_gradients: How to gate the computation of gradients. See tf.Optimizer. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: Whether or not to try colocating the gradients with the ops that generated them. gradient_multipliers: A dictionary of either `Variables` or `Variable` op names to the coefficient by which the associated gradient should be scaled. check_numerics: Whether or not we apply check_numerics. Returns: A `Tensor` that when evaluated, computes the gradients and returns the total loss value. """ def transform_grads_fn(grads): if gradient_multipliers: with ops.name_scope('multiply_grads'): grads = multiply_gradients(grads, gradient_multipliers) # Clip gradients. if clip_gradient_norm > 0: with ops.name_scope('clip_grads'): grads = clip_gradient_norms(grads, clip_gradient_norm) return grads return training.create_train_op( total_loss=total_loss, optimizer=optimizer, global_step=global_step, update_ops=update_ops, variables_to_train=variables_to_train, transform_grads_fn=transform_grads_fn, summarize_gradients=summarize_gradients, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops, check_numerics=check_numerics)
def gan_train_ops( model, # GANModel loss, # GANLoss generator_optimizer, discriminator_optimizer, # Optional check flags. check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. The highest-level call in TFGAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. Args: model: A GANModel. loss: A GANLoss. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. Returns: A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can be used to train a generator/discriminator pair. """ # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) # Get generator and discriminator update ops. We split them so that update # ops aren't accidentally run multiple times. For now, throw an error if # there are update ops that aren't associated with either the generator or # the discriminator. Might modify the `kwargs` dictionary. gen_update_ops, dis_update_ops = _get_update_ops( kwargs, model.generator_scope.name, model.discriminator_scope.name, check_for_unused_update_ops) generator_global_step = None if isinstance(generator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # TODO(joelshor): Figure out a way to get this work without including the # dummy global step in the checkpoint. # WARNING: Making this variable a local variable causes sync replicas to # hang forever. generator_global_step = variable_scope.get_variable( 'dummy_global_step_generator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) gen_update_ops += [generator_global_step.assign(global_step)] with ops.name_scope('generator_train'): gen_train_op = training.create_train_op( total_loss=loss.generator_loss, optimizer=generator_optimizer, variables_to_train=model.generator_variables, global_step=generator_global_step, update_ops=gen_update_ops, **kwargs) discriminator_global_step = None if isinstance(discriminator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # See comment above `generator_global_step`. discriminator_global_step = variable_scope.get_variable( 'dummy_global_step_discriminator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) dis_update_ops += [discriminator_global_step.assign(global_step)] with ops.name_scope('discriminator_train'): disc_train_op = training.create_train_op( total_loss=loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=model.discriminator_variables, global_step=discriminator_global_step, update_ops=dis_update_ops, **kwargs) return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
def create_train_op(total_loss, optimizer, global_step=_USE_GLOBAL_STEP, update_ops=None, variables_to_train=None, clip_gradient_norm=0, summarize_gradients=False, gate_gradients=tf_optimizer.Optimizer.GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, gradient_multipliers=None): """Creates an `Operation` that evaluates the gradients and returns the loss. Args: total_loss: A `Tensor` representing the total loss. optimizer: A tf.Optimizer to use for computing the gradients. global_step: A `Tensor` representing the global step variable. If left as `_USE_GLOBAL_STEP`, then slim.variables.global_step() is used. update_ops: An optional list of updates to execute. If `update_ops` is `None`, then the update ops are set to the contents of the `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, a warning will be displayed. variables_to_train: an optional list of variables to train. If None, it will default to all tf.trainable_variables(). clip_gradient_norm: If greater than 0 then the gradients would be clipped by it. summarize_gradients: Whether or not add summaries for each gradient. gate_gradients: How to gate the computation of gradients. See tf.Optimizer. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. colocate_gradients_with_ops: Whether or not to try colocating the gradients with the ops that generated them. gradient_multipliers: A dictionary of either `Variables` or `Variable` op names to the coefficient by which the associated gradient should be scaled. Returns: A `Tensor` that when evaluated, computes the gradients and returns the total loss value. """ def transform_grads_fn(grads): if gradient_multipliers: with ops.name_scope('multiply_grads'): grads = multiply_gradients(grads, gradient_multipliers) # Clip gradients. if clip_gradient_norm > 0: with ops.name_scope('clip_grads'): grads = clip_gradient_norms(grads, clip_gradient_norm) return grads return training.create_train_op( total_loss=total_loss, optimizer=optimizer, global_step=global_step, update_ops=update_ops, variables_to_train=variables_to_train, transform_grads_fn=transform_grads_fn, summarize_gradients=summarize_gradients, gate_gradients=gate_gradients, aggregation_method=aggregation_method, colocate_gradients_with_ops=colocate_gradients_with_ops)
def gan_train_ops( models, generator_scope, discriminator_scopes, losses, generator_optimizer, discriminator_optimizers, check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. The highest-level call in TFGAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. Args: model: A GANModel. loss: A GANLoss. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. Returns: A GANTrainOps tuple of (generator_train_op, discriminator_train_op, global_step_inc_op) that can be used to train a generator/discriminator pair. """ # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) # Get generator and discriminator update ops. We split them so that update # ops aren't accidentally run multiple times. For now, throw an error if # there are update ops that aren't associated with either the generator or # the discriminator. Might modify the `kwargs` dictionary. gen_update_ops, dis_update_ops_list = _get_update_ops( kwargs, generator_scope, discriminator_scopes, check_for_unused_update_ops) with tf.name_scope(generator_scope): gen_train_op = training.create_train_op( total_loss=losses.gen_nongan_losses, optimizer=generator_optimizer, variables_to_train=models[0]. generator_variables, # the generator variables are always the same global_step=None, update_ops=gen_update_ops, summarize_gradients=True, **kwargs) gen_summaries = ops.get_collection(ops.GraphKeys.SUMMARIES, generator_scope) disc_train_ops = {} disc_summaries_dict = {} for i in np.arange(len(models)): with tf.name_scope(discriminator_scopes[i]): disc_train_op = training.create_train_op( total_loss=losses.discriminator_losses[i], optimizer=discriminator_optimizers[i], variables_to_train=models[i].discriminator_variables, global_step=None, update_ops=dis_update_ops_list[i], summarize_gradients=True, **kwargs) # annotate the train_ops with meaningful names? disc_train_ops[i] = disc_train_op disc_summaries_dict[discriminator_scopes[i]] = ops.get_collection( ops.GraphKeys.SUMMARIES, discriminator_scopes[i]) return GANTrainOps(gen_train_op, disc_train_ops, global_step_inc), gen_summaries, disc_summaries_dict
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self): logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/') if gfile.Exists(logdir): # For running on jenkins. gfile.DeleteRecursively(logdir) # First, train only the weights of the model. with ops.Graph().as_default(): random_seed.set_random_seed(0) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) weights = variables_lib.get_variables_by_name('weights') train_op = training.create_train_op( total_loss, optimizer, variables_to_train=weights) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=200), ]) self.assertGreater(loss, .015) self.assertLess(loss, .05) # Next, train the biases of the model. with ops.Graph().as_default(): random_seed.set_random_seed(1) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) biases = variables_lib.get_variables_by_name('biases') train_op = training.create_train_op( total_loss, optimizer, variables_to_train=biases) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=300), ]) self.assertGreater(loss, .015) self.assertLess(loss, .05) # Finally, train both weights and bias to get lower loss. with ops.Graph().as_default(): random_seed.set_random_seed(2) total_loss = self.ModelLoss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = training.create_train_op(total_loss, optimizer) saver = saver_lib.Saver() loss = training.train( train_op, logdir, hooks=[ basic_session_run_hooks.CheckpointSaverHook( logdir, save_steps=1, saver=saver), basic_session_run_hooks.StopAtStepHook(num_steps=400), ]) self.assertIsNotNone(loss) self.assertLess(loss, .015)
def gan_train_ops( model, loss, generator_optimizer, discriminator_optimizer, check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. The highest-level call in TFGAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. Args: model: A GANModel. loss: A GANLoss. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. Returns: A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can be used to train a generator/discriminator pair. """ if isinstance(model, namedtuples.CycleGANModel): # Get and store all arguments other than model and loss from locals. # Contents of locals should not be modified, may not affect values. So make # a copy. https://docs.python.org/2/library/functions.html#locals. saved_params = dict(locals()) saved_params.pop('model', None) saved_params.pop('loss', None) kwargs = saved_params.pop('kwargs', {}) saved_params.update(kwargs) with ops.name_scope('cyclegan_x2y_train'): train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y, **saved_params) with ops.name_scope('cyclegan_y2x_train'): train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x, **saved_params) return namedtuples.GANTrainOps( (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op), (train_ops_x2y.discriminator_train_op, train_ops_y2x.discriminator_train_op), training_util.get_or_create_global_step().assign_add(1)) # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) # Get generator and discriminator update ops. We split them so that update # ops aren't accidentally run multiple times. For now, throw an error if # there are update ops that aren't associated with either the generator or # the discriminator. Might modify the `kwargs` dictionary. gen_update_ops, dis_update_ops = _get_update_ops( kwargs, model.generator_scope.name, model.discriminator_scope.name, check_for_unused_update_ops) generator_global_step = None if isinstance(generator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # TODO (joelshor): Figure out a way to get this work without including the id:885 # https://github.com/imdone/tensorflow/issues/886 # dummy global step in the checkpoint. # WARNING: Making this variable a local variable causes sync replicas to # hang forever. generator_global_step = variable_scope.get_variable( 'dummy_global_step_generator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) gen_update_ops += [generator_global_step.assign(global_step)] with ops.name_scope('generator_train'): gen_train_op = training.create_train_op( total_loss=loss.generator_loss, optimizer=generator_optimizer, variables_to_train=model.generator_variables, global_step=generator_global_step, update_ops=gen_update_ops, **kwargs) discriminator_global_step = None if isinstance(discriminator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # See comment above `generator_global_step`. discriminator_global_step = variable_scope.get_variable( 'dummy_global_step_discriminator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) dis_update_ops += [discriminator_global_step.assign(global_step)] with ops.name_scope('discriminator_train'): disc_train_op = training.create_train_op( total_loss=loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=model.discriminator_variables, global_step=discriminator_global_step, update_ops=dis_update_ops, **kwargs) return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
def ctc_estimator(tokens, token_lengths, logits, glogits, sequence_mask, sequence_length_ctc, vocab, run_config, params, mode, model_scope, training_hooks=[]): with tf.name_scope(model_scope + "/"): tok_1 = tokens + 1 ctc_labels_sparse = sparsify(tf.cast(tok_1, tf.int32), sequence_mask) ctc_labels = tf.sparse_tensor_to_dense(ctc_labels_sparse, default_value=-1) # ctc_labels = tf.sparse_transpose(ctc_labels, (1,0)) print("Labels: {}".format(ctc_labels)) print("logits: {}".format(logits)) print("glogits: {}".format(glogits)) # tf.tile(tf.pow([2], depth), (n,)) print("CTC: {}, {}, {}".format(ctc_labels, logits, sequence_length_ctc)) if tf.flags.FLAGS.gpu_ctc: ctc_loss_raw = ctc_loss_dense(labels=tok_1, label_length=token_lengths, logits=logits, logit_length=sequence_length_ctc) else: with tf.device("/cpu:0"): ctc_loss_raw = ctc_loss_dense(labels=tok_1, label_length=token_lengths, logits=logits, logit_length=sequence_length_ctc) # blank_index=-1 # sequence_length=tf.shape(logits)[0], # ctc_merge_repeated=True, # preprocess_collapse_repeated=False, # ctc_merge_repeated=True, # ignore_longer_outputs_than_inputs=False, # time_major=True ctc_loss = tf.reduce_mean(ctc_loss_raw, name='ctc_loss') tf.losses.add_loss(ctc_loss) losses = tf.losses.get_losses(scope=model_scope) print("Estimator losses: {}".format(losses)) losses += tf.losses.get_regularization_losses(scope=model_scope) total_loss = tf.add_n(losses) updates = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS, scope=model_scope) evaluation_hooks = [] if logits is not None: autoencode_hook = CTCHook(logits=logits, lengths=sequence_length_ctc, vocab=vocab, path=os.path.join(run_config.model_dir, "autoencoded", "autoencoded-{:08d}.csv"), true=ctc_labels, name="Autoencoded", merge_repeated=True) evaluation_hooks.append(autoencode_hook) if glogits is not None: generate_hook = CTCHook(logits=glogits, lengths=sequence_length_ctc, vocab=vocab, path=os.path.join(run_config.model_dir, "generated", "generated-{:08d}.csv"), true=ctc_labels, name="Generated", merge_repeated=True) evaluation_hooks.append(generate_hook) tf.summary.scalar('ctc_loss', ctc_loss) tf.summary.scalar('total_loss', total_loss) # Train optimizer = tf.train.AdamOptimizer(params.lr) variables = tf.trainable_variables(scope=model_scope) transform_grads_fn = make_transform_grads_fn(params=params) train_op = create_train_op(total_loss=total_loss, optimizer=optimizer, update_ops=updates, variables_to_train=variables, transform_grads_fn=transform_grads_fn, summarize_gradients=False, aggregation_method=None, check_numerics=True) eval_metric_ops = { 'ctc_loss_eval': tf.metrics.mean(ctc_loss_raw), 'token_lengths_eval': tf.metrics.mean(token_lengths) } return EstimatorSpec(mode=mode, loss=total_loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=evaluation_hooks, training_hooks=training_hooks, train_op=train_op)
def gan_train_ops( model, loss, generator_optimizer, discriminator_optimizer, check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. The highest-level call in TFGAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. Args: model: A GANModel. loss: A GANLoss. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. Returns: A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can be used to train a generator/discriminator pair. """ if isinstance(model, namedtuples.CycleGANModel): saved_params = locals() saved_params.pop('model', None) saved_params.pop('loss', None) kwargs = saved_params.pop('kwargs', {}) saved_params.update(kwargs) with ops.name_scope('cyclegan_x2y_train'): train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y, **saved_params) with ops.name_scope('cyclegan_y2x_train'): train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x, **saved_params) return namedtuples.GANTrainOps( (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op), (train_ops_x2y.discriminator_train_op, train_ops_y2x.discriminator_train_op), training_util.get_or_create_global_step().assign_add(1)) # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) # Get generator and discriminator update ops. We split them so that update # ops aren't accidentally run multiple times. For now, throw an error if # there are update ops that aren't associated with either the generator or # the discriminator. Might modify the `kwargs` dictionary. gen_update_ops, dis_update_ops = _get_update_ops( kwargs, model.generator_scope.name, model.discriminator_scope.name, check_for_unused_update_ops) generator_global_step = None if isinstance(generator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # TODO(joelshor): Figure out a way to get this work without including the # dummy global step in the checkpoint. # WARNING: Making this variable a local variable causes sync replicas to # hang forever. generator_global_step = variable_scope.get_variable( 'dummy_global_step_generator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) gen_update_ops += [generator_global_step.assign(global_step)] with ops.name_scope('generator_train'): gen_train_op = training.create_train_op( total_loss=loss.generator_loss, optimizer=generator_optimizer, variables_to_train=model.generator_variables, global_step=generator_global_step, update_ops=gen_update_ops, **kwargs) discriminator_global_step = None if isinstance(discriminator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): # See comment above `generator_global_step`. discriminator_global_step = variable_scope.get_variable( 'dummy_global_step_discriminator', shape=[], dtype=global_step.dtype.base_dtype, initializer=init_ops.zeros_initializer(), trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) dis_update_ops += [discriminator_global_step.assign(global_step)] with ops.name_scope('discriminator_train'): disc_train_op = training.create_train_op( total_loss=loss.discriminator_loss, optimizer=discriminator_optimizer, variables_to_train=model.discriminator_variables, global_step=discriminator_global_step, update_ops=dis_update_ops, **kwargs) return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)