Example #1
0
def run_benchmark(sess, init_op, add_op):
  """Returns MB/s rate of addition."""


  logdir=FLAGS.logdir_prefix+'/'+FLAGS.name
  os.system('mkdir -p '+logdir)
  
  # TODO: make events follow same format as eager writer
  writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
  filename = compat.as_text(writer.FileName())
  training_util.get_or_create_global_step()

  sess.run(init_op)

  for step in range(FLAGS.iters):
    start_time = time.time()
    for i in range(FLAGS.iters_per_step):
      sess.run(add_op.op)

    elapsed_time = time.time() - start_time
    rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
    event = make_event('rate', rate, step)
    writer.WriteEvent(event)
    writer.Flush()
  writer.Close()
 def _test_logits_helper(self, mode):
   """Tests that the expected logits are passed to mock head."""
   with ops.Graph().as_default():
     training_util.get_or_create_global_step()
     generator_inputs = {'x': array_ops.zeros([5, 4])}
     real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
                  array_ops.zeros([5, 4]))
     generator_scope_name = 'generator'
     head = mock_head(self,
                      expected_generator_inputs=generator_inputs,
                      expected_real_data=real_data,
                      generator_scope_name=generator_scope_name)
     estimator_spec = estimator._gan_model_fn(
         features=generator_inputs,
         labels=real_data,
         mode=mode,
         generator_fn=generator_fn,
         discriminator_fn=discriminator_fn,
         generator_scope_name=generator_scope_name,
         head=head)
     with monitored_session.MonitoredTrainingSession(
         checkpoint_dir=self._model_dir) as sess:
       if mode == model_fn_lib.ModeKeys.TRAIN:
         sess.run(estimator_spec.train_op)
       elif mode == model_fn_lib.ModeKeys.EVAL:
         sess.run(estimator_spec.loss)
       elif mode == model_fn_lib.ModeKeys.PREDICT:
         sess.run(estimator_spec.predictions)
       else:
         self.fail('Invalid mode: {}'.format(mode))
 def testGraphSummary(self):
   training_util.get_or_create_global_step()
   name = 'hi'
   graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
   with self.test_session():
     with self.create_db_writer().as_default():
       summary_ops.initialize(graph=graph)
   six.assertCountEqual(self, [name],
                        get_all(self.db, 'SELECT node_name FROM Nodes'))
Example #4
0
 def testEagerMemory(self):
   training_util.get_or_create_global_step()
   logdir = self.get_temp_dir()
   with summary_ops.create_file_writer(
       logdir, max_queue=0,
       name='t0').as_default(), summary_ops.always_record_summaries():
     summary_ops.generic('tensor', 1, '')
     summary_ops.scalar('scalar', 2.0)
     summary_ops.histogram('histogram', [1.0])
     summary_ops.image('image', [[[[1.0]]]])
     summary_ops.audio('audio', [[1.0]], 1.0, 1)
Example #5
0
  def testWriteSummaries(self):
    e = SimpleEvaluator(IdentityModel())
    e(3.0)
    e([5.0, 7.0, 9.0])
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()

    e.all_metric_results(logdir)

    events = summary_test_util.events_from_file(logdir)
    self.assertEqual(len(events), 2)
    self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
Example #6
0
  def testSummaryName(self):
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    with summary_ops.create_file_writer(
        logdir, max_queue=0,
        name='t2').as_default(), summary_ops.always_record_summaries():

      summary_ops.scalar('scalar', 2.0)

      events = summary_test_util.events_from_logdir(logdir)
      self.assertEqual(len(events), 2)
      self.assertEqual(events[1].summary.value[0].tag, 'scalar')
 def testSummaryOps(self):
   training_util.get_or_create_global_step()
   logdir = tempfile.mkdtemp()
   summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
   summary_ops.always_record_summaries()
   summary_ops.generic('tensor', 1, '')
   summary_ops.scalar('scalar', 2.0)
   summary_ops.histogram('histogram', [1.0])
   summary_ops.image('image', [[[[1.0]]]])
   summary_ops.audio('audio', [[1.0]], 1.0, 1)
   # The working condition of the ops is tested in the C++ test so we just
   # test here that we're calling them correctly.
   self.assertTrue(gfile.Exists(logdir))
  def testWriteSummaries(self):
    m = metrics.Mean()
    m([1, 10, 100])
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    with summary_ops.create_file_writer(
        logdir, max_queue=0,
        name="t0").as_default(), summary_ops.always_record_summaries():
      m.result()  # As a side-effect will write summaries.

    events = summary_test_util.events_from_logdir(logdir)
    self.assertEqual(len(events), 2)
    self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
 def testSummaryGlobalStep(self):
   training_util.get_or_create_global_step()
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(logdir, max_queue=0)
   with writer.as_default(), summary_ops.always_record_summaries():
     summary_ops.scalar('scalar', 2.0)
   with self.cached_session() as sess:
     sess.run(variables.global_variables_initializer())
     sess.run(summary_ops.summary_writer_initializer_op())
     step, _ = sess.run(
         [training_util.get_global_step(), summary_ops.all_summary_ops()])
   events = summary_test_util.events_from_logdir(logdir)
   self.assertEqual(2, len(events))
   self.assertEqual(step, events[1].step)
Example #10
0
  def testWriteSummariesGraph(self):
    with context.graph_mode(), ops.Graph().as_default(), self.test_session():
      e = SimpleEvaluator(IdentityModel())
      ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
      training_util.get_or_create_global_step()
      logdir = tempfile.mkdtemp()
      init_op, call_op, results_op = e.evaluate_on_dataset(
          ds, summary_logdir=logdir)
      variables.global_variables_initializer().run()
      e.run_evaluation(init_op, call_op, results_op)

    events = summary_test_util.events_from_file(logdir)
    self.assertEqual(len(events), 2)
    self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
  def testDefunSummarys(self):
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    with summary_ops.create_summary_file_writer(
        logdir, max_queue=0,
        name='t1').as_default(), summary_ops.always_record_summaries():

      @function.defun
      def write():
        summary_ops.scalar('scalar', 2.0)

      write()
      events = summary_test_util.events_from_logdir(logdir)
      self.assertEqual(len(events), 2)
      self.assertEqual(events[1].summary.value[0].simple_value, 2.0)
Example #12
0
  def test_inv_update_thunks(self):
    """Ensures inverse update ops run once per global_step."""
    with self._graph.as_default(), self.test_session() as sess:
      fisher_estimator = estimator.FisherEstimator(
          damping_fn=lambda: 0.2,
          variables=[self.weights],
          layer_collection=self.layer_collection,
          cov_ema_decay=0.0)

      # Construct op that updates one inverse per global step.
      global_step = training_util.get_or_create_global_step()
      inv_matrices = [
          matrix
          for fisher_factor in self.layer_collection.get_factors()
          for matrix in fisher_factor._inverses_by_damping.values()
      ]
      inv_update_op_thunks = fisher_estimator.inv_update_thunks
      inv_update_op = control_flow_ops.case(
          [(math_ops.equal(global_step, i), thunk)
           for i, thunk in enumerate(inv_update_op_thunks)])
      increment_global_step = global_step.assign_add(1)

      sess.run(variables.global_variables_initializer())
      initial_inv_values = sess.run(inv_matrices)

      # Ensure there's one update per inverse matrix. This is true as long as
      # there's no fan-in/fan-out or parameter re-use.
      self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))

      # Test is no-op if only 1 invariance matrix.
      assert len(inv_matrices) > 1

      # Assign each covariance matrix a value other than the identity. This
      # ensures that the inverse matrices are updated to something different as
      # well.
      cov_matrices = [
          fisher_factor.get_cov()
          for fisher_factor in self.layer_collection.get_factors()
      ]
      sess.run([
          cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
          for cov_matrix in cov_matrices
      ])

      for i in range(len(inv_matrices)):
        # Compare new and old inverse values
        new_inv_values = sess.run(inv_matrices)
        is_inv_equal = [
            np.allclose(initial_inv_value, new_inv_value)
            for (initial_inv_value,
                 new_inv_value) in zip(initial_inv_values, new_inv_values)
        ]
        num_inv_equal = sum(is_inv_equal)

        # Ensure exactly one inverse matrix changes per step.
        self.assertEqual(num_inv_equal, len(inv_matrices) - i)

        # Run all inverse update ops.
        sess.run(inv_update_op)
        sess.run(increment_global_step)
 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()), test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.AdamOptimizer(0.001)
       root = util.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = checkpoint_management.latest_checkpoint(
           checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(model, input_value),
           global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
  def testEagerTPUDistributionStrategy(self):
    self.skipTest("b/121387144")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      strategy = tpu_strategy.TPUStrategy()
      with strategy.scope():
        model = Subclassed()
        optimizer = adam_v1.AdamOptimizer(0.001)
        root = checkpointable_utils.Checkpoint(
            optimizer=optimizer, model=model,
            optimizer_step=training_util.get_or_create_global_step())
        root.restore(checkpoint_management.latest_checkpoint(
            checkpoint_directory))

        for _ in range(num_training_steps):
          strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
        root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
Example #15
0
 def setUp(self):
   self.model_dir = tempfile.mkdtemp()
   self.graph = ops.Graph()
   with self.graph.as_default():
     self.scaffold = monitored_session.Scaffold()
     self.global_step = training_util.get_or_create_global_step()
     self.train_op = state_ops.assign_add(self.global_step, 1)
 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = util.Checkpoint(
             optimizer=optimizer, model=model,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             model(input_value),
             global_step=root.global_step)
         checkpoint_path = checkpoint_management.latest_checkpoint(
             checkpoint_directory)
         with self.session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
  def testScalarSummary(self):
    """Test record_summaries_every_n_global_steps and all_summaries()."""
    with ops.Graph().as_default(), self.test_session() as sess:
      global_step = training_util.get_or_create_global_step()
      global_step.initializer.run()
      with ops.device('/cpu:0'):
        step_increment = state_ops.assign_add(global_step, 1)
      sess.run(step_increment)  # Increment global step from 0 to 1

      logdir = tempfile.mkdtemp()
      with summary_ops.create_file_writer(logdir, max_queue=0,
                                          name='t2').as_default():
        with summary_ops.record_summaries_every_n_global_steps(2):
          summary_ops.initialize()
          summary_op = summary_ops.scalar('my_scalar', 2.0)

          # Neither of these should produce a summary because
          # global_step is 1 and "1 % 2 != 0"
          sess.run(summary_ops.all_summary_ops())
          sess.run(summary_op)
          events = summary_test_util.events_from_logdir(logdir)
          self.assertEqual(len(events), 1)

          # Increment global step from 1 to 2 and check that the summary
          # is now written
          sess.run(step_increment)
          sess.run(summary_ops.all_summary_ops())
          events = summary_test_util.events_from_logdir(logdir)
          self.assertEqual(len(events), 2)
          self.assertEqual(events[1].summary.value[0].tag, 'my_scalar')
 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()):
       network = MyNetwork()
       optimizer = CheckpointableAdam(0.001)
       root = Checkpoint(
           optimizer=optimizer, network=network,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(network, input_value),
           global_step=root.global_step)
       if context.in_graph_mode():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
  def _test_summary_for_replica_zero_only(self, d):
    logdir = tempfile.mkdtemp()

    def run_fn():
      """Function executed for each replica."""
      with summary_writer.as_default():
        replica_id = ds_context.get_replica_context().replica_id_in_sync_group
        return summary_ops.write("a", replica_id)

    with self.cached_session() as sess, d.scope(), \
        summary_ops.always_record_summaries():
      # We need global_step because summary writing op *always* has global_step
      # as input, even when we always record summary or never record summary.
      global_step = training_util.get_or_create_global_step()
      if not context.executing_eagerly():
        # When executing eagerly, variables are initialized immediately after
        # creation, and its initializer will be None.
        global_step.initializer.run()
      summary_ops.set_step(0)
      summary_writer = summary_ops.create_file_writer(logdir)
      output = d.extended.call_for_each_replica(run_fn)
      unwrapped = d.unwrap(output)
      if not context.executing_eagerly():
        sess.run(summary_writer.init())
        sess.run(unwrapped)
        sess.run(summary_writer.close())

      events = _events_from_logdir(self, logdir)
      # There will be 2 entries: 1 summary file header entry, and 1 entry
      # written by replica 0.
      self.assertLen(events, 2)
      self.assertEqual(events[1].summary.value[0].tag, "a")
      self.assertEqual(events[1].summary.value[0].simple_value, 0.0)
Example #20
0
def _clone_and_build_model(mode,
                           keras_model,
                           custom_objects,
                           features=None,
                           labels=None):
  """Clone and build the given keras_model.

  Args:
    mode: training mode.
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    features:
    labels:

  Returns:
    The newly built model.
  """
  # Set to True during training, False for inference.
  K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)

  # Clone keras model.
  input_tensors = None if features is None else _create_ordered_io(
      keras_model, features)
  if custom_objects:
    with CustomObjectScope(custom_objects):
      model = models.clone_model(keras_model, input_tensors=input_tensors)
  else:
    model = models.clone_model(keras_model, input_tensors=input_tensors)

  # Compile/Build model
  if mode is model_fn_lib.ModeKeys.PREDICT and not model.built:
    model.build()
  else:
    optimizer_config = keras_model.optimizer.get_config()
    optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
    optimizer.iterations = training_util.get_or_create_global_step()

    # Get list of outputs.
    if labels is None:
      target_tensors = None
    elif isinstance(labels, dict):
      target_tensors = _create_ordered_io(keras_model, labels, is_input=False)
    else:
      target_tensors = [
          _cast_tensor_to_floatx(
              sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
      ]

    model.compile(
        optimizer,
        keras_model.loss,
        metrics=keras_model.metrics,
        loss_weights=keras_model.loss_weights,
        sample_weight_mode=keras_model.sample_weight_mode,
        weighted_metrics=keras_model.weighted_metrics,
        target_tensors=target_tensors)

  if isinstance(model, models.Sequential):
    model = model.model
  return model
  def testGraphDistributionStrategy(self):
    self.skipTest("b/121381184")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      return optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      with ops.Graph().as_default():
        strategy = mirrored_strategy.MirroredStrategy()
        with strategy.scope():
          model = MyModel()
          optimizer = adam.AdamOptimizer(0.001)
          root = checkpointable_utils.Checkpoint(
              optimizer=optimizer, model=model,
              optimizer_step=training_util.get_or_create_global_step())
          status = root.restore(checkpoint_management.latest_checkpoint(
              checkpoint_directory))
          train_op = strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
          with self.session() as session:
            if training_continuation > 0:
              status.assert_consumed()
            status.initialize_or_restore()
            for _ in range(num_training_steps):
              session.run(train_op)
            root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   for training_continuation in range(3):
     with test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.AdamOptimizer(0.001)
       root = checkpointable_utils.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       manager = checkpoint_management.CheckpointManager(
           root, checkpoint_directory, max_to_keep=1)
       status = root.restore(save_path=manager.latest_checkpoint)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(model, input_value),
           global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       manager.save()
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
Example #23
0
  def testSummaryName(self):
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2')
    summary_ops.always_record_summaries()

    summary_ops.scalar('scalar', 2.0)

    self.assertTrue(gfile.Exists(logdir))
    files = gfile.ListDirectory(logdir)
    self.assertEqual(len(files), 1)
    records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
    self.assertEqual(len(records), 2)
    event = event_pb2.Event()
    event.ParseFromString(records[1])
    self.assertEqual(event.summary.value[0].tag, 'scalar')
 def model_fn():
   """Mnist model with synthetic input."""
   data_format = 'channels_last'
   input_shape = [28, 28, 1]
   l = keras.layers
   max_pool = l.MaxPooling2D((2, 2), (2, 2),
                             padding='same',
                             data_format=data_format)
   model = keras.Sequential([
       l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
       l.Conv2D(
           32,
           5,
           padding='same',
           data_format=data_format,
           activation=nn.relu), max_pool,
       l.Conv2D(
           64,
           5,
           padding='same',
           data_format=data_format,
           activation=nn.relu), max_pool,
       l.Flatten(),
       l.Dense(1024, activation=nn.relu),
       l.Dropout(0.4),
       l.Dense(10)
   ])
   image = random_ops.random_uniform([2, 28, 28])
   label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32)
   logits = model(image, training=True)
   loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits)
   optimizer = adam.AdamOptimizer(learning_rate=1e-4)
   train_op = optimizer.minimize(loss,
                                 training_util.get_or_create_global_step())
   return train_op
 def setUp(self):
   self.model_dir = tempfile.mkdtemp()
   self.graph = ops.Graph()
   with self.graph.as_default():
     self.scaffold = monitored_session.Scaffold()
     with variable_scope.variable_scope('foo', use_resource=True):
       self.global_step = training_util.get_or_create_global_step()
     self.train_op = training_util._increment_global_step(1)
Example #26
0
 def testSaveRestoreDefaultGlobalStep(self):
   net = MyNetwork(name="abcd")
   net(constant_op.constant([[2.0]]))
   self.evaluate(net.variables[0].assign([[3.]]))
   default_global_step = training_util.get_or_create_global_step()
   self.evaluate(default_global_step.assign(4242))
   save_path = network.save_network_checkpoint(net, self.get_temp_dir())
   self.assertIn("abcd-4242", save_path)
Example #27
0
  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    num_trainable_vars = len(variables_lib.get_trainable_variables())

    if create_global_step:
      gstep = variable_scope.get_variable(
          'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
      ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTrainOps.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(
          hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = training_util.get_or_create_global_step()
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      variables.local_variables_initializer().run()

      g_opt.chief_init_op.run()
      d_opt.chief_init_op.run()

      gstep_before = global_step.eval()

      # Start required queue runner for SyncReplicasOptimizer.
      coord = coordinator.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      g_sync_init_op.run()
      d_sync_init_op.run()

      train_ops.generator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      train_ops.discriminator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      coord.request_stop()
      coord.join(g_threads + d_threads)
Example #28
0
 def setUp(self):
   super(PruningHParamsTest, self).setUp()
   # Add global step variable to the graph
   self.global_step = training_util.get_or_create_global_step()
   # Add sparsity
   self.sparsity = variables.Variable(0.5, name="sparsity")
   # Parse hparams
   self.pruning_hparams = pruning.get_pruning_hparams().parse(
       self.TEST_HPARAMS)
Example #29
0
def record_summaries_every_n_global_steps(n, global_step=None):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  if global_step is None:
    global_step = training_util.get_or_create_global_step()
  with ops.device("cpu:0"):
    should = lambda: math_ops.equal(global_step % n, 0)
    if not context.executing_eagerly():
      should = should()
  return _record_summaries(should)
Example #30
0
def record_summaries_every_n_global_steps(n, global_step=None):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  if global_step is None:
    global_step = training_util.get_or_create_global_step()
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  with ops.device("cpu:0"):
    collection_ref[:] = [math_ops.equal(global_step % n, 0)]
  yield
  collection_ref[:] = old
Example #31
0
def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          session_wrapper=None,
          trace_every_n_steps=None,
          ignore_live_threads=False):
    """Runs a training loop using a TensorFlow supervisor.

  When the sync_optimizer is supplied, gradient updates are applied
  synchronously. Otherwise, gradient updates are applied asynchronous.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where training logs are written to. If None, model
      checkpoints and summaries will not be written.
    train_step_fn: The function to call in order to execute a single gradient
      step. The function must have take exactly four arguments: the current
      session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
    train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
      default, two `Boolean`, scalar ops called "should_stop" and "should_log"
      are provided.
    log_every_n_steps: The frequency, in terms of global steps, that the loss
      and global step are logged.
    graph: The graph to pass to the supervisor. If no graph is supplied the
      default graph is used.
    master: The address of the tensorflow master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    global_step: The `Tensor` representing the global step. If left as `None`,
      then training_util.get_or_create_global_step(), that is,
      tf.contrib.framework.global_step() is used.
    number_of_steps: The max number of gradient steps to take during training,
      as measured by 'global_step': training will stop if global_step is
      greater than 'number_of_steps'. If the value is left as None, training
      proceeds indefinitely.
    init_op: The initialization operation. If left to its default value, then
      the session is initialized by calling `tf.global_variables_initializer()`.
    init_feed_dict: A feed dictionary to use when executing the `init_op`.
    local_init_op: The local initialization operation. If left to its default
      value, then the session is initialized by calling
      `tf.local_variables_initializer()` and `tf.tables_initializer()`.
    init_fn: An optional callable to be executed after `init_op` is called. The
      callable must accept one argument, the session being initialized.
    ready_op: Operation to check if the model is ready to use. If left to its
      default value, then the session checks for readiness by calling
      `tf.report_uninitialized_variables()`.
    summary_op: The summary operation.
    save_summaries_secs: How often, in seconds, to save summaries.
    summary_writer: `SummaryWriter` to use.  Can be `None`
      to indicate that no summaries should be written. If unset, we
      create a SummaryWriter.
    startup_delay_steps: The number of steps to wait for before beginning. Note
      that this must be 0 if a sync_optimizer is supplied.
    saver: Saver to save checkpoints. If None, a default one will be created
      and used.
    save_interval_secs: How often, in seconds, to save the model to `logdir`.
    sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of
      them. If the argument is supplied, gradient updates will be synchronous.
      If left as `None`, gradient updates will be asynchronous.
    session_config: An instance of `tf.ConfigProto` that will be used to
      configure the `Session`. If left as `None`, the default will be used.
    session_wrapper: A function that takes a `tf.Session` object as the only
      argument and returns a wrapped session object that has the same methods
      that the original object has, or `None`. Iff not `None`, the wrapped
      object will be used for training.
    trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
      and add it to the summaries every `trace_every_n_steps`. If None, no trace
      information will be produced or saved.
    ignore_live_threads: If `True` ignores threads that remain running after
      a grace period when stopping the supervisor, instead of raising a
      RuntimeError.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `train_op` is empty or if `startup_delay_steps` is
      non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
      negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
      provided.
  """

    print("HALLOOOOO**********************************")

    if train_op is None:
        raise ValueError('train_op cannot be None.')

    if logdir is None:
        if summary_op != _USE_DEFAULT:
            raise ValueError('Cannot provide summary_op because logdir=None')
        if saver is not None:
            raise ValueError('Cannot provide saver because logdir=None')
        if trace_every_n_steps is not None:
            raise ValueError('Cannot provide trace_every_n_steps because '
                             'logdir=None')

    if isinstance(sync_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        sync_optimizer = [sync_optimizer]
    if sync_optimizer is not None and startup_delay_steps > 0:
        raise ValueError(
            'startup_delay_steps must be zero when sync_optimizer is supplied.'
        )

    if number_of_steps is not None and number_of_steps <= 0:
        raise ValueError(
            '`number_of_steps` must be either None or a positive number.')

    graph = graph or ops.get_default_graph()
    with graph.as_default():
        if global_step is None:
            global_step = training_util.get_or_create_global_step()
        saver = saver or tf_saver.Saver()

        if sync_optimizer is not None:
            for opt in sync_optimizer:
                if not isinstance(
                        opt, sync_replicas_optimizer.SyncReplicasOptimizer):
                    raise ValueError(
                        '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.'
                    )

        with ops.name_scope('init_ops'):
            if init_op == _USE_DEFAULT:
                init_op = variables.global_variables_initializer()

            if ready_op == _USE_DEFAULT:
                ready_op = variables.report_uninitialized_variables()

            if local_init_op == _USE_DEFAULT:
                local_init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

            if sync_optimizer is not None and isinstance(sync_optimizer, list):
                with ops.control_dependencies(
                    [local_init_op] if local_init_op is not None else []):
                    if is_chief:
                        local_init_op = control_flow_ops.group(
                            *[opt.chief_init_op for opt in sync_optimizer])
                    else:
                        local_init_op = control_flow_ops.group(
                            *
                            [opt.local_step_init_op for opt in sync_optimizer])
                ready_for_local_init_op = control_flow_ops.group(
                    *[opt.ready_for_local_init_op for opt in sync_optimizer])
            else:
                ready_for_local_init_op = None

        if summary_op == _USE_DEFAULT:
            summary_op = summary.merge_all()

        if summary_writer == _USE_DEFAULT:
            summary_writer = supervisor.Supervisor.USE_DEFAULT

        if is_chief and sync_optimizer is not None:
            # Need to create these BEFORE the supervisor finalizes the graph:
            init_tokens_op = [
                opt.get_init_tokens_op() for opt in sync_optimizer
            ]
            chief_queue_runner = [
                opt.get_chief_queue_runner() for opt in sync_optimizer
            ]

        if train_step_kwargs == _USE_DEFAULT:
            with ops.name_scope('train_step'):
                train_step_kwargs = {}

                if number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                if log_every_n_steps > 0:
                    train_step_kwargs['should_log'] = math_ops.equal(
                        math_ops.mod(global_step, log_every_n_steps), 0)
                if is_chief and trace_every_n_steps is not None:
                    train_step_kwargs['should_trace'] = math_ops.equal(
                        math_ops.mod(global_step, trace_every_n_steps), 0)
                    train_step_kwargs['logdir'] = logdir

    sv = supervisor.Supervisor(graph=graph,
                               is_chief=is_chief,
                               logdir=logdir,
                               init_op=init_op,
                               init_feed_dict=init_feed_dict,
                               local_init_op=local_init_op,
                               ready_for_local_init_op=ready_for_local_init_op,
                               ready_op=ready_op,
                               summary_op=summary_op,
                               summary_writer=summary_writer,
                               global_step=global_step,
                               saver=saver,
                               save_summaries_secs=save_summaries_secs,
                               save_model_secs=save_interval_secs,
                               init_fn=init_fn)

    if summary_writer is not None:
        train_step_kwargs['summary_writer'] = sv.summary_writer

    total_loss = None
    should_retry = True
    while should_retry:
        try:
            should_retry = False
            with sv.managed_session(master,
                                    start_standard_services=False,
                                    config=session_config) as sess:
                logging.info('Starting Session.')
                if session_wrapper is not None:
                    logging.info('Wrapping session with wrapper function: %s',
                                 session_wrapper)
                    sess = session_wrapper(sess)
                if is_chief:
                    if logdir:
                        sv.start_standard_services(sess)
                elif startup_delay_steps > 0:
                    # (use sys.maxsize because sys.maxint doesn't exist in Python 3)
                    _wait_for_step(
                        sess, global_step,
                        min(startup_delay_steps, number_of_steps
                            or sys.maxsize))
                threads = sv.start_queue_runners(sess)
                logging.info('Starting Queues.')
                if is_chief and sync_optimizer is not None:
                    sv.start_queue_runners(sess, chief_queue_runner)
                    sess.run(init_tokens_op)
                try:
                    while not sv.should_stop():
                        total_loss, should_stop = train_step_fn(
                            sess, train_op, global_step, train_step_kwargs)
                        if should_stop:
                            logging.info('Stopping Training.')
                            sv.request_stop()
                            break
                except errors.OutOfRangeError as e:
                    # OutOfRangeError is thrown when epoch limit per
                    # tf.train.limit_epochs is reached.
                    logging.info(
                        'Caught OutOfRangeError. Stopping Training. %s', e)
                if logdir and sv.is_chief:
                    logging.info('Finished training! Saving model to disk.')
                    sv.saver.save(sess,
                                  sv.save_path,
                                  global_step=sv.global_step)
                    sv.stop(threads,
                            close_summary_writer=True,
                            ignore_live_threads=ignore_live_threads)

        except errors.AbortedError:
            # Always re-run on AbortedError as it indicates a restart of one of the
            # distributed tensorflow servers.
            logging.info('Retrying training!')
            should_retry = True

    return total_loss
Example #32
0
def _choose_step(step):
    if step is None:
        return training_util.get_or_create_global_step()
    if not isinstance(step, ops.Tensor):
        return ops.convert_to_tensor(step, dtypes.int64)
    return step
Example #33
0
 def _init_global_step(self):
     self.global_step = training_util.get_or_create_global_step()
     self._training_ops.update(
         {'increment_global_step': training_util._increment_global_step(1)})
Example #34
0
def slim_learning_create_train_op_with_manual_grads(
        total_loss,
        optimizers,  # list of optimizers 
        grads_and_vars,  # list of grads_and_vars from optimizer.compute_gradients()
        global_step=0,
        #                     update_ops=None,
        #                     variables_to_train=None,
        clip_gradient_norm=0,
        summarize_gradients=False,
        gate_gradients=1,  # tf.python.training.optimizer.Optimizer.GATE_OP,
        aggregation_method=None,
        colocate_gradients_with_ops=False,
        gradient_multipliers=None,
        check_numerics=True):
    """Runs the training loop
      
    modified from slim.learning.create_train_op() to work with
    a matched list of optimizers and grads_and_vars

    see:
      https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py
      https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/training.py
  
  Returns:
      train_ops - the value of the loss function after training.
  """
    from tensorflow.python.framework import ops
    from tensorflow.python.ops import array_ops
    from tensorflow.python.ops import control_flow_ops
    from tensorflow.python.training import training_util

    def _transform_grads_fn(grads):
        if gradient_multipliers:
            with ops.name_scope('multiply_grads'):
                grads = multiply_gradients(grads, gradient_multipliers)

        # Clip gradients.
        if clip_gradient_norm > 0:
            with ops.name_scope('clip_grads'):
                grads = clip_gradient_norms(grads, clip_gradient_norm)
        return grads

    if global_step is None:
        global_step = training_util.get_or_create_global_step()

    # we are assuming these are a matched set, should be zipped as a tuple(opt, grads, vars)
    assert len(optimizers) == len(grads_and_vars)

    ### order of processing:
    # 0. grads = opt.compute_gradients()
    # 1. grads = _transform_grads_fn(grads)
    # 2. add_gradients_summaries(grads)
    # 3. grads = opt.apply_gradients(grads, global_step=global_step)

    grad_updates = []
    for i in range(len(optimizers)):
        grads = grads_and_vars[i]  # 0. kvarg, from opt.compute_gradients()
        grads = _transform_grads_fn(grads)  # 1. _transform_grads_fn()
        if summarize_gradients:
            with ops.name_scope('summarize_grads'):
                slim.learning.add_gradients_summaries(
                    grads)  # 2. add_gradients_summaries()
        if i == 0:
            grad_update = optimizers[i].apply_gradients(
                grads,  # 3. optimizer.apply_gradients()
                global_step=global_step)  #    update global_step only once
        else:
            grad_update = optimizers[i].apply_gradients(grads)
        grad_updates.append(grad_update)

    with ops.name_scope('train_op'):
        total_loss = array_ops.check_numerics(total_loss,
                                              'LossTensor is inf or nan')
        train_op = control_flow_ops.with_dependencies(grad_updates, total_loss)

    # Add the operation used for training to the 'train_op' collection
    train_ops = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    if train_op not in train_ops:
        train_ops.append(train_op)

    return train_op
Example #35
0
def run_benchmark(master, direction=None):
    """Connect to master and run simple TF->Python transfer benchmark."""

    from tensorflow.python.summary import summary as summary_lib
    from tensorflow.python import pywrap_tensorflow
    from tensorflow.python.util import compat
    from tensorflow.core.util import event_pb2
    from tensorflow.core.framework import summary_pb2

    def make_event(tag, value, step):
        event = event_pb2.Event(
            wall_time=time.time(),
            step=step,
            summary=summary_pb2.Summary(
                value=[summary_pb2.Summary.Value(tag=tag, simple_value=value)
                       ]))
        return event

    if not direction:
        os.system('mkdir -p ' + FLAGS.logdir)

        # todo: unique filenames like with contrib.summary writer
        writer = pywrap_tensorflow.EventsWriter(
            compat.as_bytes(FLAGS.logdir + '/events'))
        filename = compat.as_text(writer.FileName())

        training_util.get_or_create_global_step()
        sess = tf.InteractiveSession()
        step = 0
        while True:
            p_to_t = run_benchmark(master, 'p->t')
            print("recoridng", p_to_t, "to", FLAGS.logdir)
            t_to_p = run_benchmark(master, 't->p')

            event = make_event('p->t', p_to_t, step)
            writer.WriteEvent(event)
            event = make_event('t->p', t_to_p, step)
            writer.WriteEvent(event)
            writer.Flush()
            step += 1

        writer.Close()
        return

    assert FLAGS.warmup_iters > 0
    gc.disable()

    dtype = tf.int32
    params_size = 250 * 1000 * FLAGS.data_mb  # 1MB is 250k integers
    #  params = tf.get_variable("params", [params_size], dtype,
    #                           initializer=tf.ones_initializer())
    params = tf.Variable(tf.ones([params_size], dtype=dtype), name='params')
    params_read = params.read_value()  # prevent caching
    params_holder = tf.placeholder(dtype)
    params_write = params.assign(params_holder)
    done_queue = create_done_queue(0)
    init_op = tf.global_variables_initializer()
    sess = tf.Session(master, config=session_config())
    sess.run(init_op)
    result = sess.run(params_read)

    total = 0
    for i in range(FLAGS.iters + FLAGS.warmup_iters):
        if i == FLAGS.warmup_iters:
            start_time = time.time()
        # fetch value into Python runtime
        if direction == "t->p":
            result = sess.run(params_read)
            if FLAGS.sanity_check:
                total += result.sum()
                print(float(total) / params_size)
        elif direction == "p->t":
            sess.run(params_write.op, feed_dict={params_holder: result})

    elapsed_time = time.time() - start_time
    rate = float(FLAGS.iters) * FLAGS.data_mb / elapsed_time
    print("%5s %.2f MB/second" % (direction, rate))
    sess.run(done_queue.enqueue(1))
    return rate
Example #36
0
    def test_initialize_if_not_restoring(self):
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
        with test_util.device(use_gpu=True):
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001)
            root = trackable_utils.Checkpoint(
                model=model,  # Do not save the optimizer with the checkpoint.
                global_step=training_util.get_or_create_global_step())
            optimizer_checkpoint = trackable_utils.Checkpoint(
                optimizer=optimizer)

            checkpoint_path = checkpoint_management.latest_checkpoint(
                checkpoint_directory)
            status = root.restore(save_path=checkpoint_path)
            input_value = constant_op.constant([[3.]])
            train_fn = functools.partial(optimizer.minimize,
                                         functools.partial(model, input_value),
                                         global_step=root.global_step)
            if not context.executing_eagerly():
                train_fn = functools.partial(self.evaluate, train_fn())
            status.initialize_or_restore()
            self.evaluate([v.initializer for v in optimizer.variables()])
            train_fn()
            model_save_path = root.save(file_prefix=checkpoint_prefix)
            self.evaluate(optimizer.variables()[0].assign(42.))
            optimizer_save_path = optimizer_checkpoint.save(
                optimizer_only_prefix)

        # Restore into a graph with the optimizer
        with test_util.device(use_gpu=True):
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001)
            root = trackable_utils.Checkpoint(
                optimizer=optimizer,
                model=model,
                global_step=training_util.get_or_create_global_step())
            status = root.restore(save_path=model_save_path)
            input_value = constant_op.constant([[3.]])
            train_fn = functools.partial(optimizer.minimize,
                                         functools.partial(model, input_value),
                                         global_step=root.global_step)
            if not context.executing_eagerly():
                train_fn = functools.partial(self.evaluate, train_fn())
            status.initialize_or_restore()
            train_fn()
            with self.assertRaises(AssertionError):
                status.assert_existing_objects_matched()
            with self.assertRaises(AssertionError):
                status.assert_consumed()

        # Make sure initialization doesn't clobber later restores
        with test_util.device(use_gpu=True):
            model = MyModel()
            optimizer = adam.AdamOptimizer(0.001, beta1=1.0)
            root = trackable_utils.Checkpoint(
                optimizer=optimizer,
                model=model,
                global_step=training_util.get_or_create_global_step())
            opt_root = trackable_utils.Checkpoint(optimizer=optimizer)
            status = root.restore(save_path=model_save_path)
            init_only_optimizer_status = opt_root.restore(save_path=None)
            optimizer_status = opt_root.restore(save_path=optimizer_save_path)
            input_value = constant_op.constant([[3.]])
            train_fn = functools.partial(optimizer.minimize,
                                         functools.partial(model, input_value),
                                         global_step=root.global_step)
            if not context.executing_eagerly():
                train_fn = functools.partial(self.evaluate, train_fn())
            optimizer_status.run_restore_ops()
            status.initialize_or_restore()
            init_only_optimizer_status.initialize_or_restore()
            train_fn()
            self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
Example #37
0
 def setUp(self):
     super(PruningTest, self).setUp()
     self.global_step = training_util.get_or_create_global_step()
Example #38
0
    def _model_fn(self, features, labels, mode, params):
        """

        :param features: 
        :param labels: 
        :param mode: 
        :param params: 
        :return: 
        """

        sample_image = None
        training_hooks = None

        # Create global step increment op.
        self.global_step = training_util.get_or_create_global_step()
        self.global_step_inc = self.global_step.assign_add(0)

        z_placeholder = features[
            self._feature_type.
            AUDIO_OR_NOISE]  # Audio/Noise Placeholder to the discriminator
        tf.logging.info("=========> {}".format(z_placeholder))

        z_placeholder = tf.cast(z_placeholder, tf.float32)

        tf.logging.info("=========> {}".format(z_placeholder))

        if mode != ModeKeys.INFER:

            x_placeholder = features[
                self._feature_type.
                IMAGE]  # Placeholder for input image vectors to the generator
            tf.logging.info("=========> {}".format(x_placeholder))

            x_placeholder = tf.cast(x_placeholder, tf.float32)
            tf.logging.info("=========> {}".format(x_placeholder))

            channel = x_placeholder.get_shape()[-1]
            d_loss, g_loss, print_hooks = self.model_loss(
                x_placeholder, z_placeholder, channel, self.global_step)

            d_train_opt, g_train_opt = self.model_opt(
                d_loss, g_loss, self.gan_config.learning_rate,
                self.gan_config.beta1, self.global_step)
        else:
            sample_image = self.generator(z_placeholder,
                                          self.gan_config.num_image_channels)
            #changes are made to take image channels from data iterator just for prediction

        # Loss, training and eval operations are not needed during inference.
        loss = None
        train_op = None
        eval_metric_ops = {}

        if mode != ModeKeys.INFER:
            loss = g_loss + d_loss
            tf.summary.scalar(name="g_loss", tensor=g_loss)
            tf.summary.scalar(name="d_loss", tensor=d_loss)

            training_hooks = self.get_sequential_train_hooks(
                d_train_opt, g_train_opt)
            training_hooks.append(print_hooks)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=sample_image,
                                          loss=loss,
                                          train_op=self.global_step_inc,
                                          eval_metric_ops=eval_metric_ops,
                                          training_hooks=training_hooks)
Example #39
0
 def begin(self):
     if self._replace_summary_op:
         self._summary_op = summary.merge_all()
     self._global_step = training_util.get_or_create_global_step()
Example #40
0
def _clone_and_build_model(mode,
                           keras_model,
                           custom_objects,
                           features=None,
                           labels=None):
    """Clone and build the given keras_model.

  Args:
    mode: training mode.
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    features: Dict of tensors.
    labels: Dict of tensors, or single tensor instance.

  Returns:
    The newly built model.
  """
    # Set to True during training, False for inference.
    K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)

    # Get list of inputs.
    if features is None:
        input_tensors = None
    else:
        input_tensors = _create_ordered_io(keras_model,
                                           estimator_io=features,
                                           is_input=True)
    # Get list of outputs.
    if labels is None:
        target_tensors = None
    elif isinstance(labels, dict):
        target_tensors = _create_ordered_io(keras_model,
                                            estimator_io=labels,
                                            is_input=False)
    else:
        target_tensors = [
            _cast_tensor_to_floatx(
                sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
        ]

    if keras_model._is_graph_network:
        if custom_objects:
            with CustomObjectScope(custom_objects):
                model = models.clone_model(keras_model,
                                           input_tensors=input_tensors)
        else:
            model = models.clone_model(keras_model,
                                       input_tensors=input_tensors)
    else:
        model = keras_model
        _in_place_subclassed_model_reset(model)
        if input_tensors is not None:
            model._set_inputs(input_tensors)

    # Compile/Build model
    if mode is model_fn_lib.ModeKeys.PREDICT:
        if isinstance(model, models.Sequential):
            model.build()
    else:
        if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
            optimizer = keras_model.optimizer
        else:
            optimizer_config = keras_model.optimizer.get_config()
            optimizer = keras_model.optimizer.__class__.from_config(
                optimizer_config)
        optimizer.iterations = training_util.get_or_create_global_step()

        model.compile(optimizer,
                      keras_model.loss,
                      metrics=keras_model.metrics,
                      loss_weights=keras_model.loss_weights,
                      sample_weight_mode=keras_model.sample_weight_mode,
                      weighted_metrics=keras_model.weighted_metrics,
                      target_tensors=target_tensors)
    return model
Example #41
0
def _dnn_linear_combined_model_fn_v2(
        features,
        labels,
        mode,
        head,
        linear_feature_columns=None,
        linear_optimizer='Ftrl',
        dnn_feature_columns=None,
        dnn_optimizer='Adagrad',
        dnn_hidden_units=None,
        dnn_activation_fn=nn.relu,
        dnn_dropout=None,
        config=None,
        batch_norm=False,
        linear_sparse_combiner='sum',
        loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE):
    """Deep Neural Net and Linear combined model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype
      `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction. See
      `ModeKeys`.
    head: A `Head` instance.
    linear_feature_columns: An iterable containing all the feature columns used
      by the Linear model.
    linear_optimizer: string, `Optimizer` object, or callable that defines the
      optimizer to use for training the Linear model. Defaults to the Ftrl
      optimizer.
    dnn_feature_columns: An iterable containing all the feature columns used by
      the DNN model.
    dnn_optimizer: string, `Optimizer` object, or callable that defines the
      optimizer to use for training the DNN model. Defaults to the Adagrad
      optimizer.
    dnn_hidden_units: List of hidden units per DNN layer.
    dnn_activation_fn: Activation function applied to each DNN layer. If `None`,
      will use `tf.nn.relu`.
    dnn_dropout: When not `None`, the probability we will drop out a given DNN
      coordinate.
    config: `RunConfig` object to configure the runtime settings.
    batch_norm: Whether to use batch normalization after each hidden layer.
    linear_sparse_combiner: A string specifying how to reduce the linear model
      if a categorical column is multivalent.  One of "mean", "sqrtn", and
      "sum".
    loss_reduction: One of `tf.keras.losses.Reduction` except `NONE`. Describes
      how to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.

  Returns:
    An `EstimatorSpec` instance.

  Raises:
    ValueError: If both `linear_feature_columns` and `dnn_features_columns`
      are empty at the same time, or `input_layer_partitioner` is missing,
      or features has the wrong type.
  """
    if not isinstance(features, dict):
        raise ValueError('features should be a dictionary of `Tensor`s. '
                         'Given type: {}'.format(type(features)))
    if not linear_feature_columns and not dnn_feature_columns:
        raise ValueError(
            'Either linear_feature_columns or dnn_feature_columns must be defined.'
        )

    del config

    # Build DNN Logits.
    if not dnn_feature_columns:
        dnn_logits = None
    else:
        if mode == ModeKeys.TRAIN:
            dnn_optimizer = optimizers.get_optimizer_instance_v2(
                dnn_optimizer, learning_rate=_DNN_LEARNING_RATE)
            _check_no_sync_replicas_optimizer(dnn_optimizer)

        if not dnn_hidden_units:
            raise ValueError(
                'dnn_hidden_units must be defined when dnn_feature_columns is '
                'specified.')
        dnn_logits, dnn_trainable_variables, dnn_update_ops = (
            dnn._dnn_model_fn_builder_v2(  # pylint: disable=protected-access
                units=head.logits_dimension,
                hidden_units=dnn_hidden_units,
                feature_columns=dnn_feature_columns,
                activation_fn=dnn_activation_fn,
                dropout=dnn_dropout,
                batch_norm=batch_norm,
                features=features,
                mode=mode))

    if not linear_feature_columns:
        linear_logits = None
    else:
        if mode == ModeKeys.TRAIN:
            linear_optimizer = optimizers.get_optimizer_instance_v2(
                linear_optimizer,
                learning_rate=_linear_learning_rate(
                    len(linear_feature_columns)))
            _check_no_sync_replicas_optimizer(linear_optimizer)

        linear_logits, linear_trainable_variables = (
            linear._linear_model_fn_builder_v2(  # pylint: disable=protected-access
                units=head.logits_dimension,
                feature_columns=linear_feature_columns,
                sparse_combiner=linear_sparse_combiner,
                features=features))
        _add_layer_summary(linear_logits, 'linear')

    # Combine logits and build full model.
    if dnn_logits is not None and linear_logits is not None:
        logits = dnn_logits + linear_logits
    elif dnn_logits is not None:
        logits = dnn_logits
    else:
        logits = linear_logits

    def _train_op_fn(loss):
        """Returns the op to optimize the loss."""
        train_ops = []
        # Scale loss by number of replicas.
        if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
            loss = losses_utils.scale_loss_for_distribution(loss)

        if dnn_logits is not None:
            train_ops.extend(
                dnn_optimizer.get_updates(loss, dnn_trainable_variables))
            if dnn_update_ops is not None:
                train_ops.extend(dnn_update_ops)
        if linear_logits is not None:
            train_ops.extend(
                linear_optimizer.get_updates(loss, linear_trainable_variables))
        train_op = control_flow_ops.group(*train_ops)
        return train_op

    # In TRAIN mode, asssign global_step variable to optimizer.iterations to
    # make global_step increased correctly, as Hooks relies on global step as
    # step counter. Note that, Only one model's optimizer needs this assignment.
    if mode == ModeKeys.TRAIN:
        if dnn_logits is not None:
            dnn_optimizer.iterations = training_util.get_or_create_global_step(
            )
        else:
            linear_optimizer.iterations = training_util.get_or_create_global_step(
            )

    return head.create_estimator_spec(features=features,
                                      mode=mode,
                                      labels=labels,
                                      train_op_fn=_train_op_fn,
                                      logits=logits)
Example #42
0
    def build_controller(self):
        """RL optimization interface.

    Returns:
      ops: A dictionary holding handles of the model used for training.
    """

        self._global_step = training_util.get_or_create_global_step()
        ops = {}
        ops["loss"] = 0

        failing_signal = self.compute_reward(self.hparams.failing_signal)

        ctr = {}

        with tf_ops.name_scope("controller_{}".format(self.ctrl_id)):
            with variable_scope.variable_scope("controller_{}".format(
                    self.ctrl_id)):
                ctr["reward"] = {"value": [], "ph": [], "update": []}
                ctr["ready"] = {"value": [], "ph": [], "update": []}
                ctr["best_reward"] = {"value": [], "update": []}
                for i in range(self.hparams.num_children):
                    reward_value = variable_scope.get_local_variable(
                        "reward_{}".format(i),
                        initializer=0.0,
                        dtype=dtypes.float32,
                        trainable=False)
                    reward_ph = array_ops.placeholder(
                        dtypes.float32,
                        shape=(),
                        name="reward_ph_{}".format(i))
                    reward_update = state_ops.assign(reward_value,
                                                     reward_ph,
                                                     use_locking=True)
                    ctr["reward"]["value"].append(reward_value)
                    ctr["reward"]["ph"].append(reward_ph)
                    ctr["reward"]["update"].append(reward_update)
                    best_reward = variable_scope.get_local_variable(
                        "best_reward_{}".format(i),
                        initializer=failing_signal,
                        dtype=dtypes.float32,
                        trainable=False)
                    ctr["best_reward"]["value"].append(best_reward)
                    ctr["best_reward"]["update"].append(
                        state_ops.assign(
                            best_reward,
                            math_ops.minimum(best_reward, reward_update)))

                    ready_value = variable_scope.get_local_variable(
                        "ready_{}".format(i),
                        initializer=True,
                        dtype=dtypes.bool,
                        trainable=False)
                    ready_ph = array_ops.placeholder(
                        dtypes.bool, shape=(), name="ready_ph_{}".format(i))
                    ready_update = state_ops.assign(ready_value,
                                                    ready_ph,
                                                    use_locking=True)
                    ctr["ready"]["value"].append(ready_value)
                    ctr["ready"]["ph"].append(ready_ph)
                    ctr["ready"]["update"].append(ready_update)

            ctr["grouping_y_preds"], ctr[
                "grouping_log_probs"] = self.get_groupings()
            summary.histogram(
                "grouping_actions",
                array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0],
                                [1, array_ops.shape(self.op_embeddings)[0]]))

            with variable_scope.variable_scope("controller_{}".format(
                    self.ctrl_id)):
                ctr["baseline"] = variable_scope.get_local_variable(
                    "baseline",
                    initializer=failing_signal
                    if self.hparams.start_with_failing_signal else 0.0,
                    dtype=dtypes.float32,
                    trainable=False)

            new_baseline = self.hparams.bl_dec * ctr["baseline"] + (
                1 - self.hparams.bl_dec) * math_ops.reduce_mean(
                    ctr["reward"]["value"])
            if not self.hparams.always_update_baseline:
                baseline_mask = math_ops.less(ctr["reward"]["value"],
                                              failing_signal)
                selected_reward = array_ops.boolean_mask(
                    ctr["reward"]["value"], baseline_mask)
                selected_baseline = control_flow_ops.cond(
                    math_ops.reduce_any(baseline_mask),
                    lambda: math_ops.reduce_mean(selected_reward),
                    lambda: constant_op.constant(0, dtype=dtypes.float32))
                ctr["pos_reward"] = selected_baseline
                pos_ = math_ops.less(
                    constant_op.constant(0, dtype=dtypes.float32),
                    selected_baseline)
                selected_baseline = self.hparams.bl_dec * ctr["baseline"] + (
                    1 - self.hparams.bl_dec) * selected_baseline
                selected_baseline = control_flow_ops.cond(
                    pos_, lambda: selected_baseline, lambda: ctr["baseline"])
                new_baseline = control_flow_ops.cond(
                    math_ops.less(self.global_step,
                                  self.hparams.stop_updating_after_steps),
                    lambda: new_baseline, lambda: selected_baseline)
            ctr["baseline_update"] = state_ops.assign(ctr["baseline"],
                                                      new_baseline,
                                                      use_locking=True)

            ctr["y_preds"], ctr["log_probs"] = self.get_placements()
            summary.histogram("actions", ctr["y_preds"]["sample"])
            mask = math_ops.less(ctr["reward"]["value"], failing_signal)
            ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"]
            ctr["loss"] *= (ctr["log_probs"]["sample"] +
                            ctr["grouping_log_probs"]["sample"])

            selected_loss = array_ops.boolean_mask(ctr["loss"], mask)
            selected_loss = control_flow_ops.cond(
                math_ops.reduce_any(mask),
                lambda: math_ops.reduce_mean(-selected_loss),
                lambda: constant_op.constant(0, dtype=dtypes.float32))

            ctr["loss"] = control_flow_ops.cond(
                math_ops.less(self.global_step,
                              self.hparams.stop_updating_after_steps),
                lambda: math_ops.reduce_mean(-ctr["loss"]),
                lambda: selected_loss)

            ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"])
            summary.scalar("loss", ctr["loss"])
            summary.scalar("avg_reward", ctr["reward_s"])
            summary.scalar("best_reward_so_far", best_reward)
            summary.scalar(
                "advantage",
                math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"]))

        with variable_scope.variable_scope("optimizer",
                                           reuse=variable_scope.AUTO_REUSE):
            (ctr["train_op"], ctr["lr"], ctr["grad_norm"],
             ctr["grad_norms"]) = self._get_train_ops(
                 ctr["loss"],
                 tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES),
                 self.global_step,
                 grad_bound=self.hparams.grad_bound,
                 lr_init=self.hparams.lr,
                 lr_dec=self.hparams.lr_dec,
                 start_decay_step=self.hparams.start_decay_step,
                 decay_steps=self.hparams.decay_steps,
                 optimizer_type=self.hparams.optimizer_type)

        summary.scalar("gradnorm", ctr["grad_norm"])
        summary.scalar("lr", ctr["lr"])
        ctr["summary"] = summary.merge_all()
        ops["controller"] = ctr

        self.ops = ops
        return ops
Example #43
0
def clone_and_build_model(model,
                          input_tensors=None,
                          target_tensors=None,
                          custom_objects=None,
                          compile_clone=True,
                          in_place_reset=False):
    """Clone a `Model` and build/compile it with the same settings used before.

  This function should be run in the same graph as the model.

  Args:
    model: `tf.keras.Model` object. Can be Functional, Sequential, or
      sub-classed.
    input_tensors: Optional list of input tensors to build the model upon. If
      not provided, placeholders will be created.
    target_tensors: Optional list of target tensors for compiling the model. If
      not provided, placeholders will be created.
    custom_objects: Optional dictionary mapping string names to custom classes
      or functions.
    compile_clone: Boolean, whether to compile model clone (default `True`).
    in_place_reset: Boolean, whether to reset the model in place. Only used if
      the model is not a graph network. If the model is a subclassed model, then
      this argument must be set to `True` (default `False`). To restore the
      original model, use the function
      `in_place_subclassed_model_state_restoration(model)`.

  Returns:
    Clone of the model.

  Raises:
    ValueError: if trying to clone a subclassed model, and `in_place_reset` is
      set to False.
  """
    if model._is_graph_network:
        if custom_objects:
            with CustomObjectScope(custom_objects):
                clone = clone_model(model, input_tensors=input_tensors)
        else:
            clone = clone_model(model, input_tensors=input_tensors)
    else:
        if not in_place_reset:
            raise ValueError(
                'Model is not a graph network (usually means that it is a subclassed '
                'model). The model cannot be cloned, but there is a workaround where '
                'the model is reset in-place. To use this, please set the argument '
                '`in_place_reset` to `True`. This will reset the attributes in the '
                'original model. To restore the attributes, call '
                '`in_place_subclassed_model_state_restoration(model)`.')
        clone = model
        _in_place_subclassed_model_reset(clone)
        if input_tensors is not None:
            clone._set_inputs(input_tensors)

    # Compile/Build model
    if not compile_clone:
        if isinstance(clone, Sequential):
            clone.build()
    elif model.optimizer:
        if isinstance(model.optimizer, optimizers.TFOptimizer):
            optimizer = model.optimizer
            K.track_tf_optimizer(optimizer)
        else:
            optimizer_config = model.optimizer.get_config()
            optimizer = model.optimizer.__class__.from_config(optimizer_config)
        global_step = training_util.get_or_create_global_step()
        K.track_variable(global_step)
        optimizer.iterations = global_step

        clone.compile(optimizer,
                      model.loss,
                      metrics=model.metrics,
                      loss_weights=model.loss_weights,
                      sample_weight_mode=model.sample_weight_mode,
                      weighted_metrics=model.weighted_metrics,
                      target_tensors=target_tensors)

    return clone
Example #44
0
 def testNamingWithOptimizer(self):
     input_value = constant_op.constant([[3.]])
     model = MyModel()
     # A nuisance Model using the same optimizer. Its slot variables should not
     # go in the checkpoint, since it is never depended on.
     other_model = MyModel()
     optimizer = adam.AdamOptimizer(0.001)
     optimizer_step = training_util.get_or_create_global_step()
     root_checkpointable = util.Checkpoint(optimizer=optimizer,
                                           model=model,
                                           optimizer_step=optimizer_step)
     if context.executing_eagerly():
         optimizer.minimize(lambda: model(input_value),
                            global_step=optimizer_step)
         optimizer.minimize(lambda: other_model(input_value),
                            global_step=optimizer_step)
     else:
         train_op = optimizer.minimize(model(input_value),
                                       global_step=optimizer_step)
         optimizer.minimize(other_model(input_value),
                            global_step=optimizer_step)
         self.evaluate(util.gather_initializers(root_checkpointable))
         self.evaluate(train_op)
     named_variables, serialized_graph, _ = (util._serialize_object_graph(
         root_checkpointable, saveables_cache=None))
     expected_checkpoint_names = (
         # Created in the root node, so no prefix.
         "optimizer_step",
         "model/_second/kernel",
         "model/_named_dense/kernel",
         "model/_named_dense/bias",
         # non-Layer dependency of the model
         "model/_non_layer/a_variable",
         # The optimizer creates two non-slot variables
         "optimizer/beta_1_power",
         "optimizer/beta_2_power",
         # Slot variables
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
         "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
     )
     suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
     expected_checkpoint_names = [
         name + suffix for name in expected_checkpoint_names
     ]
     # The optimizer and Dense layers also save get_config() JSON
     expected_checkpoint_names.extend([
         "optimizer/.ATTRIBUTES/OBJECT_CONFIG_JSON",
         "model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
         "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"
     ])
     named_variables = {v.name: v for v in named_variables}
     six.assertCountEqual(self, expected_checkpoint_names,
                          named_variables.keys())
     # Check that we've mapped to the right variable objects (not exhaustive)
     self.assertEqual("global_step",
                      named_variables["optimizer_step" + suffix].full_name)
     self.assertEqual(
         "my_model/dense_1/kernel",
         named_variables["model/_second/kernel" + suffix].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         named_variables["model/_named_dense/kernel" + suffix].full_name)
     self.assertEqual(
         "beta_1_power",
         named_variables["optimizer/beta_1_power" + suffix].full_name)
     self.assertEqual(
         "beta_2_power",
         named_variables["optimizer/beta_2_power" + suffix].full_name)
     # Spot check the generated protocol buffers.
     self.assertEqual("optimizer",
                      serialized_graph.nodes[0].children[1].local_name)
     optimizer_node = serialized_graph.nodes[
         serialized_graph.nodes[0].children[1].node_id]
     self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
     self.assertEqual(
         "beta_1_power", serialized_graph.nodes[
             optimizer_node.children[0].node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].full_name)
     # We strip off the :0 suffix, as variable.name-based saving does.
     self.assertEqual(
         "my_model/dense/kernel/Adam",
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].full_name)
     self.assertEqual(
         "my_model/dense/kernel/Adam:0",
         optimizer.get_slot(var=model._named_dense.kernel, name="m").name)
     self.assertEqual(
         "model/_named_dense/kernel" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].original_variable_node_id].attributes[0].checkpoint_key)
     self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
     self.assertEqual(
         "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
         serialized_graph.nodes[optimizer_node.slot_variables[
             0].slot_variable_node_id].attributes[0].checkpoint_key)
Example #45
0
def _bt_model_fn(
    features,
    labels,
    mode,
    head,
    feature_columns,
    tree_hparams,
    n_batches_per_layer,
    config,
    closed_form_grad_and_hess_fn=None,
    example_id_column_name=None,
    # TODO(youngheek): replace this later using other options.
    train_in_memory=False,
    name='boosted_trees'):
  """Gradient Boosted Trees model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
  is_single_machine = (config.num_worker_replicas <= 1)
  sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
  center_bias = tree_hparams.center_bias
  if train_in_memory:
    assert n_batches_per_layer == 1, (
        'When train_in_memory is enabled, input_fn should return the entire '
        'dataset as a single batch, and n_batches_per_layer should be set as '
        '1.')
    if (not config.is_chief or config.num_worker_replicas > 1 or
        config.num_ps_replicas > 0):
      raise ValueError('train_in_memory is supported only for '
                       'non-distributed training.')
  worker_device = control_flow_ops.no_op().device
  # maximum number of splits possible in the whole tree =2^(D-1)-1
  # TODO(youngheek): perhaps storage could be optimized by storing stats with
  # the dimension max_splits_per_layer, instead of max_splits (for the entire
  # tree).
  max_splits = (1 << tree_hparams.max_depth) - 1
  train_op = []
  with ops.name_scope(name) as name:
    # Prepare.
    global_step = training_util.get_or_create_global_step()
    bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
        sorted_feature_columns)
    # Extract input features and set up cache for training.
    training_state_cache = None
    if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
      # cache transformed features as well for in-memory training.
      batch_size = array_ops.shape(labels)[0]
      input_feature_list, input_cache_op = (
          _cache_transformed_features(features, sorted_feature_columns,
                                      batch_size))
      train_op.append(input_cache_op)
      training_state_cache = _CacheTrainingStatesUsingVariables(
          batch_size, head.logits_dimension)
    else:
      input_feature_list = _get_transformed_features(features,
                                                     sorted_feature_columns)
      if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
        example_ids = features[example_id_column_name]
        training_state_cache = _CacheTrainingStatesUsingHashTable(
            example_ids, head.logits_dimension)

    # Create Ensemble resources.
    tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
    # Variable that determines whether bias centering is needed.
    center_bias_var = variable_scope.variable(
        initial_value=center_bias, name='center_bias_needed', trainable=False)
    # Create logits.
    if mode != model_fn.ModeKeys.TRAIN:
      logits = boosted_trees_ops.predict(
          # For non-TRAIN mode, ensemble doesn't change after initialization,
          # so no local copy is needed; using tree_ensemble directly.
          tree_ensemble_handle=tree_ensemble.resource_handle,
          bucketized_features=input_feature_list,
          logits_dimension=head.logits_dimension)
    else:
      if is_single_machine:
        local_tree_ensemble = tree_ensemble
        ensemble_reload = control_flow_ops.no_op()
      else:
        # Have a local copy of ensemble for the distributed setting.
        with ops.device(worker_device):
          local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
              name=name + '_local', is_local=True)
        # TODO(soroush): Do partial updates if this becomes a bottleneck.
        ensemble_reload = local_tree_ensemble.deserialize(
            *tree_ensemble.serialize())

      if training_state_cache:
        cached_tree_ids, cached_node_ids, cached_logits = (
            training_state_cache.lookup())
      else:
        # Always start from the beginning when no cache is set up.
        batch_size = array_ops.shape(labels)[0]
        cached_tree_ids, cached_node_ids, cached_logits = (
            array_ops.zeros([batch_size], dtype=dtypes.int32),
            _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
            array_ops.zeros(
                [batch_size, head.logits_dimension], dtype=dtypes.float32))

      with ops.control_dependencies([ensemble_reload]):
        (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
         last_layer_nodes_range) = local_tree_ensemble.get_states()
        summary.scalar('ensemble/num_trees', num_trees)
        summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
        summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)

        partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
            tree_ensemble_handle=local_tree_ensemble.resource_handle,
            cached_tree_ids=cached_tree_ids,
            cached_node_ids=cached_node_ids,
            bucketized_features=input_feature_list,
            logits_dimension=head.logits_dimension)

      logits = cached_logits + partial_logits

    # Create training graph.
    def _train_op_fn(loss):
      """Run one training iteration."""
      if training_state_cache:
        # Cache logits only after center_bias is complete, if it's in progress.
        train_op.append(
            control_flow_ops.cond(
                center_bias_var, control_flow_ops.no_op,
                lambda: training_state_cache.insert(tree_ids, node_ids, logits))
        )

      if closed_form_grad_and_hess_fn:
        gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
      else:
        gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
        hessians = gradients_impl.gradients(
            gradients, logits, name='Hessians')[0]

      stats_summaries_list = []
      for i, feature_ids in enumerate(feature_ids_list):
        num_buckets = bucket_size_list[i]
        summaries = [
            array_ops.squeeze(
                boosted_trees_ops.make_stats_summary(
                    node_ids=node_ids,
                    gradients=gradients,
                    hessians=hessians,
                    bucketized_features_list=[input_feature_list[f]],
                    max_splits=max_splits,
                    num_buckets=num_buckets),
                axis=0) for f in feature_ids
        ]
        stats_summaries_list.append(summaries)

      # ========= Helper methods for both in and not in memory. ==============
      def grow_tree_from_stats_summaries(stats_summaries_list,
                                         feature_ids_list):
        """Updates ensemble based on the best gains from stats summaries."""
        node_ids_per_feature = []
        gains_list = []
        thresholds_list = []
        left_node_contribs_list = []
        right_node_contribs_list = []
        all_feature_ids = []

        assert len(stats_summaries_list) == len(feature_ids_list)

        for i, feature_ids in enumerate(feature_ids_list):
          (numeric_node_ids_per_feature, numeric_gains_list,
           numeric_thresholds_list, numeric_left_node_contribs_list,
           numeric_right_node_contribs_list) = (
               boosted_trees_ops.calculate_best_gains_per_feature(
                   node_id_range=last_layer_nodes_range,
                   stats_summary_list=stats_summaries_list[i],
                   l1=tree_hparams.l1,
                   l2=tree_hparams.l2,
                   tree_complexity=tree_hparams.tree_complexity,
                   min_node_weight=tree_hparams.min_node_weight,
                   max_splits=max_splits))

          all_feature_ids += feature_ids
          node_ids_per_feature += numeric_node_ids_per_feature
          gains_list += numeric_gains_list
          thresholds_list += numeric_thresholds_list
          left_node_contribs_list += numeric_left_node_contribs_list
          right_node_contribs_list += numeric_right_node_contribs_list

        grow_op = boosted_trees_ops.update_ensemble(
            # Confirm if local_tree_ensemble or tree_ensemble should be used.
            tree_ensemble.resource_handle,
            feature_ids=all_feature_ids,
            node_ids=node_ids_per_feature,
            gains=gains_list,
            thresholds=thresholds_list,
            left_node_contribs=left_node_contribs_list,
            right_node_contribs=right_node_contribs_list,
            learning_rate=tree_hparams.learning_rate,
            max_depth=tree_hparams.max_depth,
            pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
        return grow_op

      def _center_bias_fn(mean_gradients, mean_hessians):
        """Updates the ensembles and cache (if needed) with logits prior."""
        continue_centering = boosted_trees_ops.center_bias(
            tree_ensemble.resource_handle,
            mean_gradients=mean_gradients,
            mean_hessians=mean_hessians,
            l1=tree_hparams.l1,
            l2=tree_hparams.l2
        )
        return center_bias_var.assign(continue_centering)

      # ========= End of helper methods. ==============

      if train_in_memory and is_single_machine:
        train_op.append(distribute_lib.increment_var(global_step))

        mean_gradients = array_ops.expand_dims(
            math_ops.reduce_mean(gradients, 0), 0)
        mean_heassians = array_ops.expand_dims(
            math_ops.reduce_mean(hessians, 0), 0)

        train_op.append(
            control_flow_ops.cond(
                center_bias_var,
                lambda: _center_bias_fn(mean_gradients, mean_heassians),
                functools.partial(grow_tree_from_stats_summaries,
                                  stats_summaries_list, feature_ids_list)))
      else:

        def center_bias_not_in_mem():
          """Accumulates the data and updates the logits bias, when ready."""
          bias_dependencies = []

          bias_accumulator = data_flow_ops.ConditionalAccumulator(
              dtype=dtypes.float32,
              # The stats consist of grads and hessians means only.
              # TODO(nponomareva): this will change for a multiclass
              shape=[2, 1],
              shared_name='bias_accumulator')

          grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
          grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)

          apply_grad = bias_accumulator.apply_grad(grads_and_hess, stamp_token)
          bias_dependencies.append(apply_grad)

          def center_bias_from_accumulator():
            accumulated = array_ops.unstack(
                bias_accumulator.take_grad(1), axis=0)
            return _center_bias_fn(
                array_ops.expand_dims(accumulated[0], 0),
                array_ops.expand_dims(accumulated[1], 0))

          with ops.control_dependencies(bias_dependencies):
            if config.is_chief:
              center_bias_op = control_flow_ops.cond(
                  math_ops.greater_equal(bias_accumulator.num_accumulated(),
                                         n_batches_per_layer),
                  center_bias_from_accumulator,
                  control_flow_ops.no_op,
                  name='wait_until_n_batches_for_bias_accumulated')

              return center_bias_op

        def grow_not_in_mem():
          """Accumulates the data and grows a layer when ready."""

          accumulators = []
          dependencies = []
          for i, feature_ids in enumerate(feature_ids_list):
            stats_summaries = stats_summaries_list[i]
            accumulator = data_flow_ops.ConditionalAccumulator(
                dtype=dtypes.float32,
                # The stats consist of grads and hessians (the last dimension).
                shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
                shared_name='numeric_stats_summary_accumulator_' + str(i))
            accumulators.append(accumulator)

            apply_grad = accumulator.apply_grad(
                array_ops.stack(stats_summaries, axis=0), stamp_token)
            dependencies.append(apply_grad)

          def grow_tree_from_accumulated_summaries_fn():
            """Updates tree with the best layer from accumulated summaries."""
            # Take out the accumulated summaries from the accumulator and grow.
            stats_summaries_list = []

            stats_summaries_list = [
                array_ops.unstack(accumulator.take_grad(1), axis=0)
                for accumulator in accumulators
            ]

            grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
                                                     feature_ids_list)
            return grow_op

          with ops.control_dependencies(dependencies):
            if config.is_chief:
              min_accumulated = math_ops.reduce_min(
                  array_ops.stack(
                      [acc.num_accumulated() for acc in accumulators]))

              grow_model = control_flow_ops.cond(
                  math_ops.greater_equal(min_accumulated, n_batches_per_layer),
                  grow_tree_from_accumulated_summaries_fn,
                  control_flow_ops.no_op,
                  name='wait_until_n_batches_accumulated')

              return grow_model

        update_model = control_flow_ops.cond(
            center_bias_var, center_bias_not_in_mem, grow_not_in_mem)
        train_op.append(update_model)
        with ops.control_dependencies([update_model]):
          increment_global = distribute_lib.increment_var(global_step)
          train_op.append(increment_global)

      return control_flow_ops.group(train_op, name='train_op')

  estimator_spec = head.create_estimator_spec(
      features=features,
      mode=mode,
      labels=labels,
      train_op_fn=_train_op_fn,
      logits=logits)
  if mode == model_fn.ModeKeys.TRAIN:
    # Add an early stop hook.
    estimator_spec = estimator_spec._replace(
        training_hooks=estimator_spec.training_hooks +
        (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
                             tree_hparams.n_trees, tree_hparams.max_depth),))
  return estimator_spec
Example #46
0
def gan_train_ops(
        model,
        loss,
        generator_optimizer,
        discriminator_optimizer,
        check_for_unused_update_ops=True,
        # Optional args to pass directly to the `create_train_op`.
        **kwargs):
    """Returns GAN train ops.

  The highest-level call in TFGAN. It is composed of functions that can also
  be called, should a user require more control over some part of the GAN
  training process.

  Args:
    model: A GANModel.
    loss: A GANLoss.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: The optimizer for the discriminator updates.
    check_for_unused_update_ops: If `True`, throws an exception if there are
      update ops outside of the generator or discriminator scopes.
    **kwargs: Keyword args to pass directly to
      `training.create_train_op` for both the generator and
      discriminator train op.

  Returns:
    A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    be used to train a generator/discriminator pair.
  """
    if isinstance(model, namedtuples.CycleGANModel):
        # Get and store all arguments other than model and loss from locals.
        # Contents of locals should not be modified, may not affect values. So make
        # a copy. https://docs.python.org/2/library/functions.html#locals.
        saved_params = dict(locals())
        saved_params.pop('model', None)
        saved_params.pop('loss', None)
        kwargs = saved_params.pop('kwargs', {})
        saved_params.update(kwargs)
        with ops.name_scope('cyclegan_x2y_train'):
            train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
                                          **saved_params)
        with ops.name_scope('cyclegan_y2x_train'):
            train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
                                          **saved_params)
        return namedtuples.GANTrainOps(
            (train_ops_x2y.generator_train_op,
             train_ops_y2x.generator_train_op),
            (train_ops_x2y.discriminator_train_op,
             train_ops_y2x.discriminator_train_op),
            training_util.get_or_create_global_step().assign_add(1))

    # Create global step increment op.
    global_step = training_util.get_or_create_global_step()
    global_step_inc = global_step.assign_add(1)

    # Get generator and discriminator update ops. We split them so that update
    # ops aren't accidentally run multiple times. For now, throw an error if
    # there are update ops that aren't associated with either the generator or
    # the discriminator. Might modify the `kwargs` dictionary.
    gen_update_ops, dis_update_ops = _get_update_ops(
        kwargs, model.generator_scope.name, model.discriminator_scope.name,
        check_for_unused_update_ops)

    generator_global_step = None
    if isinstance(generator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # TODO(joelshor): Figure out a way to get this work without including the
        # dummy global step in the checkpoint.
        # WARNING: Making this variable a local variable causes sync replicas to
        # hang forever.
        generator_global_step = variable_scope.get_variable(
            'dummy_global_step_generator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        gen_update_ops += [generator_global_step.assign(global_step)]
    with ops.name_scope('generator_train'):
        gen_train_op = training.create_train_op(
            total_loss=loss.generator_loss,
            optimizer=generator_optimizer,
            variables_to_train=model.generator_variables,
            global_step=generator_global_step,
            update_ops=gen_update_ops,
            **kwargs)

    discriminator_global_step = None
    if isinstance(discriminator_optimizer,
                  sync_replicas_optimizer.SyncReplicasOptimizer):
        # See comment above `generator_global_step`.
        discriminator_global_step = variable_scope.get_variable(
            'dummy_global_step_discriminator',
            shape=[],
            dtype=global_step.dtype.base_dtype,
            initializer=init_ops.zeros_initializer(),
            trainable=False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        dis_update_ops += [discriminator_global_step.assign(global_step)]
    with ops.name_scope('discriminator_train'):
        disc_train_op = training.create_train_op(
            total_loss=loss.discriminator_loss,
            optimizer=discriminator_optimizer,
            variables_to_train=model.discriminator_variables,
            global_step=discriminator_global_step,
            update_ops=dis_update_ops,
            **kwargs)

    return namedtuples.GANTrainOps(gen_train_op, disc_train_op,
                                   global_step_inc)
 def _global_step(_):
   with variable_scope.variable_scope('', use_resource=True):
     return training_util.get_or_create_global_step()
Example #48
0
    def test_inv_update_thunks(self):
        """Ensures inverse update ops run once per global_step."""
        with self._graph.as_default(), self.test_session() as sess:
            fisher_estimator = estimator.FisherEstimatorRoundRobin(
                variables=[self.weights],
                layer_collection=self.layer_collection,
                damping=0.2,
                cov_ema_decay=0.0)

            # Construct op that updates one inverse per global step.
            global_step = training_util.get_or_create_global_step()
            (cov_variable_thunks, _, inv_variable_thunks, inv_update_op_thunks
             ) = fisher_estimator.create_ops_and_vars_thunks()
            for thunk in cov_variable_thunks:
                thunk()
            for thunk in inv_variable_thunks:
                thunk()
            inv_matrices = [
                matrix
                for fisher_factor in self.layer_collection.get_factors() for
                matrix in fisher_factor._matpower_by_exp_and_damping.values()
            ]
            inv_update_op = control_flow_ops.case([
                (math_ops.equal(global_step, i), thunk)
                for i, thunk in enumerate(inv_update_op_thunks)
            ])
            increment_global_step = global_step.assign_add(1)

            sess.run(variables.global_variables_initializer())
            initial_inv_values = sess.run(inv_matrices)

            # Ensure there's one update per inverse matrix. This is true as long as
            # there's no fan-in/fan-out or parameter re-use.
            self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))

            # Test is no-op if only 1 invariance matrix.
            assert len(inv_matrices) > 1

            # Assign each covariance matrix a value other than the identity. This
            # ensures that the inverse matrices are updated to something different as
            # well.
            cov_matrices = [
                fisher_factor.get_cov()
                for fisher_factor in self.layer_collection.get_factors()
            ]
            sess.run([
                cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
                for cov_matrix in cov_matrices
            ])

            for i in range(len(inv_matrices)):
                # Compare new and old inverse values
                new_inv_values = sess.run(inv_matrices)
                is_inv_equal = [
                    np.allclose(initial_inv_value, new_inv_value)
                    for (initial_inv_value, new_inv_value
                         ) in zip(initial_inv_values, new_inv_values)
                ]
                num_inv_equal = sum(is_inv_equal)

                # Ensure exactly one inverse matrix changes per step.
                self.assertEqual(num_inv_equal, len(inv_matrices) - i)

                # Run all inverse update ops.
                sess.run(inv_update_op)
                sess.run(increment_global_step)
 def begin(self):
     if self._replace_summary_op:
         # This can still remain None if there are no summaries.
         self._summary_op = summary.merge_all()
     self._global_step = training_util.get_or_create_global_step()
 def make_opt():
     gstep = training_util.get_or_create_global_step()
     lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
     return training.GradientDescentOptimizer(lr)
Example #51
0
def _run_epoch(sess,
               model,
               args,
               data,
               index=0,
               tb_summaries=None,
               id_to_word=None,
               train_op=None,
               verbose=False):
    """Runs the model on the given data and displays metrics to monitor
    progress.
    """
    epoch_start_time = time.time()
    # total cost and number of words evaluated in this epoch
    costs, total_words = 0.0, 0.0
    # epoch size is number of batches in each epoch
    epoch_size = (len(data[index]) - 1) // model.config['batch_size']
    state = sess.run(model.initial_state)

    # iterate through batches
    for step, (x, y) in enumerate(
            data_reader.batch_iterator(data[index],
                                       model.config['batch_size'])):
        # return these parameters after running TF session
        fetches = {
            'cost': model.cost[index],
            'final_state': model.final_state,
            'seq_len': model.seq_len
        }
        # only train model has optimizer operation
        if train_op is not None:
            fetches['train_op'] = train_op[index]

        # create dict to feed input, targets, and rnn into TF session
        feed_dict = utils.create_feed_dict(model, args, x, y, state)
        # run all parameters in fetches dict
        vals = sess.run(fetches, feed_dict)

        costs += vals['cost']
        # number of words evaluated
        total_words += np.sum(vals['seq_len'])
        # use perplexity to evaluate language models
        perplexity = np.exp(costs / total_words)

        if verbose and step % (epoch_size // 2) == 1:
            # display perplexity and top word predictions for sequence
            _display_epoch_metrics(step, epoch_size, perplexity, total_words,
                                   epoch_start_time, args, model, sess, index,
                                   feed_dict, vals, id_to_word, y)

    # generate sample text while training to monitor progress
    if args.display_text == 'True' and model.name == 'Train':
        generate.generate_text(sess, model, id_to_word, train_ind=index)

    # write TensorBoard summaries for Train/Valid
    if args.save_path != '' and model.name != 'Test':
        summary = sess.run(tb_summaries.summary_op,
                           {tb_summaries.ppl_summary: perplexity})
        model.file_writer.add_summary(summary,
                                      get_or_create_global_step().eval())

    return perplexity
Example #52
0
def create_train_op(total_loss,
                    optimizer,
                    global_step=_USE_GLOBAL_STEP,
                    update_ops=None,
                    variables_to_train=None,
                    transform_grads_fn=None,
                    summarize_gradients=False,
                    gate_gradients=tf_optimizer.Optimizer.GATE_OP,
                    aggregation_method=None,
                    colocate_gradients_with_ops=False,
                    check_numerics=True):
    """Creates an `Operation` that evaluates the gradients and returns the loss.

  Args:
    total_loss: A `Tensor` representing the total loss.
    optimizer: A tf.Optimizer to use for computing the gradients.
    global_step: A `Tensor` representing the global step variable. If left as
      `_USE_GLOBAL_STEP`, then tf.contrib.framework.global_step() is used.
    update_ops: An optional list of updates to execute. If `update_ops` is
      `None`, then the update ops are set to the contents of the
      `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
      it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
      a warning will be displayed.
    variables_to_train: an optional list of variables to train. If None, it will
      default to all tf.trainable_variables().
    transform_grads_fn: A function which takes a single argument, a list of
      gradient to variable pairs (tuples), performs any requested gradient
      updates, such as gradient clipping or multipliers, and returns the updated
      list.
    summarize_gradients: Whether or not add summaries for each gradient.
    gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
    aggregation_method: Specifies the method used to combine gradient terms.
      Valid values are defined in the class `AggregationMethod`.
    colocate_gradients_with_ops: Whether or not to try colocating the gradients
      with the ops that generated them.
    check_numerics: Whether or not we apply check_numerics.

  Returns:
    A `Tensor` that when evaluated, computes the gradients and returns the total
      loss value.
  """
    if global_step is _USE_GLOBAL_STEP:
        global_step = training_util.get_or_create_global_step()

    # Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
    global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
    if update_ops is None:
        update_ops = global_update_ops
    else:
        update_ops = set(update_ops)
    if not global_update_ops.issubset(update_ops):
        logging.warning(
            'update_ops in create_train_op does not contain all the '
            'update_ops in GraphKeys.UPDATE_OPS')

    # Make sure update_ops are computed before total_loss.
    if update_ops:
        with ops.control_dependencies(update_ops):
            barrier = control_flow_ops.no_op(name='update_barrier')
        total_loss = control_flow_ops.with_dependencies([barrier], total_loss)

    if variables_to_train is None:
        # Default to tf.trainable_variables()
        variables_to_train = tf_variables.trainable_variables()
    else:
        # Make sure that variables_to_train are in tf.trainable_variables()
        for v in variables_to_train:
            assert v.trainable or v in tf_variables.trainable_variables()

    assert variables_to_train

    # Create the gradients. Note that apply_gradients adds the gradient
    # computation to the current graph.
    grads = optimizer.compute_gradients(
        total_loss,
        variables_to_train,
        gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops)

    if transform_grads_fn:
        grads = transform_grads_fn(grads)

    # Summarize gradients.
    if summarize_gradients:
        with ops.name_scope('summarize_grads'):
            add_gradients_summaries(grads)

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(grads, global_step=global_step)

    with ops.name_scope('train_op'):
        # Make sure total_loss is valid.
        if check_numerics:
            total_loss = array_ops.check_numerics(total_loss,
                                                  'LossTensor is inf or nan')

        # Ensure the train_tensor computes grad_updates.
        train_op = control_flow_ops.with_dependencies([grad_updates],
                                                      total_loss)

    # Add the operation used for training to the 'train_op' collection
    train_ops = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    if train_op not in train_ops:
        train_ops.append(train_op)

    return train_op