示例#1
0
    def start(self, notify_func=None, args=None):
        input_config = self.config
        if input_config is None:
            tf.logging.error('There is no input configurations.')
            return

        try:
            training_config_file = os.path.join(self.local_path,
                                                'training_configs.json')
            print(training_config_file)
            with open(training_config_file) as f:
                self.training_configs = json.load(f)
            training_configs = self.training_configs
            training_configs['dataset_params']['dataset_dir'] = input_config[
                'data_dir']
            training_configs['fine_tuning_params'][
                'tf_initial_checkpoint'] = input_config['fine_tune_ckpt']
            common.FLAGS.min_resize_value = training_configs['common'][
                'min_resize_value']
            common.FLAGS.max_resize_value = training_configs['common'][
                'max_resize_value']
            common.FLAGS.resize_factor = training_configs['common'][
                'resize_factor']
            common.FLAGS.logits_kernel_size = training_configs['common'][
                'logits_kernel_size']
            common.FLAGS.model_variant = training_configs['common'][
                'model_variant']
            common.FLAGS.image_pyramid = training_configs['common'][
                'image_pyramid']
            common.FLAGS.add_image_level_feature = training_configs['common'][
                'add_image_level_feature']
            common.FLAGS.aspp_with_batch_norm = training_configs['common'][
                'aspp_with_batch_norm']
            common.FLAGS.aspp_with_separable_conv = training_configs['common'][
                'aspp_with_separable_conv']
            common.FLAGS.multi_grid = training_configs['common']['multi_grid']
            common.FLAGS.depth_multiplier = training_configs['common'][
                'depth_multiplier']
            common.FLAGS.decoder_output_stride = training_configs['common'][
                'decoder_output_stride']
            common.FLAGS.decoder_use_separable_conv = training_configs[
                'common']['decoder_use_separable_conv']
            common.FLAGS.merge_method = training_configs['common'][
                'merge_method']

            # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
            config = model_deploy.DeploymentConfig(
                num_clones=training_configs['tf_configs']['num_clones'],
                clone_on_cpu=training_configs['tf_configs']['clone_on_cpu'],
                replica_id=training_configs['tf_configs']['task'],
                num_replicas=training_configs['tf_configs']['num_replicas'],
                num_ps_tasks=training_configs['tf_configs']['num_ps_tasks'])

            # Split the batch across GPUs.
            assert training_configs['learning_params'][
                'train_batch_size'] % config.num_clones == 0, (
                    'Training batch size not divisble by number of clones (GPUs).'
                )

            clone_batch_size = int(
                training_configs['learning_params']['train_batch_size'] /
                config.num_clones)

            # Get dataset-dependent information.
            dataset = self._get_dataset(
                training_configs['dataset_params']['train_split'],
                dataset_dir=training_configs['dataset_params']['dataset_dir'])

            train_dir = self.local_path
            training_configs['logging_configs']['train_logdir'] = train_dir

            with tf.Graph().as_default():
                with tf.device(config.inputs_device()):
                    samples = input_generator.get(
                        dataset,
                        training_configs['learning_params']['train_crop_size'],
                        clone_batch_size,
                        min_resize_value=training_configs['common']
                        ['min_resize_value'],
                        max_resize_value=training_configs['common']
                        ['max_resize_value'],
                        resize_factor=training_configs['common']
                        ['resize_factor'],
                        min_scale_factor=training_configs['fine_tuning_params']
                        ['min_scale_factor'],
                        max_scale_factor=training_configs['fine_tuning_params']
                        ['max_scale_factor'],
                        scale_factor_step_size=training_configs[
                            'fine_tuning_params']['scale_factor_step_size'],
                        dataset_split=training_configs['dataset_params']
                        ['train_split'],
                        is_training=True,
                        model_variant=training_configs['common']
                        ['model_variant'])
                    inputs_queue = prefetch_queue.prefetch_queue(
                        samples, capacity=128 * config.num_clones)

                # Create the global step on the device storing the variables.
                with tf.device(config.variables_device()):
                    global_step = tf.train.get_or_create_global_step()

                    # Define the model and create clones.
                    model_fn = self._build_deeplab
                    model_args = (inputs_queue, {
                        common.OUTPUT_TYPE: dataset.num_classes
                    }, dataset.ignore_label)
                    clones = model_deploy.create_clones(config,
                                                        model_fn,
                                                        args=model_args)

                    # Gather update_ops from the first clone. These contain, for example,
                    # the updates for the batch_norm variables created by model_fn.
                    first_clone_scope = config.clone_scope(0)
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                   first_clone_scope)

                # Gather initial summaries.
                summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

                # Add summaries for model variables.
                for model_var in slim.get_model_variables():
                    summaries.add(
                        tf.summary.histogram(model_var.op.name, model_var))

                # Add summaries for losses.
                for loss in tf.get_collection(tf.GraphKeys.LOSSES,
                                              first_clone_scope):
                    summaries.add(
                        tf.summary.scalar('losses/%s' % loss.op.name, loss))

                # Build the optimizer based on the device specification.
                with tf.device(config.optimizer_device()):
                    learning_rate = train_utils.get_model_learning_rate(
                        training_configs['learning_params']['learning_policy'],
                        training_configs['learning_params']
                        ['base_learning_rate'],
                        training_configs['learning_params']
                        ['learning_rate_decay_step'],
                        training_configs['learning_params']
                        ['learning_rate_decay_factor'],
                        training_configs['learning_params']
                        ['training_number_of_steps'],
                        training_configs['learning_params']['learning_power'],
                        training_configs['fine_tuning_params']
                        ['slow_start_step'],
                        training_configs['fine_tuning_params']
                        ['slow_start_learning_rate'])
                    optimizer = tf.train.MomentumOptimizer(
                        learning_rate,
                        training_configs['learning_params']['momentum'])
                    summaries.add(
                        tf.summary.scalar('learning_rate', learning_rate))

                startup_delay_steps = training_configs['tf_configs'][
                    'task'] * training_configs['tf_configs'][
                        'startup_delay_steps']
                for variable in slim.get_model_variables():
                    summaries.add(
                        tf.summary.histogram(variable.op.name, variable))

                with tf.device(config.variables_device()):
                    total_loss, grads_and_vars = model_deploy.optimize_clones(
                        clones, optimizer)
                    total_loss = tf.check_numerics(total_loss,
                                                   'Loss is inf or nan.')
                    summaries.add(tf.summary.scalar('total_loss', total_loss))

                    # Modify the gradients for biases and last layer variables.
                    last_layers = model.get_extra_layer_scopes(
                        training_configs['fine_tuning_params']
                        ['last_layers_contain_logits_only'])
                    grad_mult = train_utils.get_model_gradient_multipliers(
                        last_layers, training_configs['learning_params']
                        ['last_layer_gradient_multiplier'])
                    if grad_mult:
                        grads_and_vars = slim.learning.multiply_gradients(
                            grads_and_vars, grad_mult)

                    # Create gradient update op.
                    grad_updates = optimizer.apply_gradients(
                        grads_and_vars, global_step=global_step)
                    update_ops.append(grad_updates)
                    update_op = tf.group(*update_ops)
                    with tf.control_dependencies([update_op]):
                        train_tensor = tf.identity(total_loss, name='train_op')

                # Add the summaries from the first clone. These contain the summaries
                # created by model_fn and either optimize_clones() or _gather_clone_loss().
                summaries |= set(
                    tf.get_collection(tf.GraphKeys.SUMMARIES,
                                      first_clone_scope))

                # Merge all summaries together.
                summary_op = tf.summary.merge(list(summaries))

                # Soft placement allows placing on CPU ops without GPU implementation.
                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                session_config.gpu_options.allow_growth = True

                weblog_dir = input_config['weblog_dir']
                if not os.path.exists(weblog_dir):
                    os.makedirs(weblog_dir)
                logger = Logger('Training Monitor')

                init_fn = train_utils.get_model_init_fn(
                    training_configs['logging_configs']['train_logdir'],
                    training_configs['fine_tuning_params']
                    ['tf_initial_checkpoint'],
                    training_configs['fine_tuning_params']
                    ['initialize_last_layer'],
                    last_layers,
                    ignore_missing_vars=True)

                # Start the training.
                learning.train(
                    train_tensor,
                    logdir=train_dir,
                    log_every_n_steps=training_configs['logging_configs']
                    ['log_steps'],
                    master=training_configs['tf_configs']['master'],
                    number_of_steps=training_configs['learning_params']
                    ['training_number_of_steps'],
                    is_chief=(training_configs['tf_configs']['task'] == 0),
                    session_config=session_config,
                    startup_delay_steps=startup_delay_steps,
                    init_fn=init_fn,
                    summary_op=summary_op,
                    save_summaries_secs=training_configs['logging_configs']
                    ['save_summaries_secs'],
                    save_interval_secs=training_configs['logging_configs']
                    ['save_interval_secs'],
                    logger=logger,
                    weblog_dir=weblog_dir,
                    notify_func=notify_func,
                    args=args)
        except:
            tf.logging.error('Unexpected error')
示例#2
0
  def train(self):
    config = self.config
    if config is None:
      tf.logging.error('There is no input configurations.')
      return

    try:
      with open(config['training_configs']) as f:
        training_configs = json.load(f)
      training_configs['tf_configs']['train_dir'] = config['train_dir']
      training_configs['tf_configs']['log_every_n_steps'] = int(config['log_every_n_steps'])
      training_configs['optimization_params']['optimizer'] = config['optimizer']
      training_configs['learning_rate_params']['learning_rate'] = float(config['learning_rate'])
      training_configs['dataset_params']['batch_size'] = int(config['batch_size'])
      training_configs['dataset_params']['model_name'] = config['model_name']
      training_configs['dataset_params']['dataset_dir'] = config['data_dir']
      training_configs['fine_tuning_params']['checkpoint_path'] = config['fine_tuning_ckpt_path']
      if training_configs['fine_tuning_params']['checkpoint_path'] is not None:
        training_configs['fine_tuning_params']['checkpoint_exclude_scopes'] = \
        exclude_scopes_map[training_configs['dataset_params']['model_name']].format(\
        scope_map[training_configs['dataset_params']['model_name']], \
        scope_map[training_configs['dataset_params']['model_name']])
        training_configs['fine_tuning_params']['trainable_scopes'] = \
        exclude_scopes_map[training_configs['dataset_params']['model_name']].format(\
        scope_map[training_configs['dataset_params']['model_name']], \
        scope_map[training_configs['dataset_params']['model_name']])
      self.training_configs = training_configs

      with tf.Graph().as_default():
        # use only one gpu
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
        # from tensorflow.python.client import device_lib 
        # local_device_protos = device_lib.list_local_devices()

        # create tf_record data
        # self.create_tf_data()
        self.num_classes = 5
        self.splits_to_sizes = {'train': 3320, 'val': 350}
        self.items_to_descriptions = {'image': 'A color image of varying size.',
                                      'label': 'A single integer between 0 and 4'}        

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=training_configs['tf_configs']['num_clones'],
            clone_on_cpu=training_configs['tf_configs']['clone_on_cpu'],
            replica_id=training_configs['tf_configs']['task'],
            num_replicas=training_configs['tf_configs']['worker_replicas'],
            num_ps_tasks=training_configs['tf_configs']['num_ps_tasks'])

        # Create global_step
        with tf.device(deploy_config.variables_device()):
          global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = self.get_dataset('train', training_configs['dataset_params']['dataset_dir'])

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            training_configs['dataset_params']['model_name'],
            num_classes=(dataset.num_classes - training_configs['dataset_params']['label_offset']),
            weight_decay=training_configs['optimization_params']['weight_decay'],
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = training_configs['dataset_params']['preprocessing_name'] or training_configs['dataset_params']['model_name']
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
          provider = slim.dataset_data_provider.DatasetDataProvider(
              dataset,
              num_readers=training_configs['tf_configs']['num_readers'],
              common_queue_capacity=20 * training_configs['dataset_params']['batch_size'],
              common_queue_min=10 * training_configs['dataset_params']['batch_size'])
          [image, label] = provider.get(['image', 'label'])
          label -= training_configs['dataset_params']['label_offset']

          train_image_size = training_configs['dataset_params']['train_image_size'] or network_fn.default_image_size

          image = image_preprocessing_fn(image, train_image_size, train_image_size)

          images, labels = tf.train.batch(
              [image, label],
              batch_size=training_configs['dataset_params']['batch_size'],
              num_threads=training_configs['tf_configs']['num_preprocessing_threads'],
              capacity=5 * training_configs['dataset_params']['batch_size'])
          labels = slim.one_hot_encoding(
              labels, dataset.num_classes - training_configs['dataset_params']['label_offset'])
          batch_queue = slim.prefetch_queue.prefetch_queue(
              [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
          """Allows data parallelism by creating multiple clones of network_fn."""
          images, labels = batch_queue.dequeue()
          logits, end_points = network_fn(images)

          #############################
          # Specify the loss function #
          #############################
          if 'AuxLogits' in end_points:
            slim.losses.softmax_cross_entropy(
                end_points['AuxLogits'], labels,
                label_smoothing=training_configs['learning_rate_params']['label_smoothing'], weights=0.4,
                scope='aux_loss')
          slim.losses.softmax_cross_entropy(
              logits, labels, label_smoothing=training_configs['learning_rate_params']['label_smoothing'], weights=1.0)
          return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
          x = end_points[end_point]
          summaries.add(tf.summary.histogram('activations/' + end_point, x))
          summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                          tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
          summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
          summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if training_configs['learning_rate_params']['moving_average_decay']:
          moving_average_variables = slim.get_model_variables()
          variable_averages = tf.train.ExponentialMovingAverage(
              training_configs['learning_rate_params']['moving_average_decay'], global_step)
        else:
          moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
          learning_rate = self._configure_learning_rate(dataset.num_samples, global_step)
          optimizer = self._configure_optimizer(learning_rate)
          summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if training_configs['learning_rate_params']['sync_replicas']:
          # If sync_replicas is enabled, the averaging will be done in the chief
          # queue runner.
          optimizer = tf.train.SyncReplicasOptimizer(
              opt=optimizer,
              replicas_to_aggregate=training_configs['learning_rate_params']['replicas_to_aggregate'],
              total_num_replicas=training_configs['tf_configs']['worker_replicas'],
              variable_averages=variable_averages,
              variables_to_average=moving_average_variables)
        elif training_configs['learning_rate_params']['moving_average_decay']:
          # Update ops executed locally by trainer.
          update_ops.append(variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = self._get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones,
            optimizer,
            var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
          train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                           first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        train_dir = training_configs['tf_configs']['train_dir']
        if not os.path.exists(train_dir):
          os.makedirs(train_dir)
        copy(training_configs['dataset_params']['dataset_dir'] + 'label_map.txt', training_configs['tf_configs']['train_dir']) ##
        weblog_dir = config['weblog_dir']
        if not os.path.exists(weblog_dir):
          os.makedirs(weblog_dir)

        logger = Logger('Training Monitor')      

        ###########################
        # Kicks off the training. #
        ###########################
        learning.train(
            train_tensor,
            logdir=train_dir,
            master=training_configs['tf_configs']['master'],
            is_chief=(training_configs['tf_configs']['task'] == 0),
            init_fn=self._get_init_fn(),
            summary_op=summary_op,
            log_every_n_steps=training_configs['tf_configs']['log_every_n_steps'],
            save_summaries_secs=training_configs['tf_configs']['save_summaries_secs'],
            save_interval_secs=training_configs['tf_configs']['save_interval_secs'],
            sync_optimizer=optimizer if training_configs['learning_rate_params']['sync_replicas'] else None,
            logger=logger,
            weblog_dir=weblog_dir)
    except:
      tf.logging.error('Unexpected error')
示例#3
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, config, graph_hook_fn=None):
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
  """

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

  with tf.Graph().as_default():
    # Build a configuration specifying multi-GPU and multi-replicas.
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=worker_replicas,
        num_ps_tasks=ps_tasks,
        worker_job_name=worker_job_name)

    # Place the global step on the device storing the variables.
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    with tf.device(deploy_config.inputs_device()):
      input_queue = create_input_queue(
          train_config.batch_size // num_clones, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)

    # Gather initial summaries.
    # TODO(rathodv): See if summaries can be added/extracted from global tf
    # collections so that they don't have to be passed around.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    global_summaries = set([])

    model_fn = functools.partial(_create_losses,
                                 create_model_fn=create_model_fn,
                                 train_config=train_config)
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
    first_clone_scope = clones[0].scope

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.device(deploy_config.optimizer_device()):
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      for var in optimizer_summary_vars:
        tf.summary.scalar(var.op.name, var, family='LearningRate')

    sync_optimizer = None
    if train_config.sync_replicas:
      training_optimizer = tf.train.SyncReplicasOptimizer(
          training_optimizer,
          replicas_to_aggregate=train_config.replicas_to_aggregate,
          total_num_replicas=worker_replicas)
      sync_optimizer = training_optimizer

    # Create ops required to initialize the model from a given checkpoint.
    init_fn = None
    if train_config.fine_tune_checkpoint:
      if not train_config.fine_tune_checkpoint_type:
        # train_config.from_detection_checkpoint field is deprecated. For
        # backward compatibility, fine_tune_checkpoint_type is set based on
        # from_detection_checkpoint.
        if train_config.from_detection_checkpoint:
          train_config.fine_tune_checkpoint_type = 'detection'
        else:
          train_config.fine_tune_checkpoint_type = 'classification'
      var_map = detection_model.restore_map(
          fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
          load_all_detection_checkpoint_vars=(
              train_config.load_all_detection_checkpoint_vars))
      available_var_map = (variables_helper.
                           get_variables_available_in_checkpoint(
                               var_map, train_config.fine_tune_checkpoint))
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn

    with tf.device(deploy_config.optimizer_device()):
      regularization_losses = (None if train_config.add_regularization_loss
                               else [])
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, training_optimizer,
          regularization_losses=regularization_losses)
      total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

      # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
      if train_config.bias_grad_multiplier:
        biases_regex_list = ['.*/biases']
        grads_and_vars = variables_helper.multiply_gradients_matching_regex(
            grads_and_vars,
            biases_regex_list,
            multiplier=train_config.bias_grad_multiplier)

      # Optionally freeze some layers by setting their gradients to be zero.
      if train_config.freeze_variables:
        grads_and_vars = variables_helper.freeze_gradients_matching_regex(
            grads_and_vars, train_config.freeze_variables)

      # Optionally clip gradients
      if train_config.gradient_clipping_by_norm > 0:
        with tf.name_scope('clip_grads'):
          grads_and_vars = learning.clip_gradient_norms(
              grads_and_vars, train_config.gradient_clipping_by_norm)

      # Create gradient updates.
      grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                        global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops, name='update_barrier')
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    if graph_hook_fn:
      with tf.device(deploy_config.variables_device()):
        graph_hook_fn()

    # Add summaries.
    for model_var in slim.get_model_variables():
      global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                model_var.op.name, model_var))
    for loss_tensor in tf.losses.get_losses():
      global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                             loss_tensor))
    global_summaries.add(
        tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))
    summaries |= global_summaries

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Save checkpoints regularly.
    keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
    saver = tf.train.Saver(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

    train_dir = config['train_dir']
    if not os.path.exists(train_dir):
      os.makedirs(train_dir)
    weblog_dir = config['weblog_dir']
    if not os.path.exists(weblog_dir):
      os.makedirs(weblog_dir)
    log_every_n_steps = int(config['log_every_n_steps'])

    logger = Logger('Training Monitor')
    learning.train(
        train_tensor,
        logdir=train_dir,
        log_every_n_steps=log_every_n_steps,
        master=master,
        is_chief=is_chief,
        session_config=session_config,
        startup_delay_steps=train_config.startup_delay_steps,
        init_fn=init_fn,
        summary_op=summary_op,
        number_of_steps=(
            train_config.num_steps if train_config.num_steps else None),
        save_summaries_secs=600,
        sync_optimizer=sync_optimizer,
        saver=saver,
        logger=logger,
        weblog_dir=weblog_dir)