Exemple #1
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    common.outputlogMessage('Training on %s set' % FLAGS.train_split)
    common.outputlogMessage('Dataset: %s' % FLAGS.dataset)
    common.outputlogMessage('train_crop_size: %s' % str(FLAGS.train_crop_size))
    common.outputlogMessage(str(FLAGS.train_crop_size))
    common.outputlogMessage('atrous_rates: %s' % str(FLAGS.atrous_rates))
    common.outputlogMessage('number of classes: %s' % str(FLAGS.num_classes))
    common.outputlogMessage('Ignore label value: %s' % str(FLAGS.ignore_label))
    pid = os.getpid()
    with open('train_py_pid.txt', 'w') as f_obj:
        f_obj.writelines('%d' % pid)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True,
                num_classes=FLAGS.num_classes,
                ignore_label=FLAGS.ignore_label)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (dataset.get_one_shot_iterator(), {
                common.OUTPUT_TYPE: dataset.num_of_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        if FLAGS.quantize_delay_step >= 0:
            if FLAGS.num_clones > 1:
                raise ValueError(
                    'Quantization doesn\'t support multi-clone yet.')
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay_step)

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
def _train_deeplab_model(iterator, num_of_classes, ignore_label):
    """Trains the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.

  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
    global_step = tf.train.get_or_create_global_step()
    summaries = []

    learning_rate = train_utils.get_model_learning_rate(
        FLAGS.learning_policy, FLAGS.base_learning_rate,
        FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
        FLAGS.training_number_of_steps, FLAGS.learning_power,
        FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
    summaries.append(tf.summary.scalar('learning_rate', learning_rate))

    optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

    tower_grads = []
    tower_summaries = None
    for i in range(FLAGS.num_clones):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('clone_%d' % i) as scope:
                loss = _tower_loss(iterator=iterator,
                                   num_of_classes=num_of_classes,
                                   ignore_label=ignore_label,
                                   scope=scope,
                                   reuse_variable=(i != 0))
                grads = optimizer.compute_gradients(loss)
                tower_grads.append(grads)

                # Retain the summaries from the first tower.
                if not i:
                    tower_summaries = tf.summary.merge_all(scope=scope)

    with tf.device('/cpu:0'):
        grads_and_vars = _average_gradients(tower_grads)
        if tower_summaries is not None:
            summaries.append(tower_summaries)

        # Modify the gradients for biases and last layer variables.
        last_layers = model.get_extra_layer_scopes(
            FLAGS.last_layers_contain_logits_only)
        grad_mult = train_utils.get_model_gradient_multipliers(
            last_layers, FLAGS.last_layer_gradient_multiplier)
        if grad_mult:
            grads_and_vars = tf.contrib.training.multiply_gradients(
                grads_and_vars, grad_mult)

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)

        total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

        # Print total loss to the terminal.
        # This implementation is mirrored from tf.slim.summaries.
        should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps),
                                    0)
        total_loss = tf.cond(
            should_log,
            lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
            lambda: total_loss)

        summaries.append(tf.summary.scalar('total_loss', total_loss))

        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')
        summary_op = tf.summary.merge(summaries)

    return train_tensor, summary_op
Exemple #3
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Exemple #4
0
def main(unused_argv):

  datasetDescriptor = None
  if FLAGS.config and os.path.isfile(FLAGS.config):
    with open(FLAGS.config) as f:
      trainingConfig = json.load(f)
      for key in trainingConfig:
        if key in FLAGS:
          FLAGS[key].value = trainingConfig[key]
        elif key == 'DatasetDescriptor':
          datasetDescriptor = segmentation_dataset.DatasetDescriptor(
                                name=trainingConfig[key]['name'],
                                splits_to_sizes=trainingConfig[key]['splits_to_sizes'],
                                num_classes=trainingConfig[key]['num_classes'],
                                ignore_label=trainingConfig[key]['ignore_label'],
                              )

  assert FLAGS.dataset_dir, (
      'flag --dataset_dir=None: Flag --dataset_dir must be specified.')

  assert FLAGS.train_logdir, (
      'flag --train_logdir=None: Flag --train_logdir must be specified.')

  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  if datasetDescriptor is None:
    datasetDescriptor = FLAGS.dataset

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      datasetDescriptor, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Exemple #5
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.batch_iter < 1:
        FLAGS.batch_iter = 1
    if FLAGS.batch_iter != 1:
        if not (FLAGS.num_clones == 1 and FLAGS.num_replicas == 1):
            raise NotImplementedError(
                "train.py: **NOTE** -- train_utils.train_step_custom may not work with parallel GPUs / clones > 1! Be sure you are only using one GPU."
            )

    print('\ntrain.py: Accumulating gradients over {} iterations\n'.format(
        FLAGS.batch_iter))

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(
        FLAGS.dataset,
        FLAGS.train_split,
        dataset_dir=FLAGS.dataset_dir,
    )

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            if FLAGS.class_balanced_loss:
                print(
                    'train.py: class_balanced_loss=True. Reading loss weights from segmentation_dataset.py'
                )
            else:
                print(
                    'train.py: class_balanced_loss=False. Setting loss weights to 1.0 for every class.'
                )
                dataset.loss_weight = 1.0

            #_build_deeplab has model args:
            #(inputs_queue, outputs_to_num_classes, ignore_label, loss_weight):

            outputs_to_num_classes = {common.OUTPUT_TYPE: dataset.num_classes}

            model_args = (inputs_queue,\
                          outputs_to_num_classes,
                          dataset.ignore_label, dataset.loss_weight)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):

            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            if FLAGS.batch_iter <= 1:
                FLAGS.batch_iter = 0
                summaries.add(tf.summary.scalar('total_loss', total_loss))
                grad_updates = optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                update_ops.append(grad_updates)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name='train_op')
                accum_tensor = None
            else:

                ############ Accumulate grads_and_vars op. ####################
                accum_update_ops = list(update_ops)  #.copy()
                # Create (grad, var) list to accumulate gradients in. Inititalize to 0.
                accum_grads_and_vars = [
                    (tf.Variable(tf.zeros_like(gv[0]),
                                 trainable=False,
                                 name=gv[0].name.strip(":0") + "_accum"),
                     gv[1]) for gv in grads_and_vars
                ]
                assert len(accum_grads_and_vars) == len(grads_and_vars)

                total_loss_accum = tf.Variable(0.0,
                                               dtype=tf.float32,
                                               trainable=False)
                accum_loss_update_op = [
                    total_loss_accum.assign_add(total_loss)
                ]
                accum_update_ops.append(accum_loss_update_op)

                ## Accumulate gradients: accum_grad[i] += (grad[i] / FLAGS.batch_iter)  # scaled gradients.
                accum_ops = [
                    accum_grads_and_vars[i][0].assign_add(
                        tf.div(gv[0], 1.0 * FLAGS.batch_iter))
                    for i, gv in enumerate(grads_and_vars)
                ]
                accum_update_ops.append(accum_ops)

                accum_update_op = tf.group(*accum_update_ops)
                with tf.control_dependencies([accum_update_op]):
                    accum_print_ops = []
                    if FLAGS.batch_iter_verbose:
                        accum_print_ops.extend([
                            tf.Print(
                                tf.constant(0), [tf.add(global_step, 1)],
                                message=
                                'train.py: accumulating gradients for step: '),
                            #tf.Print(total_loss, [total_loss], message='    step total_loss: ')
                            #tf.Print(tf.constant(0), [accum_grads_and_vars[0][0]], message='    '),
                        ])
                    accum_update_ops.append(accum_print_ops)
                    with tf.control_dependencies([tf.group(*accum_print_ops)]):
                        accum_tensor = tf.identity(total_loss_accum,
                                                   name='accum_op')

                ##################### Train op (apply [accumulated] grads and vars) ###############################
                train_update_ops = list(update_ops)  #.copy()
                ## Create gradient update op.
                # Apply gradients from accumulated gradients
                grad_updates = optimizer.apply_gradients(
                    accum_grads_and_vars, global_step=global_step)
                train_update_ops.append(grad_updates)

                grad_print_ops = []
                if FLAGS.batch_iter_verbose:
                    grad_print_ops.extend([
                        #                tf.Print(tf.constant(0), [grads_and_vars[0][0], grads_and_vars[0][1]], message='---grads[0] and vars[0]---------\n'),
                        #tf.Print(tf.constant(0), [], message=grads_and_vars[0][1].name),
                        tf.Print(tf.constant(0), [accum_grads_and_vars[0][0]],
                                 message='GRADS  BEFORE ZERO: ')
                    ])
                train_update_ops.append(grad_print_ops)

                total_loss_accum_average = tf.div(total_loss_accum,
                                                  FLAGS.batch_iter)
                summaries.add(
                    tf.summary.scalar('total_loss', total_loss_accum_average))

                train_update_op = tf.group(*train_update_ops)
                with tf.control_dependencies([train_update_op]):
                    zero_ops = []

                    zero_accum_ops = [
                        agv[0].assign(tf.zeros_like(agv[0]))
                        for agv in accum_grads_and_vars
                    ]
                    zero_ops.append(zero_accum_ops)

                    zero_accum_total_loss_op = [total_loss_accum.assign(0)]
                    zero_ops.append(zero_accum_total_loss_op)

                    zero_op = tf.group(*zero_ops)
                    with tf.control_dependencies([zero_op]):
                        grad_print_ops = []
                        if FLAGS.batch_iter_verbose:
                            grad_print_ops.extend([
                                #tf.Print(tf.constant(0), [accum_grads_and_vars[0][0]], message='GRADS AFTER ZERO ')
                            ])
                        with tf.control_dependencies(
                            [tf.group(*grad_print_ops)]):
                            train_tensor = tf.identity(
                                total_loss_accum_average, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or
        # _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        session_config.gpu_options.allow_growth = True

        #train_step_exit = train_utils.train_step_exit
        train_step_custom = train_utils.train_step_custom
        if FLAGS.validation_interval <= 0:
            FLAGS.validation_interval = FLAGS.training_number_of_steps
        else:
            print("*** Validation interval: {} ***".format(
                FLAGS.validation_interval))

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            train_step_fn=train_step_custom(
                                VALIDATION_N=FLAGS.validation_interval,
                                ACCUM_OP=accum_tensor,
                                ACCUM_STEPS=FLAGS.batch_iter),
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Exemple #6
0
def _train_deeplab_model(iterator, num_of_classes, ignore_label):
    """Trains the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.

  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
    global_step = tf.train.get_or_create_global_step()

    learning_rate = train_utils.get_model_learning_rate(
        FLAGS.learning_policy, FLAGS.base_learning_rate,
        FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
        FLAGS.training_number_of_steps, FLAGS.learning_power,
        FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)

    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

    tower_losses = []
    tower_grads = []
    for i in range(FLAGS.num_clones):
        with tf.device('/gpu:%d' % i):
            print("using gpu")
            # First tower has default name scope.
            name_scope = ('clone_%d' % i) if i else ''
            with tf.name_scope(name_scope) as scope:
                loss = _tower_loss(iterator=iterator,
                                   num_of_classes=num_of_classes,
                                   ignore_label=ignore_label,
                                   scope=scope,
                                   reuse_variable=(i != 0))
                tower_losses.append(loss)

    if FLAGS.quantize_delay_step >= 0:
        if FLAGS.num_clones > 1:
            raise ValueError('Quantization doesn\'t support multi-clone yet.')
        tf.contrib.quantize.create_training_graph(
            quant_delay=FLAGS.quantize_delay_step)

    for i in range(FLAGS.num_clones):
        with tf.device('/gpu:%d' % i):
            name_scope = ('clone_%d' % i) if i else ''
            with tf.name_scope(name_scope) as scope:
                grads = optimizer.compute_gradients(tower_losses[i])
                tower_grads.append(grads)

    with tf.device('/cpu:0'):
        grads_and_vars = _average_gradients(tower_grads)

        # Modify the gradients for biases and last layer variables.
        last_layers = model.get_extra_layer_scopes(
            FLAGS.last_layers_contain_logits_only)
        grad_mult = train_utils.get_model_gradient_multipliers(
            last_layers, FLAGS.last_layer_gradient_multiplier)
        if grad_mult:
            grads_and_vars = tf.contrib.training.multiply_gradients(
                grads_and_vars, grad_mult)

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)

        total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

        # Print total loss to the terminal.
        # This implementation is mirrored from tf.slim.summaries.
        should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps),
                                    0)
        total_loss = tf.cond(
            should_log,
            lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
            lambda: total_loss)
        tf.summary.scalar('total_loss', total_loss)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')
            # miou_train_tensor = tf.identitiy(miou, name='train_miou')
            # update_ops_tensor = tf.identity(miou_update_ops, name='train_iou_ops')

        # Excludes summaries from towers other than the first one.
        summary_op = tf.summary.merge_all(scope='(?!clone_)')
        # summary_op = tf.summary.merge_all()
        # print("Summary output: ", summary_op)
        # import pdb; pdb.set_trace()
    return train_tensor, summary_op
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  # 设置多gpu训练的相关参数
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,  # gpu数量
      clone_on_cpu=FLAGS.clone_on_cpu,  # 默认为False
      replica_id=FLAGS.task,    # taskId
      num_replicas=FLAGS.num_replicas,  # 默认为1
      num_ps_tasks=FLAGS.num_ps_tasks)  # 默认为0

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones    # 各个gpu均分batch_size

  tf.gfile.MakeDirs(FLAGS.train_logdir)     # 创建存放训练日志的文件
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      dataset = data_generator.Dataset(     # 定义数据集参数
          dataset_name=FLAGS.dataset,   # 数据集名称 cityscapes
          split_name=FLAGS.train_split,  # 指定带有train的tfrecorder数据集 默认为“train”
          dataset_dir=FLAGS.dataset_dir,   # 数据集目录 tfrecoder文件的数据集目录
          batch_size=clone_batch_size,  # 均分后各个gpu训练中指定batch_size 的大小
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],  # 训练中裁剪的图像大小 513,513
          min_resize_value=FLAGS.min_resize_value,  # 默认为 None
          max_resize_value=FLAGS.max_resize_value,  # 默认为None
          resize_factor=FLAGS.resize_factor,    # 默认为None
          min_scale_factor=FLAGS.min_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最小为0.5
          max_scale_factor=FLAGS.max_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最大为2
          scale_factor_step_size=FLAGS.scale_factor_step_size,      # 训练中,图像变换尺度增加的步长,默认为0.25  从0.5到2
          model_variant=FLAGS.model_variant,    # 指定模型 xception_65
          num_readers=4,    # 读取数据个数 若多gpu可增大加快训练速度
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      # 计数作用,每训练一个batch, global加1
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab  # 定义deeplab模型
      model_args = (dataset.get_one_shot_iterator(), {
          common.OUTPUT_TYPE: dataset.num_of_classes
      }, dataset.ignore_label) #模型参数
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:      # 默认为False
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_of_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(  # 获取模型学习率
          FLAGS.learning_policy,    # poly学习策略
          FLAGS.base_learning_rate,     # 0.0001
          FLAGS.learning_rate_decay_step,   # 固定2000次进行一次学习率衰退
          FLAGS.learning_rate_decay_factor,     # 0.1
          FLAGS.training_number_of_steps,   # 训练次数 20000
          FLAGS.learning_power,     # poly power 0.9
          FLAGS.slow_start_step,    # 0
          FLAGS.slow_start_learning_rate,   # 1e-4 缓慢开始的学习率
          decay_steps=FLAGS.decay_steps,    # 0.0
          end_learning_rate=FLAGS.end_learning_rate)     # 0.0

      summaries.add(tf.summary.scalar('learning_rate', learning_rate))
      # 模型训练优化器
      if FLAGS.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      elif FLAGS.optimizer == 'adam':   # adam优化器 寻找全局最优点的优化算法,引入了二次方梯度校正
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
      else:
        raise ValueError('Unknown optimizer')

    if FLAGS.quantize_delay_step >= 0:  # 默认为-1 忽略
      if FLAGS.num_clones > 1:
        raise ValueError('Quantization doesn\'t support multi-clone yet.')
      contrib_quantize.create_training_graph(
          quant_delay=FLAGS.quantize_delay_step)

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps    # FLAGS.startup_delay_steps 默认为15

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)    # 计算total_loss
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      # 获取梯度乘子
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      # grad_mult : {'logits/semantic/biases': 2.0, 'logits/semantic/weights': 1.0}
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(     # 将计算的梯度用于变量上,返回一个应用指定的梯度的操作 opration
          grads_and_vars, global_step=global_step)  # 对global_step进行自增
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    profile_dir = FLAGS.profile_logdir   # 默认为None
    if profile_dir is not None:
      tf.gfile.MakeDirs(profile_dir)

    with contrib_tfprof.ProfileContext(
        enabled=profile_dir is not None, profile_dir=profile_dir):
      init_fn = None
      if FLAGS.tf_initial_checkpoint:   # 获取预训练权重
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_logdir,
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          number_of_steps=FLAGS.training_number_of_steps,
          is_chief=(FLAGS.task == 0),
          session_config=session_config,
          startup_delay_steps=startup_delay_steps,
          init_fn=init_fn,
          summary_op=summary_op,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Exemple #8
0
    def train(self):
        FLAGS = self.flags
        dataset_split = 'train'
        data_config = edict()
        data_config.edge_width = 20
        data_config.ignore_label = DATASETS_IGNORE_LABEL[FLAGS.dataset]
        data_config.edge_class_num = FLAGS.edge_class_num
        img_files, label_files = get_dataset_files(FLAGS.dataset,
                                                   dataset_split)

        dataset = edict()
        dataset_pp = dataset_pipeline(data_config,
                                      img_files,
                                      label_files,
                                      is_train=True)
        dataset.num_classes = DATASETS_CLASS_NUM[FLAGS.dataset]
        dataset.ignore_label = DATASETS_IGNORE_LABEL[FLAGS.dataset]
        dataset.num_samples = len(dataset_pp)

        tf.logging.set_verbosity(tf.logging.INFO)
        # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
        config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                               clone_on_cpu=FLAGS.clone_on_cpu,
                                               replica_id=FLAGS.task,
                                               num_replicas=FLAGS.num_replicas,
                                               num_ps_tasks=FLAGS.num_ps_tasks)

        # Split the batch across GPUs.
        assert FLAGS.train_batch_size % config.num_clones == 0, (
            'Training batch size not divisble by number of clones (GPUs).')

        clone_batch_size = FLAGS.train_batch_size // config.num_clones

        # Get dataset-dependent information.
        #        dataset = segmentation_dataset.get_dataset(
        #            FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

        tf.gfile.MakeDirs(FLAGS.train_logdir)
        tf.logging.info('Training on %s set', FLAGS.train_split)

        with tf.Graph().as_default() as graph:
            with tf.device(config.inputs_device()):
                data_list = dataset_pp.iterator()
                samples = input_generator.get(
                    (data_list, dataset.ignore_label),
                    FLAGS.train_crop_size,
                    clone_batch_size,
                    min_resize_value=FLAGS.min_resize_value,
                    max_resize_value=FLAGS.max_resize_value,
                    resize_factor=FLAGS.resize_factor,
                    min_scale_factor=FLAGS.min_scale_factor,
                    max_scale_factor=FLAGS.max_scale_factor,
                    scale_factor_step_size=FLAGS.scale_factor_step_size,
                    dataset_split=FLAGS.train_split,
                    is_training=True,
                    model_variant=FLAGS.model_variant)
                inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                             capacity=128 *
                                                             config.num_clones)

            # Create the global step on the device storing the variables.
            with tf.device(config.variables_device()):
                global_step = tf.train.get_or_create_global_step()

                # Define the model and create clones.
                model_fn = self._build_deeplab
                model_args = (inputs_queue, {
                    common.OUTPUT_TYPE: dataset.num_classes
                }, dataset.ignore_label)
                clones = model_deploy.create_clones(config,
                                                    model_fn,
                                                    args=model_args)

                # Gather update_ops from the first clone. These contain, for example,
                # the updates for the batch_norm variables created by model_fn.
                first_clone_scope = config.clone_scope(0)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               first_clone_scope)

            # Gather initial summaries.
            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

            # Add summaries for model variables.
            for model_var in slim.get_model_variables():
                summaries.add(
                    tf.summary.histogram(model_var.op.name, model_var))

            label_name = ('%s/%s:0' %
                          (first_clone_scope, common.LABEL)).strip('/')
            print('first clone label name is:', label_name)

            # Add summaries for images, labels, semantic predictions
            if FLAGS.save_summaries_images:
                summary_image = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
                summaries.add(
                    tf.summary.image('samples/%s' % common.IMAGE,
                                     summary_image))

                first_clone_label = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))

                # Scale up summary image pixel values for better visualization.
                pixel_scaling = max(1, 255 // dataset.num_classes)
                summary_label = tf.cast(first_clone_label * pixel_scaling,
                                        tf.uint8)
                summaries.add(
                    tf.summary.image('samples/%s' % common.LABEL,
                                     summary_label))

                first_clone_output = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
                predictions = tf.expand_dims(tf.argmax(first_clone_output, 3),
                                             -1)

                summary_predictions = tf.cast(predictions * pixel_scaling,
                                              tf.uint8)
                summaries.add(
                    tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                     summary_predictions))

            # Add summaries for miou,acc
            labels = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            predictions = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.image.resize_bilinear(predictions,
                                                   tf.shape(labels)[1:3],
                                                   align_corners=True)

            labels = tf.reshape(labels, shape=[-1])
            predictions = tf.reshape(tf.argmax(predictions, 3), shape=[-1])
            weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))

            # Set ignore_label regions to label 0, because metrics.mean_iou requires
            # range of labels = [0, dataset.num_classes). Note the ignore_label regions
            # are not evaluated since the corresponding regions contain weights = 0.
            labels = tf.where(tf.equal(labels, dataset.ignore_label),
                              tf.zeros_like(labels), labels)

            # Define the evaluation metric.
            metric_map = {}
            metric_map['miou'], _ = tf.metrics.mean_iou(predictions,
                                                        labels,
                                                        dataset.num_classes,
                                                        weights=weights)
            metric_map['acc'], _ = tf.metrics.accuracy(
                labels=labels,
                predictions=predictions,
                weights=tf.reshape(weights, shape=[-1]))

            for x in ['miou', 'acc']:
                summaries.add(
                    tf.summary.scalar('metrics/%s' % x, metric_map[x]))

            # Add summaries for losses.
            for loss in tf.get_collection(tf.GraphKeys.LOSSES,
                                          first_clone_scope):
                summaries.add(
                    tf.summary.scalar('losses/%s' % loss.op.name, loss))

            # Build the optimizer based on the device specification.
            with tf.device(config.optimizer_device()):
                learning_rate = train_utils.get_model_learning_rate(
                    FLAGS.learning_policy, FLAGS.base_learning_rate,
                    FLAGS.learning_rate_decay_step,
                    FLAGS.learning_rate_decay_factor,
                    FLAGS.training_number_of_steps, FLAGS.learning_power,
                    FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
                summaries.add(tf.summary.scalar('learning_rate',
                                                learning_rate))

            startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
            for variable in slim.get_model_variables():
                summaries.add(tf.summary.histogram(variable.op.name, variable))

            with tf.device(config.variables_device()):
                total_loss, grads_and_vars = model_deploy.optimize_clones(
                    clones, optimizer)
                total_loss = tf.check_numerics(total_loss,
                                               'Loss is inf or nan.')
                summaries.add(tf.summary.scalar('total_loss', total_loss))

                # Modify the gradients for biases and last layer variables.
                last_layers = model.get_extra_layer_scopes(
                    FLAGS.last_layers_contain_logits_only)
                grad_mult = train_utils.get_model_gradient_multipliers(
                    last_layers, FLAGS.last_layer_gradient_multiplier)
                if grad_mult:
                    grads_and_vars = slim.learning.multiply_gradients(
                        grads_and_vars, grad_mult)

                # Create gradient update op.
                grad_updates = optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                update_ops.append(grad_updates)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name='train_op')

            # Add the summaries from the first clone. These contain the summaries
            # created by model_fn and either optimize_clones() or _gather_clone_loss().
            summaries |= set(
                tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

            # Merge all summaries together.
            summary_op = tf.summary.merge(list(summaries))

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)

            # Start the training.
            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=train_utils.get_model_init_fn(
                                    FLAGS.train_logdir,
                                    FLAGS.tf_initial_checkpoint,
                                    FLAGS.initialize_last_layer,
                                    last_layers,
                                    ignore_missing_vars=True),
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Exemple #9
0
def train(data_dicts,
          class_num,
          input_size,
          lr,
          n_epochs,
          num_clones,
          iters_cnt,
          val_every,
          model_init_fn,
          save_cback,
          atrous_rates=[6, 12, 18],
          fine_tune_batch_norm=True,
          output_stride=16):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=num_clones,
                                           clone_on_cpu=clone_on_cpu,
                                           replica_id=task,
                                           num_replicas=num_replicas,
                                           num_ps_tasks=num_ps_tasks)

    with tf.Graph().as_default():
        with tf.device(config.inputs_device()):
            samples = get(data_dicts['train'],
                          input_size,
                          is_training=True,
                          model_variant=model_variant)
            samples_val = get(data_dicts['val'],
                              input_size,
                              is_training=True,
                              model_variant=model_variant)

        inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                     capacity=128 *
                                                     config.num_clones,
                                                     dynamic_pad=True)
        inputs_queue_val = prefetch_queue.prefetch_queue(samples_val,
                                                         capacity=128 *
                                                         config.num_clones,
                                                         dynamic_pad=True)
        coord = tf.train.Coordinator()

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                'semantic': class_num
            }, input_size, atrous_rates, output_stride, fine_tune_batch_norm)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = lr
            optimizer = tf.train.AdamOptimizer(learning_rate)

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')

            model_fn_val = _build_deeplab_val
            model_args_val = (inputs_queue_val, {
                'semantic': class_num
            }, input_size, atrous_rates, output_stride)
            val_clones, val_losses = create_val_clones(num_clones,
                                                       config,
                                                       model_fn_val,
                                                       args=model_args_val)
            val_total_loss = get_clones_val_losses(val_clones, None,
                                                   val_losses)
            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes()
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        coord.clear_stop()
        sess = tf.Session(config=config)

        graph = ops.get_default_graph()
        with graph.as_default():
            with ops.name_scope('init_ops'):
                init_op = variables.global_variables_initializer()
                ready_op = variables.report_uninitialized_variables()
                local_init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

        # graph.finalize()
        sess.run([init_op, ready_op, local_init_op])
        queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
        threads = []
        for qr in queue_runners:
            threads.extend(
                qr.create_threads(sess, coord=coord, daemon=True, start=True))

        # # for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        # #     print(i)
        # vary_23 = [v for v in tf.global_variables() if v.name == 'xception_65/middle_flow/block1/unit_8/xception_module/separable_conv3_depthwise/BatchNorm/moving_mean:0'][0]
        #
        # beta_23 = [v for v in tf.global_variables() if v.name == 'xception_65/middle_flow/block1/unit_8/xception_module/separable_conv3_depthwise/BatchNorm/gamma:0'][0]
        # for i in range(1000):
        #     train_loss = sess.run(train_tensor)
        #     print(train_loss)
        #     vary, beta = sess.run([vary_23, beta_23])
        #     print('mean', vary[0:3])
        #     print('beta', beta[0:3])
        #     if (i + 1) % 10 == 0:
        #         for i in range(10):
        #             val_loss = sess.run(val_total_loss)
        #             vary, beta = sess.run([vary_23, beta_23])
        #             print('mean val', vary[0:3])
        #             print('beta', beta[0:3])
        #             print('VAl_loss', val_loss)

        model_init_fn(sess)
        saver = tf.train.Saver()
        eval_planner = EvalPlanner(n_epochs, val_every)
        progress = sly.progress_counter_train(n_epochs, iters_cnt['train'])
        best_val_loss = float('inf')
        epoch_flt = 0

        for epoch in range(n_epochs):
            logger.info("Before new epoch", extra={'epoch': epoch_flt})
            for train_it in range(iters_cnt['train']):
                total_loss = sess.run(train_tensor)

                metrics_values_train = {
                    'loss': total_loss,
                }

                progress.iter_done_report()
                epoch_flt = epoch_float(epoch, train_it + 1,
                                        iters_cnt['train'])
                sly.report_metrics_training(epoch_flt, metrics_values_train)

                if eval_planner.need_validation(epoch_flt):
                    logger.info("Before validation",
                                extra={'epoch': epoch_flt})

                    overall_val_loss = 0
                    for val_it in range(iters_cnt['val']):
                        overall_val_loss += sess.run(val_total_loss)

                        logger.info("Validation in progress",
                                    extra={
                                        'epoch': epoch_flt,
                                        'val_iter': val_it,
                                        'val_iters': iters_cnt['val']
                                    })

                    metrics_values_val = {
                        'loss': overall_val_loss / iters_cnt['val'],
                    }
                    sly.report_metrics_validation(epoch_flt,
                                                  metrics_values_val)
                    logger.info("Validation has been finished",
                                extra={'epoch': epoch_flt})

                    eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info(
                            'It\'s been determined that current model is the best one for a while.'
                        )

                    save_cback(saver,
                               sess,
                               model_is_best,
                               opt_data={
                                   'epoch': epoch_flt,
                                   'val_metrics': metrics_values_val,
                               })

            logger.info("Epoch was finished", extra={'epoch': epoch_flt})
Exemple #10
0
def main(unused_argv):
  print("DEEPLABv3+")
  print("SAVE TO "+FLAGS.train_logdir)
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)
  print("batch_norm: "+str(FLAGS.fine_tune_batch_norm))
  print("initialize_last_layer: "+str(FLAGS.initialize_last_layer))
  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
  dataset_val = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.val_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)
  tf.logging.info('Validating on %s set', FLAGS.val_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # 4 val

    samples_val = input_generator.get(
        dataset_val,
        FLAGS.train_crop_size,
        FLAGS.train_batch_size,
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        dataset_split=FLAGS.val_split,
        is_training=False,
        model_variant=FLAGS.model_variant)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # 4 val
    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
        crop_size=FLAGS.train_crop_size,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)
    predictions_val = model.predict_labels(samples_val[common.IMAGE], model_options,
                                         image_pyramid=FLAGS.image_pyramid)
    predictions_val = predictions_val[common.OUTPUT_TYPE]
    predictions_val = tf.reshape(predictions_val, shape=[-1])
    labels_val = tf.reshape(samples_val[common.LABEL], shape=[-1])

    # Set ignore_label regions to label 0, because metrics.mean_iou requires
    # range of labels = [0, dataset.num_classes). Note the ignore_label regions
    # are not evaluated since the corresponding regions contain weights = 0.
    #labels = tf.where(
    #    tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels)
    accuracy_validation = slim.metrics.accuracy(tf.to_int32(predictions_val),
                                                tf.to_int32(labels_val))
    iou,conf_mat = tf.metrics.mean_iou(labels_val, predictions_val, num_classes=6)
    #sess.run(tf.local_variables_initializer())


    def train_step_fn(session, *args, **kwargs):
        total_loss, should_stop = train_step(session, *args, **kwargs)

        if train_step_fn.step % FLAGS.validation_check == 0:
            pass
            # throws OutOfRange error after some time
         #   accuracy = session.run(train_step_fn.accuracy_validation)
          #  print('Step %s - Loss: %.2f Accuracy: %.2f%%' % (
          #  str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100))

     #   if train_step_fn.step == (FLAGS.max_steps - 1):
     #       accuracy = session.run(accuracy_test)
     #       print('%s - Loss: %.2f Accuracy: %.2f%%' % ('FINAL TEST', total_loss, accuracy * 100))

        train_step_fn.step += 1
        return [total_loss, should_stop]

    train_step_fn.step = 0
    train_step_fn.accuracy_validation = accuracy_validation

    # Start the training.
    slim.learning.train(
        train_tensor,
        train_step_fn=train_step_fn,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv):
    FLAGS.train_logdir = FLAGS.base_logdir + '/' + FLAGS.task_name
    if FLAGS.restore_name == None:
        FLAGS.restore_logdir = FLAGS.train_logdir
    else:
        FLAGS.restore_logdir = FLAGS.base_logdir + '/' + FLAGS.restore_name

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get logging dir ready.
    if not (os.path.isdir(FLAGS.train_logdir)):
        tf.gfile.MakeDirs(FLAGS.train_logdir)
    elif len(os.listdir(FLAGS.train_logdir)) != 0:
        if not (FLAGS.if_restore):
            if_delete_all = raw_input(
                '#### The log folder %s exists and non-empty; delete all logs? [y/n] '
                % FLAGS.train_logdir)
            if if_delete_all == 'y':
                os.system('rm -rf %s/*' % FLAGS.train_logdir)
                print '==== Log folder emptied.'
        else:
            print '==== Log folder exists; not emptying it because we need to restore from it.'
    tf.logging.info('==== Logging in dir:%s; Training on %s set',
                    FLAGS.train_logdir, FLAGS.train_split)

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.num_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)  # /device:CPU:0

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')
    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = regression_dataset.get_dataset(FLAGS.dataset,
                                             FLAGS.train_split,
                                             dataset_dir=FLAGS.dataset_dir)
    dataset_val = regression_dataset.get_dataset(FLAGS.dataset,
                                                 FLAGS.val_split,
                                                 dataset_dir=FLAGS.dataset_dir)
    print '#### The data has size:', dataset.num_samples, dataset_val.num_samples

    codes = np.load(
        '/ssd2/public/zhurui/Documents/mesh-voxelization/models/cars_64/codes.npy'
    )

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            codes_max = np.amax(codes, axis=1).reshape((-1, 1))
            codes_min = np.amin(codes, axis=1).reshape((-1, 1))
            shape_range = np.hstack(
                (codes_max + (codes_max - codes_min) /
                 (dataset.SHAPE_BINS - 1.), codes_min -
                 (codes_max - codes_min) / (dataset.SHAPE_BINS - 1.)))
            bin_range = [
                np.linspace(r[0], r[1], num=b).tolist()
                for r, b in zip(np.vstack((dataset.pose_range,
                                           shape_range)), dataset.bin_nums)
            ]
            # print np.vstack((dataset.pose_range, shape_range))
            # print bin_range[0]
            # print bin_range[-1]
            outputs_to_num_classes = {}
            outputs_to_indices = {}
            for output, bin_num, idx in zip(dataset.output_names,
                                            dataset.bin_nums,
                                            range(len(dataset.output_names))):
                if FLAGS.if_discrete_loss:
                    outputs_to_num_classes[output] = bin_num
                else:
                    outputs_to_num_classes[output] = 1
                outputs_to_indices[output] = idx
            bin_vals = [tf.constant(value=[bin_range[i]], dtype=tf.float32, shape=[1, dataset.bin_nums[i]], name=name) \
                    for i, name in enumerate(dataset.output_names)]
            # print outputs_to_num_classes
            # print spaces_to_indices

            samples = input_generator.get(dataset,
                                          codes,
                                          clone_batch_size,
                                          dataset_split=FLAGS.train_split,
                                          is_training=True,
                                          model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

            samples_val = input_generator.get(
                dataset_val,
                codes,
                clone_batch_size,
                dataset_split=FLAGS.val_split,
                is_training=False,
                model_variant=FLAGS.model_variant)
            inputs_queue_val = prefetch_queue.prefetch_queue(samples_val,
                                                             capacity=128)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (FLAGS, inputs_queue.dequeue(),
                          outputs_to_num_classes, outputs_to_indices, bin_vals,
                          bin_range, dataset, codes, True, False)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)  # clone_0
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        with tf.device('/device:GPU:3'):
            if FLAGS.if_val:
                ## Construct the validation graph; takes one GPU.
                _build_deeplab(FLAGS,
                               inputs_queue_val.dequeue(),
                               outputs_to_num_classes,
                               outputs_to_indices,
                               bin_vals,
                               bin_range,
                               dataset_val,
                               codes,
                               is_training=False,
                               reuse=True)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for images, labels, semantic predictions
        summary_loss_dict = {}
        if FLAGS.save_summaries_images:
            if FLAGS.num_clones > 1:
                pattern_train = first_clone_scope + '/%s:0'
            else:
                pattern_train = '%s:0'
            pattern_val = 'val-%s:0'
            pattern = pattern_val if FLAGS.if_val else pattern_train

            gather_list = [0] if FLAGS.num_clones < 3 else [0, 1, 2]

            summary_mask = graph.get_tensor_by_name(pattern %
                                                    'not_ignore_mask_in_loss')
            summary_mask = tf.reshape(summary_mask,
                                      [-1, dataset.height, dataset.width, 1])
            summary_mask_float = tf.to_float(summary_mask)
            summaries.add(
                tf.summary.image(
                    'gt/%s' % 'not_ignore_mask',
                    tf.gather(tf.cast(summary_mask_float * 255., tf.uint8),
                              gather_list)))

            summary_image = graph.get_tensor_by_name(pattern % common.IMAGE)
            summaries.add(
                tf.summary.image('gt/%s' % common.IMAGE,
                                 tf.gather(summary_image, gather_list)))

            summary_image_name = graph.get_tensor_by_name(pattern %
                                                          common.IMAGE_NAME)
            summaries.add(
                tf.summary.text('gt/%s' % common.IMAGE_NAME,
                                tf.gather(summary_image_name, gather_list)))

            summary_image_name = graph.get_tensor_by_name(pattern_train %
                                                          common.IMAGE_NAME)
            summaries.add(
                tf.summary.text('gt/%s_train' % common.IMAGE_NAME,
                                tf.gather(summary_image_name, gather_list)))

            summary_vis = graph.get_tensor_by_name(pattern % 'vis')
            summaries.add(
                tf.summary.image('gt/%s' % 'vis',
                                 tf.gather(summary_vis, gather_list)))

            def scale_to_255(tensor, pixel_scaling=None):
                tensor = tf.to_float(tensor)
                if pixel_scaling == None:
                    offset_to_zero = tf.reduce_min(tensor)
                    scale_to_255 = tf.div(
                        255., tf.reduce_max(tensor - offset_to_zero))
                else:
                    offset_to_zero, scale_to_255 = pixel_scaling
                summary_tensor_float = tensor - offset_to_zero
                summary_tensor_float = summary_tensor_float * scale_to_255
                summary_tensor_float = tf.clip_by_value(
                    summary_tensor_float, 0., 255.)
                summary_tensor_uint8 = tf.cast(summary_tensor_float, tf.uint8)
                return summary_tensor_uint8, (offset_to_zero, scale_to_255)

            label_outputs = graph.get_tensor_by_name(pattern %
                                                     'label_pose_shape_map')
            label_id_outputs = graph.get_tensor_by_name(
                pattern % 'pose_shape_label_id_map')
            logit_outputs = graph.get_tensor_by_name(
                pattern % 'scaled_prob_logits_pose_shape_map')

            summary_rot_diffs = graph.get_tensor_by_name(pattern %
                                                         'rot_error_map')
            summary_rot_diffs = tf.where(summary_mask, summary_rot_diffs,
                                         tf.zeros_like(summary_rot_diffs))
            summary_rot_diffs_uint8, _ = scale_to_255(summary_rot_diffs)
            summaries.add(
                tf.summary.image(
                    'metrics_map/%s' % 'rot_diffs',
                    tf.gather(summary_rot_diffs_uint8, gather_list)))

            summary_trans_diffs = graph.get_tensor_by_name(pattern %
                                                           'trans_error_map')
            summary_trans_diffs = tf.where(summary_mask, summary_trans_diffs,
                                           tf.zeros_like(summary_trans_diffs))
            summary_trans_diffs_uint8, _ = scale_to_255(summary_trans_diffs)
            summaries.add(
                tf.summary.image('metrics_map/%s' % 'trans_diffs',
                                 tf.gather(summary_trans_diffs, gather_list)))

            shape_id_outputs = graph.get_tensor_by_name(pattern %
                                                        'shape_id_map')
            shape_id_outputs = tf.where(summary_mask, shape_id_outputs + 1,
                                        tf.zeros_like(shape_id_outputs))
            summary_shape_id_output_uint8, _ = scale_to_255(shape_id_outputs)
            summaries.add(
                tf.summary.image(
                    'shape/shape_id_map',
                    tf.gather(summary_shape_id_output_uint8, gather_list)))

            shape_id_outputs_gt = graph.get_tensor_by_name(pattern %
                                                           'shape_id_map_gt')
            shape_id_outputs_gt = tf.where(summary_mask,
                                           shape_id_outputs_gt + 1,
                                           tf.zeros_like(shape_id_outputs))
            summary_shape_id_output_uint8_gt, _ = scale_to_255(
                shape_id_outputs_gt)
            summaries.add(
                tf.summary.image(
                    'shape/shape_id_map_gt',
                    tf.gather(summary_shape_id_output_uint8_gt, gather_list)))

            if FLAGS.if_summary_metrics:
                shape_id_outputs = graph.get_tensor_by_name(
                    pattern % 'shape_id_map_predict')
                summary_shape_id_output = tf.where(
                    summary_mask, shape_id_outputs,
                    tf.zeros_like(shape_id_outputs))
                summary_shape_id_output_uint8, _ = scale_to_255(
                    summary_shape_id_output)
                summaries.add(
                    tf.summary.image(
                        'shape/shape_id_map_predict',
                        tf.gather(summary_shape_id_output_uint8, gather_list)))

                shape_id_sim_map_train = graph.get_tensor_by_name(
                    pattern_train % 'shape_id_sim_map')
                # shape_id_sim_map_train = tf.where(summary_mask, shape_id_sim_map_train, tf.zeros_like(shape_id_sim_map_train))
                shape_id_sim_map_uint8_train, _ = scale_to_255(
                    shape_id_sim_map_train, pixel_scaling=(0., 255.))
                summaries.add(
                    tf.summary.image(
                        'metrics_map/shape_id_sim_map-trainInv',
                        tf.gather(shape_id_sim_map_uint8_train, gather_list)))

                shape_id_sim_map = graph.get_tensor_by_name(pattern %
                                                            'shape_id_sim_map')
                # shape_id_sim_map = tf.where(summary_mask, shape_id_sim_map, tf.zeros_like(shape_id_sim_map))
                shape_id_sim_map_uint8, _ = scale_to_255(shape_id_sim_map,
                                                         pixel_scaling=(0.,
                                                                        255.))
                summaries.add(
                    tf.summary.image(
                        'metrics_map/shape_id_sim_map-valInv',
                        tf.gather(shape_id_sim_map_uint8, gather_list)))

            for output_idx, output in enumerate(dataset.output_names):
                # # Scale up summary image pixel values for better visualization.
                summary_label_output = tf.gather(label_outputs, [output_idx],
                                                 axis=3)
                summary_label_output = tf.where(
                    summary_mask, summary_label_output,
                    tf.zeros_like(summary_label_output))
                summary_label_output_uint8, pixel_scaling = scale_to_255(
                    summary_label_output)
                summaries.add(
                    tf.summary.image(
                        'output/%s_label' % output,
                        tf.gather(summary_label_output_uint8, gather_list)))

                summary_logit_output = tf.gather(logit_outputs, [output_idx],
                                                 axis=3)
                summary_logit_output = tf.where(
                    summary_mask, summary_logit_output,
                    tf.zeros_like(summary_logit_output))
                summary_logit_output_uint8, _ = scale_to_255(
                    summary_logit_output, pixel_scaling)
                summaries.add(
                    tf.summary.image(
                        'output/%s_logit' % output,
                        tf.gather(summary_logit_output_uint8, gather_list)))

                # summary_label_id_output = tf.to_float(tf.gather(label_id_outputs, [output_idx], axis=3))
                # summary_label_id_output = tf.where(summary_mask, summary_label_id_output+1, tf.zeros_like(summary_label_id_output))
                # summary_label_id_output_uint8, _ = scale_to_255(summary_label_id_output)
                # summary_label_id_output_uint8 = tf.identity(summary_label_id_output_uint8, 'tttt'+output)
                # summaries.add(tf.summary.image(
                #     'test/%s_label_id' % output, tf.gather(summary_label_id_output_uint8, gather_list)))

                summary_diff = tf.abs(
                    tf.to_float(summary_label_output_uint8) -
                    tf.to_float(summary_logit_output_uint8))
                summary_diff = tf.where(summary_mask, summary_diff,
                                        tf.zeros_like(summary_diff))
                summaries.add(
                    tf.summary.image(
                        'diff_map/%s_ldiff' % output,
                        tf.gather(tf.cast(summary_diff, tf.uint8),
                                  gather_list)))

                summary_loss = graph.get_tensor_by_name(
                    (pattern % 'loss_slice_reg_').replace(':0', '') + output +
                    ':0')
                summaries.add(
                    tf.summary.scalar(
                        'slice_loss/' + (pattern % 'reg_').replace(':0', '') +
                        output, summary_loss))

                summary_loss = graph.get_tensor_by_name(
                    (pattern % 'loss_slice_cls_').replace(':0', '') + output +
                    ':0')
                summaries.add(
                    tf.summary.scalar(
                        'slice_loss/' + (pattern % 'cls_').replace(':0', '') +
                        output, summary_loss))

            for pattern in [pattern_train, pattern_val
                            ] if FLAGS.if_val else [pattern_train]:
                add_metrics = ['loss_all_shape_id_cls_metric'
                               ] if FLAGS.if_summary_metrics else []
                for loss_name in [
                        'loss_reg_rot_quat_metric', 'loss_reg_rot_quat',
                        'loss_reg_trans_metric', 'loss_reg_trans',
                        'loss_cls_ALL', 'loss_reg_shape'
                ] + add_metrics:
                    if pattern == pattern_val:
                        summary_loss_avg = graph.get_tensor_by_name(pattern %
                                                                    loss_name)
                        # summary_loss_dict['val-'+loss_name] = summary_loss_avg
                    else:
                        summary_loss_avg = train_utils.get_avg_tensor_from_scopes(
                            FLAGS.num_clones, '%s:0', graph, config, loss_name)
                        # summary_loss_dict['train-'+loss_name] = summary_loss_avg
                    summaries.add(
                        tf.summary.scalar(
                            ('total_loss/' + pattern % loss_name).replace(
                                ':0', ''), summary_loss_avg))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            # optimizer = tf.train.AdamOptimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            print '------ total_loss', total_loss, tf.get_collection(
                tf.GraphKeys.LOSSES, first_clone_scope)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss/train', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            print '////last layers', last_layers

            # Filter trainable variables for last layers ONLY.
            # grads_and_vars = train_utils.filter_gradients(last_layers, grads_and_vars)

            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = False

        def train_step_fn(sess, train_op, global_step, train_step_kwargs):
            train_step_fn.step += 1  # or use global_step.eval(session=sess)

            # calc training losses
            loss, should_stop = slim.learning.train_step(
                sess, train_op, global_step, train_step_kwargs)
            print loss
            # print 'loss: ', loss
            # first_clone_test = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'shape_map')).strip('/'))
            # test = sess.run(first_clone_test)
            # # print test
            # print 'test: ', test.shape, np.max(test), np.min(test), np.mean(test), test.dtype
            should_stop = 0

            if FLAGS.if_val and train_step_fn.step % FLAGS.val_interval_steps == 0:
                # first_clone_test = graph.get_tensor_by_name('val-loss_all:0')
                # test = sess.run(first_clone_test)
                print '-- Validating...'
                first_clone_test = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'shape_id_map')).strip('/'))
                first_clone_test2 = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'shape_id_sim_map')).strip('/'))
                # 'ttttrow:0')
                first_clone_test3 = graph.get_tensor_by_name((
                    '%s/%s:0' %
                    (first_clone_scope, 'not_ignore_mask_in_loss')).strip('/'))
                # 'ttttrow:0')
                test_out, test_out2, test_out3 = sess.run(
                    [first_clone_test, first_clone_test2, first_clone_test3])
                # test_out = test[:, :, :, 3]
                test_out = test_out[test_out3]
                # test_out2 = test2[:, :, :, 3]
                test_out2 = test_out2[test_out3]
                # print test_out
                print 'shape_id_map: ', test_out.shape, np.max(
                    test_out), np.min(test_out), np.mean(test_out), np.median(
                        test_out), test_out.dtype
                print 'shape_id_sim_map: ', test_out2.shape, np.max(
                    test_out2), np.min(test_out2), np.mean(
                        test_out2), np.median(test_out2), test_out2.dtype
                print 'masks sum: ', test_out3.dtype, np.sum(
                    test_out3.astype(float))
                # assert np.max(test_out) == np.max(test_out2), 'MAtch1!!!'
                # assert np.min(test_out) == np.min(test_out2), 'MAtch2!!!'

            # first_clone_label = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'pose_map')).strip('/')) # clone_0/val-loss:0
            # # first_clone_pose_dict = graph.get_tensor_by_name(
            # #         ('%s/%s:0' % (first_clone_scope, 'pose_dict')).strip('/'))
            # first_clone_logit = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'scaled_regression')).strip('/'))
            # not_ignore_mask = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'not_ignore_mask_in_loss')).strip('/'))
            # label, logits, mask = sess.run([first_clone_label, first_clone_logit, not_ignore_mask])
            # mask = np.reshape(mask, (-1, FLAGS.train_crop_size[0], FLAGS.train_crop_size[1], dataset.num_classes))

            # print '... shapes, types, loss', label.shape, label.dtype, logits.shape, logits.dtype, loss
            # print 'mask', mask.shape, np.mean(mask)
            # logits[mask==0.] = 0.
            # print 'logits', logits.shape, np.max(logits), np.min(logits), np.mean(logits), logits.dtype
            # for idx in range(6):
            #     print idx, np.max(label[:, :, :, idx]), np.min(label[:, :, :, idx])
            # label = label[:, :, :, 5]
            # print 'label', label.shape, np.max(label), np.min(label), np.mean(label), label.dtype
            # print pose_dict, pose_dict.shape
            # # print 'training....... logits stats: ', np.max(logits), np.min(logits), np.mean(logits)
            # # label_one_piece = label[0, :, :, 0]
            # # print 'training....... label stats', np.max(label_one_piece), np.min(label_one_piece), np.sum(label_one_piece[label_one_piece!=255.])
            return [loss, should_stop]

        train_step_fn.step = 0

        # trainables = [v.name for v in tf.trainable_variables()]
        # alls =[v.name for v in tf.all_variables()]
        # print '----- Trainables %d: '%len(trainables), trainables
        # print '----- All %d: '%len(alls), alls[:10]
        # print '===== ', len(list(set(trainables) - set(alls)))
        # print '===== ', len(list(set(alls) - set(trainables)))

        if FLAGS.if_print_tensors:
            for op in tf.get_default_graph().get_operations():
                print str(op.name)

        # Start the training.
        slim.learning.train(train_tensor,
                            train_step_fn=train_step_fn,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.restore_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.if_restore,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Exemple #12
0
    def train(self):
        FLAGS = self.flags
        dataset_split = 'train'
        data_config = edict()
        data_config.edge_width = 20
        data_config.ignore_label = DATASETS_IGNORE_LABEL[FLAGS.dataset]
        data_config.edge_class_num = FLAGS.edge_class_num
        img_files, label_files = get_dataset_files(FLAGS.dataset,
                                                   dataset_split)

        dataset = edict()
        dataset_pp = dataset_pipeline(data_config,
                                      img_files,
                                      label_files,
                                      is_train=True)
        dataset.num_classes = DATASETS_CLASS_NUM[FLAGS.dataset]
        dataset.ignore_label = DATASETS_IGNORE_LABEL[FLAGS.dataset]
        dataset.num_samples = len(dataset_pp)

        tf.logging.set_verbosity(tf.logging.INFO)
        # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
        config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                               clone_on_cpu=FLAGS.clone_on_cpu,
                                               replica_id=FLAGS.task,
                                               num_replicas=FLAGS.num_replicas,
                                               num_ps_tasks=FLAGS.num_ps_tasks)

        # Split the batch across GPUs.
        assert FLAGS.train_batch_size % config.num_clones == 0, (
            'Training batch size not divisble by number of clones (GPUs).')

        clone_batch_size = FLAGS.train_batch_size // config.num_clones

        # Get dataset-dependent information.
        #        dataset = segmentation_dataset.get_dataset(
        #            FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

        tf.gfile.MakeDirs(FLAGS.train_logdir)
        tf.logging.info('Training on %s set', FLAGS.train_split)

        with tf.Graph().as_default() as graph:
            with tf.device(config.inputs_device()):
                data_list = dataset_pp.iterator()
                samples = input_generator.get(
                    (data_list, dataset.ignore_label),
                    FLAGS.train_crop_size,
                    clone_batch_size,
                    min_resize_value=FLAGS.min_resize_value,
                    max_resize_value=FLAGS.max_resize_value,
                    resize_factor=FLAGS.resize_factor,
                    min_scale_factor=FLAGS.min_scale_factor,
                    max_scale_factor=FLAGS.max_scale_factor,
                    scale_factor_step_size=FLAGS.scale_factor_step_size,
                    dataset_split=FLAGS.train_split,
                    is_training=True,
                    model_variant=FLAGS.model_variant)
                inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                             capacity=128 *
                                                             config.num_clones)

            # Create the global step on the device storing the variables.
            with tf.device(config.variables_device()):
                global_step = tf.train.get_or_create_global_step()

                # Define the model and create clones.
                model_fn = self._build_deeplab
                model_args = (inputs_queue, {
                    common.OUTPUT_TYPE: dataset.num_classes
                }, dataset.ignore_label)
                clones = model_deploy.create_clones(config,
                                                    model_fn,
                                                    args=model_args)

                # Gather update_ops from the first clone. These contain, for example,
                # the updates for the batch_norm variables created by model_fn.
                first_clone_scope = config.clone_scope(0)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               first_clone_scope)

            # Gather initial summaries.
            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

            # Add summaries for model variables.
            for model_var in slim.get_model_variables():
                summaries.add(
                    tf.summary.histogram(model_var.op.name, model_var))

            label_name = ('%s/%s:0' %
                          (first_clone_scope, common.LABEL)).strip('/')
            print('first clone label name is:', label_name)

            # Add summaries for images, labels, semantic predictions
            if FLAGS.save_summaries_images:
                summary_image = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
                summaries.add(
                    tf.summary.image('samples/%s' % common.IMAGE,
                                     summary_image))

                first_clone_label = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))

                # Scale up summary image pixel values for better visualization.
                pixel_scaling = max(1, 255 // dataset.num_classes)
                summary_label = tf.cast(first_clone_label * pixel_scaling,
                                        tf.uint8)
                summaries.add(
                    tf.summary.image('samples/%s' % common.LABEL,
                                     summary_label))

                first_clone_output = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
                predictions = tf.expand_dims(tf.argmax(first_clone_output, 3),
                                             -1)

                summary_predictions = tf.cast(predictions * pixel_scaling,
                                              tf.uint8)
                summaries.add(
                    tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                     summary_predictions))

            # Add summaries for miou,acc
            labels = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            predictions = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.image.resize_bilinear(predictions,
                                                   tf.shape(labels)[1:3],
                                                   align_corners=True)
            # predictions shape (2, 513, 513, 19/21)
            print('predictions shape', predictions.shape)
            self.get_metric(labels, predictions, 'train')

            # Add summaries for losses.
            for loss in tf.get_collection(tf.GraphKeys.LOSSES,
                                          first_clone_scope):
                summaries.add(
                    tf.summary.scalar('losses/%s' % loss.op.name, loss))

#            losses = {}
#            for key in [common.OUTPUT_TYPE,common.EDGE]:
#                losses[key]=graph.get_tensor_by_name(name='losses/%s:0'%key)
#                summaries.add(tf.summary.scalar('losses/'+key,losses[key]))

# Build the optimizer based on the device specification.
            with tf.device(config.optimizer_device()):
                learning_rate = train_utils.get_model_learning_rate(
                    FLAGS.learning_policy, FLAGS.base_learning_rate,
                    FLAGS.learning_rate_decay_step,
                    FLAGS.learning_rate_decay_factor,
                    FLAGS.training_number_of_steps, FLAGS.learning_power,
                    FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
                summaries.add(tf.summary.scalar('learning_rate',
                                                learning_rate))

            startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
            for variable in slim.get_model_variables():
                summaries.add(tf.summary.histogram(variable.op.name, variable))

            with tf.device(config.variables_device()):
                total_loss, grads_and_vars = model_deploy.optimize_clones(
                    clones, optimizer)
                total_loss = tf.check_numerics(total_loss,
                                               'Loss is inf or nan.')
                summaries.add(
                    tf.summary.scalar('losses/total_loss', total_loss))

                # Modify the gradients for biases and last layer variables.
                last_layers = model.get_extra_layer_scopes(
                    FLAGS.last_layers_contain_logits_only)
                grad_mult = train_utils.get_model_gradient_multipliers(
                    last_layers, FLAGS.last_layer_gradient_multiplier)
                if grad_mult:
                    grads_and_vars = slim.learning.multiply_gradients(
                        grads_and_vars, grad_mult)

                # Create gradient update op.
                grad_updates = optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                update_ops.append(grad_updates)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name='train_op')

            # Add the summaries from the first clone. These contain the summaries
            # created by model_fn and either optimize_clones() or _gather_clone_loss().
            summaries |= set(
                tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

            # Merge all summaries together.
            summary_op = tf.summary.merge(list(summaries))

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)
            session_config.gpu_options.allow_growth = True

            #            init_fn=train_utils.get_model_init_fn(
            #                    FLAGS.train_logdir,
            #                    FLAGS.tf_initial_checkpoint,
            #                    FLAGS.initialize_last_layer,
            #                    last_layers,
            #                    ignore_missing_vars=True)

            exclude_list = ['global_step']
            if not FLAGS.initialize_last_layer:
                exclude_list.extend(last_layers)
            variables_to_restore = slim.get_variables_to_restore(
                exclude=exclude_list)
            init_fn = slim.assign_from_checkpoint_fn(
                model_path=FLAGS.tf_initial_checkpoint,
                var_list=variables_to_restore,
                ignore_missing_vars=True)
            #            saver = tf.train.Saver()
            #            train_writer = tf.summary.FileWriter(FLAGS.train_logdir)
            #            sess=tf.Session(config=session_config)
            #            init_fn(sess)
            #            sess.run(tf.global_variables_initializer())
            #            sess.run(tf.local_variables_initializer())
            #            tf.train.start_queue_runners(sess)
            #
            #            for i in trange(FLAGS.training_number_of_steps):
            #                loss,summary,n_step=sess.run([train_tensor,summary_op,global_step])
            #                train_writer.add_summary(summary,i)
            #                if i%100==1:
            #                    tqdm.write('%d/%d global_step=%0.2f, loss=%0.5f'%(i,FLAGS.training_number_of_steps,n_step,loss))
            #
            #            saver.save(sess,os.path.join(FLAGS.train_logdir,'model'),global_step=FLAGS.training_number_of_steps)
            #            train_writer.close()
            # Start the training.
            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Exemple #13
0
def _train_deeplab_model(iterator, num_of_classes, ignore_label):
  """Trains the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.

  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
  global_step = tf.train.get_or_create_global_step()
  summaries = []

  learning_rate = train_utils.get_model_learning_rate(
      FLAGS.learning_policy, FLAGS.base_learning_rate,
      FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
      FLAGS.training_number_of_steps, FLAGS.learning_power,
      FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
  summaries.append(tf.summary.scalar('learning_rate', learning_rate))

  optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

  tower_grads = []
  tower_summaries = None
  for i in range(FLAGS.num_clones):
    with tf.device('/gpu:%d' % i):
      with tf.name_scope('clone_%d' % i) as scope:
        loss = _tower_loss(
            iterator=iterator,
            num_of_classes=num_of_classes,
            ignore_label=ignore_label,
            scope=scope,
            reuse_variable=(i != 0))
        grads = optimizer.compute_gradients(loss)
        tower_grads.append(grads)

        # Retain the summaries from the first tower.
        if not i:
          tower_summaries = tf.summary.merge_all(scope=scope)

  with tf.device('/cpu:0'):
    grads_and_vars = _average_gradients(tower_grads)
    if tower_summaries is not None:
      summaries.append(tower_summaries)

    # Modify the gradients for biases and last layer variables.
    last_layers = model.get_extra_layer_scopes(
        FLAGS.last_layers_contain_logits_only)
    grad_mult = train_utils.get_model_gradient_multipliers(
        last_layers, FLAGS.last_layer_gradient_multiplier)
    if grad_mult:
      grads_and_vars = tf.contrib.training.multiply_gradients(
          grads_and_vars, grad_mult)

    # Create gradient update op.
    grad_updates = optimizer.apply_gradients(
        grads_and_vars, global_step=global_step)

    # Gather update_ops. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    update_ops.append(grad_updates)
    update_op = tf.group(*update_ops)

    total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

    # Print total loss to the terminal.
    # This implementation is mirrored from tf.slim.summaries.
    should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0)
    total_loss = tf.cond(
        should_log,
        lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
        lambda: total_loss)

    summaries.append(tf.summary.scalar('total_loss', total_loss))

    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')
    summary_op = tf.summary.merge(summaries)

  return train_tensor, summary_op
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size / config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default():
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes()
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)
    session_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth

    # Save checkpoints regularly.
    saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        saver=saver,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Exemple #15
0
def main(unused_argv):
    # syaru: Sets the threshold(入口) for what messages will be logged. 加上这句才能输出训练过程的log.
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    # syaru: models/research/slim/deployment/model_deploy.DeploymentConfig(object)
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = int(FLAGS.train_batch_size / config.num_clones)

    # Get dataset-dependent information.
    """
  syaru: deeplab/datasets/segmentation_dataset.get_dataset()
  Gets an instance of slim Dataset.
  Args:
    dataset_name: Dataset name.
    split_name: A train/val Split name.
    dataset_dir: The directory of the dataset sources.
  """
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)

    tf.gfile.MakeDirs(
        FLAGS.train_logdir
    )  # sayru: FLAGS.train_logdir = "pascal_voc_seg/exp/train_on_trainval_set/train"
    tf.logging.info('Training on %s set',
                    FLAGS.train_split)  #        FLAGS.train_split = "trainval"

    with tf.Graph().as_default() as graph:
        with tf.device(
                config.inputs_device()
        ):  # syaru: deeplab/utils/input_generator.get(): This functions gets the dataset split for semantic segmentation.
            samples = input_generator.get(  # Returns: A dictionary of batched Tensors for semantic segmentation.
                dataset,  # Args: dataset: An instance of slim Dataset.
                FLAGS.
                train_crop_size,  #       train_crop_size: 如果定义了crop_size,那么在train时会对大于crop_size的图片进行随机裁剪
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.
                min_scale_factor,  # syaru: min_scale_factor: 'Minmum scale factor for data augmentation.'
                max_scale_factor=FLAGS.
                max_scale_factor,  # min_scale_factor: 'Maximum scale factor for data augmentation.'
                scale_factor_step_size=FLAGS.
                scale_factor_step_size,  # scale_factor_step_size: 'Scale factor step size for data augmentation.'(from minmum to maximum)
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            # syaru: /tensorflow/contrib/slim/python/slim/data/prefetch_queue.py
            inputs_queue = prefetch_queue.prefetch_queue(  # tensors: A list or dictionary of `Tensors` to enqueue in the buffer.
                samples,
                capacity=128 * config.num_clones
            )  # capacity: An integer. The maximum number of elements in the queue.

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            """
      syaru: 
      models/research/slim/deployment/model_deploy.create_clones():
      The `model_fn(*args, **kwargs)` function is called `config.num_clones` times to create the model clones.
      (and one or several clones are deployed on different GPUs and one or several replicas of such clones.)
      Then it return the scope and device in a namedtuple `Clone(outputs, scope, device)`.

      Args:
      config: A DeploymentConfig object.
      model_fn: A callable. Called as `model_fn(*args, **kwargs)`
      args: Optional list of arguments to pass to `model_fn`.
      kwargs: Optional list of keyword arguments to pass to `model_fn`..
      Returns:
      A list of namedtuples `Clone`.

      Note: it is assumed that any loss created by `model_fn` is collected at
      the tf.GraphKeys.LOSSES collection.

      To recover the losses, summaries or update_ops created by the clone use:
      ```python
      losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
      summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
      ```
      """
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                common.OUTPUT_TYPE: dataset.num_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(  # syaru: get_tensor_by_name(name): return tensor by specifily 'name'.
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')
            )  # str.strip (): is used to remove the specified characters at the front/end of the string (the default is space).
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            summary_label = tf.cast(
                graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, common.LABEL)).strip('/')), tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            predictions = tf.cast(
                tf.expand_dims(
                    tf.argmax(
                        graph.get_tensor_by_name(  # syaru: tf.argmax(axis=3)
                            ('%s/%s:0' % (first_clone_scope,
                                          common.OUTPUT_TYPE)).strip('/')),
                        3),
                    -1),
                tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            # syaru: train_utils.get_model_learning_rate():
            #        Computes the model's learning rate for different learning policy("step" and "poly").
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            # syaru: Compute clone losses and gradients for the given list of `Clones`.
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            """
      syaru: 
      For the task of semantic segmentation, the models are
      usually fine-tuned from the models trained on the task of image
      classification. To fine-tune the models, we usually set larger (e.g.,
      10 times larger) learning rate for the parameters of last layer.
      
      deeplab/model/model.get_extra_layer_scopes():
      Returns: A list of scopes for extra layers. 

      deeplab/utils/train_utils.get_model_gradient_multipliers():
      Returns: The gradient multiplier map with variables as key, and multipliers as value.
      """
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            # syaru: tf.identity()和tf.group()均可将语句变为操作(ops).
            #        (我们需要`optimizer.apply_gradients`后才计算`total_loss`(as 'train_op'),而tf.control_dependencies()适用于tf.ops)
            #        And `update_ops = tf.get_collection(..)` only return a list of variables.
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        # syaru: set gpu_options
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False,
                                        gpu_options=gpu_options)

        # Start the training.
        # syaru: /tensorflow/contrib/slim/python/slim/learning.py
        # train_utils.get_model_init_fn(): Gets the function initializing model variables from a checkpoint.
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_logdir,
            log_every_n_steps=FLAGS.log_steps,
            master=FLAGS.master,
            number_of_steps=FLAGS.training_number_of_steps,
            is_chief=(FLAGS.task == 0),
            session_config=session_config,
            startup_delay_steps=startup_delay_steps,  # syaru:
            init_fn=train_utils.
            get_model_init_fn(  # `init_fn`: An optional callable to be executed after `init_op` is called. The
                FLAGS.
                train_logdir,  # callable must accept one argument, the session being initialized.
                FLAGS.tf_initial_checkpoint,
                FLAGS.initialize_last_layer,
                last_layers,
                ignore_missing_vars=True),
            summary_op=summary_op,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs)