Exemple #1
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.set_random_seed(FLAGS.seed)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')
    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training segmentation and self-attention on %s set',
                    FLAGS.train_split)
    if FLAGS.weakly:
        tf.logging.info('Training classification on %s set',
                        FLAGS.train_split_cls)
    else:
        tf.logging.info('Training classification on %s set', FLAGS.train_split)
    tf.logging.info('Enforcing consistency constraint on %s set',
                    FLAGS.train_split_cls)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True,
                output_valid=True,
                with_cls=True,
                cls_only=False)

            dataset_cls = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split_cls,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True,
                with_cls=FLAGS.weakly,
                cls_only=False,
                strong_weak=True)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_pseudo_seg
            model_args = (dataset.get_one_shot_iterator(),
                          dataset_cls.get_one_shot_iterator(), {
                              common.OUTPUT_TYPE: dataset.num_of_classes
                          }, dataset.ignore_label, clone_batch_size)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        if FLAGS.use_attention:
            summary = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'max_prob_weak')).strip('/'))
            summaries.add(tf.summary.histogram('max_prob_weak', summary))

            summary = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, 'max_att_prob_weak')).strip('/'))
            summaries.add(tf.summary.histogram('max_att_prob_weak', summary))

            summary = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'max_prob_avg')).strip('/'))
            summaries.add(tf.summary.histogram('max_prob_avg', summary))

            if FLAGS.soft_pseudo_label and FLAGS.temperature != 1.0:
                summary = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'max_prob_avg_t')).strip('/'))
                summaries.add(tf.summary.histogram('max_prob_avg_t', summary))

        # Add summaries for images, labels, semantic predictions
        # Visualize seg image and predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.IMAGE + '_seg')).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE + '_seg',
                                 summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.LABEL + '_seg')).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL + '_seg',
                                 summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE + '_seg')).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE + '_seg',
                                 summary_predictions))

            # For unlabeled image
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'valid_mask')).strip('/'))
            summaries.add(
                tf.summary.image('sanity_check/valid_mask', summary_image))

            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'weak')).strip('/'))
            summaries.add(tf.summary.image('unlabeled/weak', summary_image))

            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'strong')).strip('/'))
            summaries.add(tf.summary.image('unlabeled/strong', summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'unlabeled')).strip('/'))
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('unlabeled/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'logits_weak')).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('unlabeled/logits_weak', summary_predictions))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'logits_strong')).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)
            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('unlabeled/logits_strong',
                                 summary_predictions))

            if FLAGS.use_attention:
                first_clone_output = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'att_logits_weak')).strip('/'))
                predictions = tf.expand_dims(tf.argmax(first_clone_output, 3),
                                             -1)
                predictions = tf.compat.v1.image.resize_bilinear(
                    predictions, [int(sz) for sz in FLAGS.train_crop_size],
                    align_corners=True)
                summary_predictions = tf.cast(predictions * pixel_scaling,
                                              tf.uint8)
                summaries.add(
                    tf.summary.image('att/att_logits_weak',
                                     summary_predictions))

                first_clone_output = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope, 'cam_weak')).strip('/'))
                predictions = tf.expand_dims(tf.argmax(first_clone_output, 3),
                                             -1)
                predictions = tf.compat.v1.image.resize_bilinear(
                    predictions, [int(sz) for sz in FLAGS.train_crop_size],
                    align_corners=True)
                summary_predictions = tf.cast(predictions * pixel_scaling,
                                              tf.uint8)
                summaries.add(
                    tf.summary.image('att/cam_weak', summary_predictions))

                first_clone_output = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'merged_logits')).strip('/'))
                predictions = tf.expand_dims(tf.argmax(first_clone_output, 3),
                                             -1)
                predictions = tf.compat.v1.image.resize_bilinear(
                    predictions, [int(sz) for sz in FLAGS.train_crop_size],
                    align_corners=True)
                summary_predictions = tf.cast(predictions * pixel_scaling,
                                              tf.uint8)
                summaries.add(
                    tf.summary.image('att/merged_logits', summary_predictions))

                first_clone_output = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'att_logits_labeled')).strip('/'))
                predictions = tf.expand_dims(tf.argmax(first_clone_output, 3),
                                             -1)
                predictions = tf.compat.v1.image.resize_bilinear(
                    predictions, [int(sz) for sz in FLAGS.train_crop_size],
                    align_corners=True)
                summary_predictions = tf.cast(predictions * pixel_scaling,
                                              tf.uint8)
                summaries.add(
                    tf.summary.image('att/att_logits_labeled',
                                     summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Monitor pseudo label quality
        summary = graph.get_tensor_by_name(
            ('%s/%s:0' % (first_clone_scope, 'acc_seg')).strip('/'))
        summaries.add(tf.summary.scalar('sanity_check/acc_seg', summary))

        summary = graph.get_tensor_by_name(
            ('%s/%s:0' % (first_clone_scope, 'acc_weak')).strip('/'))
        summaries.add(tf.summary.scalar('sanity_check/acc_weak', summary))

        summary = graph.get_tensor_by_name(
            ('%s/%s:0' % (first_clone_scope, 'acc_strong')).strip('/'))
        summaries.add(tf.summary.scalar('sanity_check/acc_strong', summary))

        if FLAGS.pseudo_label_threshold > 0:
            summary = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'acc_pseudo')).strip('/'))
            summaries.add(tf.summary.scalar('sanity_check/acc_pseudo',
                                            summary))

            summary = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, 'acc_strong_confident')).strip('/'))
            summaries.add(
                tf.summary.scalar('sanity_check/acc_strong_confident',
                                  summary))

            summary = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, 'valid_ratio')).strip('/'))
            summaries.add(
                tf.summary.scalar('sanity_check/valid_ratio', summary))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)
            # NOTE: Neither last cls nor last seg layer loads pre-trained weights
            last_layers += [
                '{}/logits'.format(FLAGS.model_variant).replace('_beta', '')
            ]

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = True

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Exemple #2
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    common.outputlogMessage('Training on %s set' % FLAGS.train_split)
    common.outputlogMessage('Dataset: %s' % FLAGS.dataset)
    common.outputlogMessage('train_crop_size: %s' % str(FLAGS.train_crop_size))
    common.outputlogMessage(str(FLAGS.train_crop_size))
    common.outputlogMessage('atrous_rates: %s' % str(FLAGS.atrous_rates))
    common.outputlogMessage('number of classes: %s' % str(FLAGS.num_classes))
    common.outputlogMessage('Ignore label value: %s' % str(FLAGS.ignore_label))
    pid = os.getpid()
    with open('train_py_pid.txt', 'w') as f_obj:
        f_obj.writelines('%d' % pid)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True,
                num_classes=FLAGS.num_classes,
                ignore_label=FLAGS.ignore_label)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (dataset.get_one_shot_iterator(), {
                common.OUTPUT_TYPE: dataset.num_of_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        if FLAGS.quantize_delay_step >= 0:
            if FLAGS.num_clones > 1:
                raise ValueError(
                    'Quantization doesn\'t support multi-clone yet.')
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay_step)

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  # 设置多gpu训练的相关参数
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,  # gpu数量
      clone_on_cpu=FLAGS.clone_on_cpu,  # 默认为False
      replica_id=FLAGS.task,    # taskId
      num_replicas=FLAGS.num_replicas,  # 默认为1
      num_ps_tasks=FLAGS.num_ps_tasks)  # 默认为0

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones    # 各个gpu均分batch_size

  tf.gfile.MakeDirs(FLAGS.train_logdir)     # 创建存放训练日志的文件
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      dataset = data_generator.Dataset(     # 定义数据集参数
          dataset_name=FLAGS.dataset,   # 数据集名称 cityscapes
          split_name=FLAGS.train_split,  # 指定带有train的tfrecorder数据集 默认为“train”
          dataset_dir=FLAGS.dataset_dir,   # 数据集目录 tfrecoder文件的数据集目录
          batch_size=clone_batch_size,  # 均分后各个gpu训练中指定batch_size 的大小
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],  # 训练中裁剪的图像大小 513,513
          min_resize_value=FLAGS.min_resize_value,  # 默认为 None
          max_resize_value=FLAGS.max_resize_value,  # 默认为None
          resize_factor=FLAGS.resize_factor,    # 默认为None
          min_scale_factor=FLAGS.min_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最小为0.5
          max_scale_factor=FLAGS.max_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最大为2
          scale_factor_step_size=FLAGS.scale_factor_step_size,      # 训练中,图像变换尺度增加的步长,默认为0.25  从0.5到2
          model_variant=FLAGS.model_variant,    # 指定模型 xception_65
          num_readers=4,    # 读取数据个数 若多gpu可增大加快训练速度
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      # 计数作用,每训练一个batch, global加1
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab  # 定义deeplab模型
      model_args = (dataset.get_one_shot_iterator(), {
          common.OUTPUT_TYPE: dataset.num_of_classes
      }, dataset.ignore_label) #模型参数
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:      # 默认为False
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_of_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(  # 获取模型学习率
          FLAGS.learning_policy,    # poly学习策略
          FLAGS.base_learning_rate,     # 0.0001
          FLAGS.learning_rate_decay_step,   # 固定2000次进行一次学习率衰退
          FLAGS.learning_rate_decay_factor,     # 0.1
          FLAGS.training_number_of_steps,   # 训练次数 20000
          FLAGS.learning_power,     # poly power 0.9
          FLAGS.slow_start_step,    # 0
          FLAGS.slow_start_learning_rate,   # 1e-4 缓慢开始的学习率
          decay_steps=FLAGS.decay_steps,    # 0.0
          end_learning_rate=FLAGS.end_learning_rate)     # 0.0

      summaries.add(tf.summary.scalar('learning_rate', learning_rate))
      # 模型训练优化器
      if FLAGS.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      elif FLAGS.optimizer == 'adam':   # adam优化器 寻找全局最优点的优化算法,引入了二次方梯度校正
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
      else:
        raise ValueError('Unknown optimizer')

    if FLAGS.quantize_delay_step >= 0:  # 默认为-1 忽略
      if FLAGS.num_clones > 1:
        raise ValueError('Quantization doesn\'t support multi-clone yet.')
      contrib_quantize.create_training_graph(
          quant_delay=FLAGS.quantize_delay_step)

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps    # FLAGS.startup_delay_steps 默认为15

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)    # 计算total_loss
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      # 获取梯度乘子
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      # grad_mult : {'logits/semantic/biases': 2.0, 'logits/semantic/weights': 1.0}
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(     # 将计算的梯度用于变量上,返回一个应用指定的梯度的操作 opration
          grads_and_vars, global_step=global_step)  # 对global_step进行自增
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    profile_dir = FLAGS.profile_logdir   # 默认为None
    if profile_dir is not None:
      tf.gfile.MakeDirs(profile_dir)

    with contrib_tfprof.ProfileContext(
        enabled=profile_dir is not None, profile_dir=profile_dir):
      init_fn = None
      if FLAGS.tf_initial_checkpoint:   # 获取预训练权重
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_logdir,
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          number_of_steps=FLAGS.training_number_of_steps,
          is_chief=(FLAGS.task == 0),
          session_config=session_config,
          startup_delay_steps=startup_delay_steps,
          init_fn=init_fn,
          summary_op=summary_op,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)