예제 #1
0
    def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
        logdir1 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                               'tmp_logs1')

        # First, train only the weights of the model.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            weights = variables_lib2.get_variables_by_name('weights')

            train_op = learning.create_train_op(total_loss,
                                                optimizer,
                                                variables_to_train=weights)

            loss = learning.train(train_op,
                                  logdir1,
                                  number_of_steps=200,
                                  log_every_n_steps=10)
            self.assertGreater(loss, .015)
            self.assertLess(loss, .05)

        # Next, train the biases of the model.
        with ops.Graph().as_default():
            random_seed.set_random_seed(1)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            biases = variables_lib2.get_variables_by_name('biases')

            train_op = learning.create_train_op(total_loss,
                                                optimizer,
                                                variables_to_train=biases)

            loss = learning.train(train_op,
                                  logdir1,
                                  number_of_steps=300,
                                  log_every_n_steps=10)
            self.assertGreater(loss, .015)
            self.assertLess(loss, .05)

        # Finally, train both weights and bias to get lower loss.
        with ops.Graph().as_default():
            random_seed.set_random_seed(2)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)
            loss = learning.train(train_op,
                                  logdir1,
                                  number_of_steps=400,
                                  log_every_n_steps=10)

            self.assertIsNotNone(loss)
            self.assertLess(loss, .015)
예제 #2
0
    def testTrainWithInitFromCheckpoint(self):
        logdir1 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                               'tmp_logs1')
        logdir2 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                               'tmp_logs2')

        # First, train the model one step (make sure the error is high).
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op()
            loss = learning.train(train_op, logdir1, number_of_steps=1)
            self.assertGreater(loss, .5)

        # Next, train the model to convergence.
        with ops.Graph().as_default():
            random_seed.set_random_seed(1)
            train_op = self.create_train_op()
            loss = learning.train(train_op,
                                  logdir1,
                                  number_of_steps=300,
                                  log_every_n_steps=10)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)

        # Finally, advance the model a single step and validate that the loss is
        # still low.
        with ops.Graph().as_default():
            random_seed.set_random_seed(2)
            train_op = self.create_train_op()

            model_variables = variables_lib.global_variables()
            model_path = os.path.join(logdir1, 'model.ckpt-300')

            init_op = variables_lib.global_variables_initializer()
            op, init_feed_dict = variables_lib2.assign_from_checkpoint(
                model_path, model_variables)

            def InitAssignFn(sess):
                sess.run(op, init_feed_dict)

            loss = learning.train(train_op,
                                  logdir2,
                                  number_of_steps=1,
                                  init_op=init_op,
                                  init_fn=InitAssignFn)

            self.assertIsNotNone(loss)
            self.assertLess(loss, .02)
예제 #3
0
    def testTrainWithLocalVariable(self):
        logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                              'tmp_logs')
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            local_multiplier = variables_lib2.local_variable(1.0)

            tf_predictions = LogisticClassifier(tf_inputs) * local_multiplier
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            loss = learning.train(train_op,
                                  logdir,
                                  number_of_steps=300,
                                  log_every_n_steps=10)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .015)
예제 #4
0
    def testResumeTrainAchievesRoughlyTheSameLoss(self):
        logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                              'tmp_logs')
        number_of_steps = [300, 301, 305]

        for i in range(len(number_of_steps)):
            with ops.Graph().as_default():
                random_seed.set_random_seed(i)
                tf_inputs = constant_op.constant(self._inputs,
                                                 dtype=dtypes.float32)
                tf_labels = constant_op.constant(self._labels,
                                                 dtype=dtypes.float32)

                tf_predictions = LogisticClassifier(tf_inputs)
                loss_ops.log_loss(tf_predictions, tf_labels)
                total_loss = loss_ops.get_total_loss()

                optimizer = gradient_descent.GradientDescentOptimizer(
                    learning_rate=1.0)

                train_op = learning.create_train_op(total_loss, optimizer)

                loss = learning.train(train_op,
                                      logdir,
                                      number_of_steps=number_of_steps[i],
                                      log_every_n_steps=10)
                self.assertIsNotNone(loss)
                self.assertLess(loss, .015)
예제 #5
0
    def testTrainWithSessionConfig(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = LogisticClassifier(tf_inputs)
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            session_config = config_pb2.ConfigProto(allow_soft_placement=True)
            loss = learning.train(train_op,
                                  None,
                                  number_of_steps=300,
                                  log_every_n_steps=10,
                                  session_config=session_config)
        self.assertIsNotNone(loss)
        self.assertLess(loss, .015)
예제 #6
0
    def testTrainWithTrace(self):
        logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                              'tmp_logs')
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = LogisticClassifier(tf_inputs)
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()
            summary.scalar('total_loss', total_loss)

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            loss = learning.train(train_op,
                                  logdir,
                                  number_of_steps=300,
                                  log_every_n_steps=10,
                                  trace_every_n_steps=100)
        self.assertIsNotNone(loss)
        for trace_step in [1, 101, 201]:
            trace_filename = 'tf_trace-%d.json' % trace_step
            self.assertTrue(
                os.path.isfile(os.path.join(logdir, trace_filename)))
예제 #7
0
    def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
        logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                              'tmp_logs')
        g = ops.Graph()
        with g.as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = BatchNormClassifier(tf_inputs)
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            loss = learning.train(train_op,
                                  logdir,
                                  number_of_steps=300,
                                  log_every_n_steps=10)
            self.assertLess(loss, .1)
예제 #8
0
    def testTrainWithEpochLimit(self):
        logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                              'tmp_logs')
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)
            tf_inputs_limited = input_lib.limit_epochs(tf_inputs,
                                                       num_epochs=300)
            tf_labels_limited = input_lib.limit_epochs(tf_labels,
                                                       num_epochs=300)

            tf_predictions = LogisticClassifier(tf_inputs_limited)
            loss_ops.log_loss(tf_predictions, tf_labels_limited)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            loss = learning.train(train_op, logdir, log_every_n_steps=10)
        self.assertIsNotNone(loss)
        self.assertLess(loss, .015)
        self.assertTrue(
            os.path.isfile('{}/model.ckpt-300.index'.format(logdir)))
        self.assertTrue(
            os.path.isfile(
                '{}/model.ckpt-300.data-00000-of-00001'.format(logdir)))
예제 #9
0
    def testTrainWithAlteredGradients(self):
        # Use the same learning rate but different gradient multipliers
        # to train two models. Model with equivalently larger learning
        # rate (i.e., learning_rate * gradient_multiplier) has smaller
        # training loss.
        logdir1 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                               'tmp_logs1')
        logdir2 = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                               'tmp_logs2')

        multipliers = [1., 1000.]
        number_of_steps = 10
        losses = []
        learning_rate = 0.001

        # First, train the model with equivalently smaller learning rate.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op(learning_rate=learning_rate,
                                            gradient_multiplier=multipliers[0])
            loss = learning.train(train_op,
                                  logdir1,
                                  number_of_steps=number_of_steps)
            losses.append(loss)
            self.assertGreater(loss, .5)

        # Second, train the model with equivalently larger learning rate.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            train_op = self.create_train_op(learning_rate=learning_rate,
                                            gradient_multiplier=multipliers[1])
            loss = learning.train(train_op,
                                  logdir2,
                                  number_of_steps=number_of_steps)
            losses.append(loss)
            self.assertIsNotNone(loss)
            self.assertLess(loss, .5)

        # The loss of the model trained with larger learning rate should
        # be smaller.
        self.assertGreater(losses[0], losses[1])
예제 #10
0
    def testTrainWithNoneAsLogdirWhenUsingTraceRaisesError(self):
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = LogisticClassifier(tf_inputs)
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            with self.assertRaises(ValueError):
                learning.train(train_op,
                               None,
                               number_of_steps=300,
                               trace_every_n_steps=10)
예제 #11
0
    def testTrainWithNoneAsInitWhenUsingVarsRaisesError(self):
        logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
                              'tmp_logs')
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = LogisticClassifier(tf_inputs)
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            with self.assertRaises(RuntimeError):
                learning.train(train_op,
                               logdir,
                               init_op=None,
                               number_of_steps=300)
예제 #12
0
    def testTrainWithSessionWrapper(self):
        """Test that slim.learning.train can take `session_wrapper` args.

    One of the applications of `session_wrapper` is the wrappers of TensorFlow
    Debugger (tfdbg), which intercept methods calls to `tf.Session` (e.g., run)
    to achieve debugging. `DumpingDebugWrapperSession` is used here for testing
    purpose.
    """
        dump_root = tempfile.mkdtemp()

        def dumping_wrapper(sess):  # pylint: disable=invalid-name
            return dumping_wrapper_lib.DumpingDebugWrapperSession(
                sess, dump_root)

        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            tf_inputs = constant_op.constant(self._inputs,
                                             dtype=dtypes.float32)
            tf_labels = constant_op.constant(self._labels,
                                             dtype=dtypes.float32)

            tf_predictions = LogisticClassifier(tf_inputs)
            loss_ops.log_loss(tf_predictions, tf_labels)
            total_loss = loss_ops.get_total_loss()

            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)

            train_op = learning.create_train_op(total_loss, optimizer)

            loss = learning.train(train_op,
                                  None,
                                  number_of_steps=1,
                                  session_wrapper=dumping_wrapper)
        self.assertIsNotNone(loss)

        run_root = glob.glob(os.path.join(dump_root, 'run_*'))[-1]
        dump = debug_data.DebugDumpDir(run_root)
        self.assertAllEqual(
            0,
            dump.get_tensors('global_step', 0, 'DebugIdentity')[0])
예제 #13
0
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          num_examples,
          total_configs,
          model_config,
          is_first_training=True):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    num_examples: The number of examples in dataset for training.
    total_configs: config list
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            if is_first_training:
                global_step = slim.create_global_step()
            else:
                prev_global_step = int(
                    train_config.fine_tune_checkpoint.split('-')[-1])
                global_step = variable_scope.get_variable(
                    ops.GraphKeys.GLOBAL_STEP,
                    dtype=dtypes.int64,
                    initializer=tf.constant(prev_global_step,
                                            dtype=dtypes.int64),
                    trainable=False,
                    collections=[
                        ops.GraphKeys.GLOBAL_VARIABLES,
                        ops.GraphKeys.GLOBAL_STEP
                    ])

        with tf.device(deploy_config.inputs_device()):
            input_queue = _create_input_queue(
                train_config.batch_size // num_clones,
                create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options,
                ignore_options=train_config.ignore_options,
                mtl_window=model_config.mtl.window,
                mtl_edgemask=model_config.mtl.edgemask)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        kwargs = {}
        kwargs['mtl'] = model_config.mtl

        update_schedule = None
        model_fn = functools.partial(
            _create_losses,
            create_model_fn=create_model_fn,
            show_image_summary=train_config.show_image_summary,
            update_schedule=update_schedule,
            **kwargs)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope
        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            # TODO: support syncrhonous update for manual loss update
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint,
                restore_box_predictor=train_config.restore_box_predictor,
                restore_window=train_config.restore_window,
                restore_edgemask=train_config.restore_edgemask,
                restore_closeness=train_config.restore_closeness,
                restore_mtl_refine=train_config.restore_mtl_refine,
            )
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            mtl = model_config.mtl
            mtl_init_saver_list = []

            def _get_mtl_init_saver(scope_name):
                _var_map = detection_model._feature_extractor.mtl_restore_from_classification_checkpoint_fn(
                    scope_name)
                if train_config.from_detection_checkpoint:
                    _var_map_new = dict()
                    for name, val in _var_map.iteritems():
                        _var_map_new[detection_model.
                                     second_stage_feature_extractor_scope +
                                     '/' + name] = val
                    _var_map = _var_map_new
                _available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        _var_map, train_config.fine_tune_checkpoint))
                if _available_var_map:
                    return tf.train.Saver(_available_var_map)
                else:
                    return None

            # if mtl.share_second_stage_init and mtl.shared_feature == 'proposal_feature_maps':
            if mtl.share_second_stage_init and train_config.from_detection_checkpoint == False:
                if mtl.window:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.window_box_predictor_scope))
                if mtl.closeness:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.closeness_box_predictor_scope))
                if mtl.edgemask:
                    mtl_init_saver_list.append(
                        _get_mtl_init_saver(
                            detection_model.edgemask_predictor_scope))

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)
                for mtl_init_saver in mtl_init_saver_list:
                    if not mtl_init_saver == None:
                        mtl_init_saver.restore(
                            sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        def _get_trainable_variables(except_scopes=None):
            trainable_variables = tf.trainable_variables()
            if except_scopes is None:
                return trainable_variables
            for var in tf.trainable_variables():
                if any([scope in var.name for scope in except_scopes]):
                    trainable_variables.remove(var)
            return trainable_variables

        def _get_update_ops(except_scopes=None):
            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)
            if except_scopes is None:
                return update_ops
            for var in tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                         first_clone_scope):
                if any([scope in var.name for scope in except_scopes]):
                    update_ops.remove(var)
            return update_ops

        with tf.device(deploy_config.optimizer_device()):

            def _single_update():
                kwargs = {}
                _training_optimizer = training_optimizer
                kwargs['var_list'] = None
                update_ops = _get_update_ops()
                total_loss, grads_and_vars = model_deploy.optimize_clones(
                    clones,
                    _training_optimizer,
                    regularization_losses=None,
                    **kwargs)

                # Optionaly multiply gradients by train_config.{grad_multiplier,
                # divide_grad_by_batch}.
                if train_config.grad_multiplier or train_config.divide_grad_by_batch:
                    base_multiplier = train_config.grad_multiplier \
                        if train_config.grad_multiplier else 1.0
                    batch_divider = float(train_config.batch_size) \
                        if train_config.divide_grad_by_batch else 1.0
                    total_multiplier = base_multiplier / batch_divider
                    grads_and_vars = variables_helper.multiply_gradients_by_scalar_multiplier(
                        grads_and_vars, multiplier=total_multiplier)

                # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
                if train_config.bias_grad_multiplier:
                    biases_regex_list = ['.*/biases']
                    grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                        grads_and_vars,
                        biases_regex_list,
                        multiplier=train_config.bias_grad_multiplier)

                # Optionally freeze some layers by setting their gradients to be zero.
                if train_config.freeze_variables:
                    grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                        grads_and_vars, train_config.freeze_variables)

                # Optionally clip gradients
                if train_config.gradient_clipping_by_norm > 0:
                    with tf.name_scope('clip_grads'):
                        grads_and_vars = slim.learning.clip_gradient_norms(
                            grads_and_vars,
                            train_config.gradient_clipping_by_norm)

                # Create gradient updates.
                grad_updates = _training_optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                # update_ops.append(grad_updates)
                total_update_ops = update_ops + [grad_updates]

                update_op = tf.group(*total_update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name=('train_op'))
                return train_tensor

            train_tensor = _single_update()

        # Add summaries.
        def _get_total_loss_with_collection(collection,
                                            add_regularization_losses=True,
                                            name="total_loss"):
            losses = tf.losses.get_losses(loss_collection=collection)
            if add_regularization_losses:
                losses += tf.losses.get_regularization_losses()
            return math_ops.add_n(losses, name=name)

        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # not contained in global_summaries
        config_summary_list = select_config_summary_list(total_configs,
                                                         as_matrix=False)

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        custom_learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            global_step=(None if is_first_training else global_step),
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            log_every_n_steps=(train_config.log_every_n_steps
                               if train_config.log_every_n_steps else None),
            save_summaries_secs=train_config.save_summaries_secs,
            save_interval_secs=train_config.save_interval_secs,
            sync_optimizer=sync_optimizer,
            saver=saver,
            batch_size=train_config.batch_size,
            num_examples=num_examples,
            config_summary_list=config_summary_list)