Beispiel #1
0
 def scaffold_fn():
     """Create Scaffold for initialization, etc."""
     if params['init_backbone_only']:
         print('initialization - init backbone only')
         tf.train.init_from_checkpoint(params['init_checkpoint'],
                                       {var_scope: var_scope})
         return tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=8))
     else:
         print('initialization - init full model')
         init_fn = train_utils.get_model_init_fn(
             params['model_dir'],
             params['init_checkpoint'],
             True, [],
             ignore_missing_vars=False)
         return tf.train.Scaffold(init_fn=init_fn,
                                  saver=tf.train.Saver(max_to_keep=8))
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
            assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
                'Training batch size not divisble by number of clones (GPUs).')
            clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=FLAGS.train_crop_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=2,
                is_training=True,
                should_shuffle=True,
                should_repeat=True)

            train_tensor, summary_op = _train_deeplab_model(
                dataset.get_one_shot_iterator(), dataset.num_of_classes,
                dataset.ignore_label)

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)

            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            scaffold = tf.train.Scaffold(
                init_fn=init_fn,
                summary_op=summary_op,
            )

            stop_hook = tf.train.StopAtStepHook(FLAGS.training_number_of_steps)

            profile_dir = FLAGS.profile_logdir
            if profile_dir is not None:
                tf.gfile.MakeDirs(profile_dir)

            with tf.contrib.tfprof.ProfileContext(enabled=profile_dir
                                                  is not None,
                                                  profile_dir=profile_dir):
                with tf.train.MonitoredTrainingSession(
                        master=FLAGS.master,
                        is_chief=(FLAGS.task == 0),
                        config=session_config,
                        scaffold=scaffold,
                        checkpoint_dir=FLAGS.train_logdir,
                        log_step_count_steps=FLAGS.log_steps,
                        save_summaries_steps=FLAGS.save_summaries_secs,
                        save_checkpoint_secs=FLAGS.save_interval_secs,
                        hooks=[stop_hook]) as sess:
                    while not sess.should_stop():
                        sess.run([train_tensor])
Beispiel #3
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    common.outputlogMessage('Training on %s set' % FLAGS.train_split)
    common.outputlogMessage('Dataset: %s' % FLAGS.dataset)
    common.outputlogMessage('train_crop_size: %s' % str(FLAGS.train_crop_size))
    common.outputlogMessage(str(FLAGS.train_crop_size))
    common.outputlogMessage('atrous_rates: %s' % str(FLAGS.atrous_rates))
    common.outputlogMessage('number of classes: %s' % str(FLAGS.num_classes))
    common.outputlogMessage('Ignore label value: %s' % str(FLAGS.ignore_label))
    pid = os.getpid()
    with open('train_py_pid.txt', 'w') as f_obj:
        f_obj.writelines('%d' % pid)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True,
                num_classes=FLAGS.num_classes,
                ignore_label=FLAGS.ignore_label)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (dataset.get_one_shot_iterator(), {
                common.OUTPUT_TYPE: dataset.num_of_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        if FLAGS.quantize_delay_step >= 0:
            if FLAGS.num_clones > 1:
                raise ValueError(
                    'Quantization doesn\'t support multi-clone yet.')
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay_step)

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Beispiel #4
0
def main(unused_argv):

  datasetDescriptor = None
  if FLAGS.config and os.path.isfile(FLAGS.config):
    with open(FLAGS.config) as f:
      trainingConfig = json.load(f)
      for key in trainingConfig:
        if key in FLAGS:
          FLAGS[key].value = trainingConfig[key]
        elif key == 'DatasetDescriptor':
          datasetDescriptor = segmentation_dataset.DatasetDescriptor(
                                name=trainingConfig[key]['name'],
                                splits_to_sizes=trainingConfig[key]['splits_to_sizes'],
                                num_classes=trainingConfig[key]['num_classes'],
                                ignore_label=trainingConfig[key]['ignore_label'],
                              )

  assert FLAGS.dataset_dir, (
      'flag --dataset_dir=None: Flag --dataset_dir must be specified.')

  assert FLAGS.train_logdir, (
      'flag --train_logdir=None: Flag --train_logdir must be specified.')

  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  if datasetDescriptor is None:
    datasetDescriptor = FLAGS.dataset

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      datasetDescriptor, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)
  train_rel_map = {"images": "train", "labels": "label"}
  base_dir = "/mnt/sda/deep_learning/CSE527_FinalProject-master/images"

  graph = tf.Graph()
  with graph.as_default():
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).')
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones
      input_pipeline = ImageInputPipeline(train_rel_map, ".tif", base_dir)
      dataset = input_pipeline._input_fn(size=(256, 256), batch_size=FLAGS.train_batch_size, augment=False)

      train_tensor, summary_op = _train_deeplab_model(dataset.make_one_shot_iterator(), 3, 255)

      # Soft placement allows placing on CPU ops without GPU implementation.
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)

      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )

      stop_hook = tf.train.StopAtStepHook(
          last_step=FLAGS.training_number_of_steps)

      profile_dir = FLAGS.profile_logdir
      if profile_dir is not None:
        tf.gfile.MakeDirs(profile_dir)

      with tf.contrib.tfprof.ProfileContext(
          enabled=profile_dir is not None, profile_dir=profile_dir):
        with tf.train.MonitoredTrainingSession(
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            config=session_config,
            scaffold=scaffold,
            checkpoint_dir=FLAGS.train_logdir,
            summary_dir=FLAGS.train_logdir,
            log_step_count_steps=FLAGS.log_steps,
            save_summaries_steps=FLAGS.save_summaries_secs,
            save_checkpoint_secs=FLAGS.save_interval_secs,
            hooks=[stop_hook]) as sess:
          while not sess.should_stop():
            sess.run([train_tensor])
Beispiel #6
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Beispiel #7
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.batch_iter < 1:
        FLAGS.batch_iter = 1
    if FLAGS.batch_iter != 1:
        if not (FLAGS.num_clones == 1 and FLAGS.num_replicas == 1):
            raise NotImplementedError(
                "train.py: **NOTE** -- train_utils.train_step_custom may not work with parallel GPUs / clones > 1! Be sure you are only using one GPU."
            )

    print('\ntrain.py: Accumulating gradients over {} iterations\n'.format(
        FLAGS.batch_iter))

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(
        FLAGS.dataset,
        FLAGS.train_split,
        dataset_dir=FLAGS.dataset_dir,
    )

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            if FLAGS.class_balanced_loss:
                print(
                    'train.py: class_balanced_loss=True. Reading loss weights from segmentation_dataset.py'
                )
            else:
                print(
                    'train.py: class_balanced_loss=False. Setting loss weights to 1.0 for every class.'
                )
                dataset.loss_weight = 1.0

            #_build_deeplab has model args:
            #(inputs_queue, outputs_to_num_classes, ignore_label, loss_weight):

            outputs_to_num_classes = {common.OUTPUT_TYPE: dataset.num_classes}

            model_args = (inputs_queue,\
                          outputs_to_num_classes,
                          dataset.ignore_label, dataset.loss_weight)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):

            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            if FLAGS.batch_iter <= 1:
                FLAGS.batch_iter = 0
                summaries.add(tf.summary.scalar('total_loss', total_loss))
                grad_updates = optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                update_ops.append(grad_updates)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name='train_op')
                accum_tensor = None
            else:

                ############ Accumulate grads_and_vars op. ####################
                accum_update_ops = list(update_ops)  #.copy()
                # Create (grad, var) list to accumulate gradients in. Inititalize to 0.
                accum_grads_and_vars = [
                    (tf.Variable(tf.zeros_like(gv[0]),
                                 trainable=False,
                                 name=gv[0].name.strip(":0") + "_accum"),
                     gv[1]) for gv in grads_and_vars
                ]
                assert len(accum_grads_and_vars) == len(grads_and_vars)

                total_loss_accum = tf.Variable(0.0,
                                               dtype=tf.float32,
                                               trainable=False)
                accum_loss_update_op = [
                    total_loss_accum.assign_add(total_loss)
                ]
                accum_update_ops.append(accum_loss_update_op)

                ## Accumulate gradients: accum_grad[i] += (grad[i] / FLAGS.batch_iter)  # scaled gradients.
                accum_ops = [
                    accum_grads_and_vars[i][0].assign_add(
                        tf.div(gv[0], 1.0 * FLAGS.batch_iter))
                    for i, gv in enumerate(grads_and_vars)
                ]
                accum_update_ops.append(accum_ops)

                accum_update_op = tf.group(*accum_update_ops)
                with tf.control_dependencies([accum_update_op]):
                    accum_print_ops = []
                    if FLAGS.batch_iter_verbose:
                        accum_print_ops.extend([
                            tf.Print(
                                tf.constant(0), [tf.add(global_step, 1)],
                                message=
                                'train.py: accumulating gradients for step: '),
                            #tf.Print(total_loss, [total_loss], message='    step total_loss: ')
                            #tf.Print(tf.constant(0), [accum_grads_and_vars[0][0]], message='    '),
                        ])
                    accum_update_ops.append(accum_print_ops)
                    with tf.control_dependencies([tf.group(*accum_print_ops)]):
                        accum_tensor = tf.identity(total_loss_accum,
                                                   name='accum_op')

                ##################### Train op (apply [accumulated] grads and vars) ###############################
                train_update_ops = list(update_ops)  #.copy()
                ## Create gradient update op.
                # Apply gradients from accumulated gradients
                grad_updates = optimizer.apply_gradients(
                    accum_grads_and_vars, global_step=global_step)
                train_update_ops.append(grad_updates)

                grad_print_ops = []
                if FLAGS.batch_iter_verbose:
                    grad_print_ops.extend([
                        #                tf.Print(tf.constant(0), [grads_and_vars[0][0], grads_and_vars[0][1]], message='---grads[0] and vars[0]---------\n'),
                        #tf.Print(tf.constant(0), [], message=grads_and_vars[0][1].name),
                        tf.Print(tf.constant(0), [accum_grads_and_vars[0][0]],
                                 message='GRADS  BEFORE ZERO: ')
                    ])
                train_update_ops.append(grad_print_ops)

                total_loss_accum_average = tf.div(total_loss_accum,
                                                  FLAGS.batch_iter)
                summaries.add(
                    tf.summary.scalar('total_loss', total_loss_accum_average))

                train_update_op = tf.group(*train_update_ops)
                with tf.control_dependencies([train_update_op]):
                    zero_ops = []

                    zero_accum_ops = [
                        agv[0].assign(tf.zeros_like(agv[0]))
                        for agv in accum_grads_and_vars
                    ]
                    zero_ops.append(zero_accum_ops)

                    zero_accum_total_loss_op = [total_loss_accum.assign(0)]
                    zero_ops.append(zero_accum_total_loss_op)

                    zero_op = tf.group(*zero_ops)
                    with tf.control_dependencies([zero_op]):
                        grad_print_ops = []
                        if FLAGS.batch_iter_verbose:
                            grad_print_ops.extend([
                                #tf.Print(tf.constant(0), [accum_grads_and_vars[0][0]], message='GRADS AFTER ZERO ')
                            ])
                        with tf.control_dependencies(
                            [tf.group(*grad_print_ops)]):
                            train_tensor = tf.identity(
                                total_loss_accum_average, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or
        # _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        session_config.gpu_options.allow_growth = True

        #train_step_exit = train_utils.train_step_exit
        train_step_custom = train_utils.train_step_custom
        if FLAGS.validation_interval <= 0:
            FLAGS.validation_interval = FLAGS.training_number_of_steps
        else:
            print("*** Validation interval: {} ***".format(
                FLAGS.validation_interval))

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            train_step_fn=train_step_custom(
                                VALIDATION_N=FLAGS.validation_interval,
                                ACCUM_OP=accum_tensor,
                                ACCUM_STEPS=FLAGS.batch_iter),
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv):

  print("logging params")
  print("Learning rate: ", FLAGS.base_learning_rate)
  print("Momentum: ", FLAGS.momentum)
  print("Weight decay: ", FLAGS.weight_decay)
  print("training steps: ", FLAGS.training_number_of_steps)
  print("Dataset name: ",FLAGS.dataset)
  print("Using dataset for training: ",FLAGS.train_split)
  print("Dataset directory: ",FLAGS.dataset_dir)
  print("batch size: ", FLAGS.train_batch_size)
  print("crop size: ", FLAGS.train_crop_size)
  print("Model variant used: ",FLAGS.model_variant)
  print("Train log directory: ", FLAGS.train_logdir)
  train_list = []
  val_list = []
  count= 0
  best_val_mean_iou = 0.718
  dir_path='deeplab/best_ckpt/'

  tf.logging.set_verbosity(tf.logging.INFO)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  graph = tf.Graph()
  with graph.as_default():
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).')
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones # will be equivalent to train_batch_size

      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=1, #check??
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)

      # Soft placement allows placing on CPU ops without GPU implementation.
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)

      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )

      stop_hook = tf.train.StopAtStepHook(
          last_step=FLAGS.training_number_of_steps)

      profile_dir = FLAGS.profile_logdir
      if profile_dir is not None:
        tf.gfile.MakeDirs(profile_dir)

      with tf.contrib.tfprof.ProfileContext(
          enabled=profile_dir is not None, profile_dir=profile_dir):
        with tf.train.MonitoredTrainingSession(
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            config=session_config,
            scaffold=scaffold,
            checkpoint_dir=FLAGS.train_logdir,
            summary_dir=FLAGS.train_logdir,
            log_step_count_steps=FLAGS.log_steps,
            save_summaries_steps=FLAGS.save_summaries_secs,
            save_checkpoint_secs=FLAGS.save_interval_secs,
            hooks=[stop_hook]) as sess:
          while not sess.should_stop():
            count+=1
            training_loss = sess.run([train_tensor])
            if np.isnan(training_loss):
                print("learning rate too high. exiting!")
                exit()

            try:
              if count>5000 and count%200==0:
                train_iou = subprocess.check_output([sys.executable, "deeplab/vistrain.py"])
                val_iou = subprocess.check_output([sys.executable, "deeplab/vis.py"])
                val_mean_iou = float(val_iou.decode("utf-8").split('\n')[-2])
                val_list.append(val_mean_iou*100)
                train_mean_iou=float(train_iou.decode("utf-8").split('\n')[-2])*100
                train_list.append(train_mean_iou)

                
                print("Mean IoU on training dataset: ", train_mean_iou)
                print("Mean IoU on validation dataset: ", val_mean_iou)
                sys.stdout.flush()

                if  val_mean_iou > best_val_mean_iou:
                  if os.path.isdir(dir_path): shutil.rmtree(dir_path)
		            
                  print("Validation Mean IoU: ", val_mean_iou)
                  shutil.copytree(FLAGS.train_logdir, dir_path)
                  best_val_mean_iou = val_mean_iou
            except:
              print("Validation script returned non-zero status.")
Beispiel #9
0
    def train(self):
        FLAGS = self.flags
        dataset_split = 'train'
        data_config = edict()
        data_config.edge_width = 20
        data_config.ignore_label = DATASETS_IGNORE_LABEL[FLAGS.dataset]
        data_config.edge_class_num = FLAGS.edge_class_num
        img_files, label_files = get_dataset_files(FLAGS.dataset,
                                                   dataset_split)

        dataset = edict()
        dataset_pp = dataset_pipeline(data_config,
                                      img_files,
                                      label_files,
                                      is_train=True)
        dataset.num_classes = DATASETS_CLASS_NUM[FLAGS.dataset]
        dataset.ignore_label = DATASETS_IGNORE_LABEL[FLAGS.dataset]
        dataset.num_samples = len(dataset_pp)

        tf.logging.set_verbosity(tf.logging.INFO)
        # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
        config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                               clone_on_cpu=FLAGS.clone_on_cpu,
                                               replica_id=FLAGS.task,
                                               num_replicas=FLAGS.num_replicas,
                                               num_ps_tasks=FLAGS.num_ps_tasks)

        # Split the batch across GPUs.
        assert FLAGS.train_batch_size % config.num_clones == 0, (
            'Training batch size not divisble by number of clones (GPUs).')

        clone_batch_size = FLAGS.train_batch_size // config.num_clones

        # Get dataset-dependent information.
        #        dataset = segmentation_dataset.get_dataset(
        #            FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

        tf.gfile.MakeDirs(FLAGS.train_logdir)
        tf.logging.info('Training on %s set', FLAGS.train_split)

        with tf.Graph().as_default() as graph:
            with tf.device(config.inputs_device()):
                data_list = dataset_pp.iterator()
                samples = input_generator.get(
                    (data_list, dataset.ignore_label),
                    FLAGS.train_crop_size,
                    clone_batch_size,
                    min_resize_value=FLAGS.min_resize_value,
                    max_resize_value=FLAGS.max_resize_value,
                    resize_factor=FLAGS.resize_factor,
                    min_scale_factor=FLAGS.min_scale_factor,
                    max_scale_factor=FLAGS.max_scale_factor,
                    scale_factor_step_size=FLAGS.scale_factor_step_size,
                    dataset_split=FLAGS.train_split,
                    is_training=True,
                    model_variant=FLAGS.model_variant)
                inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                             capacity=128 *
                                                             config.num_clones)

            # Create the global step on the device storing the variables.
            with tf.device(config.variables_device()):
                global_step = tf.train.get_or_create_global_step()

                # Define the model and create clones.
                model_fn = self._build_deeplab
                model_args = (inputs_queue, {
                    common.OUTPUT_TYPE: dataset.num_classes
                }, dataset.ignore_label)
                clones = model_deploy.create_clones(config,
                                                    model_fn,
                                                    args=model_args)

                # Gather update_ops from the first clone. These contain, for example,
                # the updates for the batch_norm variables created by model_fn.
                first_clone_scope = config.clone_scope(0)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               first_clone_scope)

            # Gather initial summaries.
            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

            # Add summaries for model variables.
            for model_var in slim.get_model_variables():
                summaries.add(
                    tf.summary.histogram(model_var.op.name, model_var))

            label_name = ('%s/%s:0' %
                          (first_clone_scope, common.LABEL)).strip('/')
            print('first clone label name is:', label_name)

            # Add summaries for images, labels, semantic predictions
            if FLAGS.save_summaries_images:
                summary_image = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
                summaries.add(
                    tf.summary.image('samples/%s' % common.IMAGE,
                                     summary_image))

                first_clone_label = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))

                # Scale up summary image pixel values for better visualization.
                pixel_scaling = max(1, 255 // dataset.num_classes)
                summary_label = tf.cast(first_clone_label * pixel_scaling,
                                        tf.uint8)
                summaries.add(
                    tf.summary.image('samples/%s' % common.LABEL,
                                     summary_label))

                first_clone_output = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
                predictions = tf.expand_dims(tf.argmax(first_clone_output, 3),
                                             -1)

                summary_predictions = tf.cast(predictions * pixel_scaling,
                                              tf.uint8)
                summaries.add(
                    tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                     summary_predictions))

            # Add summaries for miou,acc
            labels = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            predictions = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.image.resize_bilinear(predictions,
                                                   tf.shape(labels)[1:3],
                                                   align_corners=True)

            labels = tf.reshape(labels, shape=[-1])
            predictions = tf.reshape(tf.argmax(predictions, 3), shape=[-1])
            weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))

            # Set ignore_label regions to label 0, because metrics.mean_iou requires
            # range of labels = [0, dataset.num_classes). Note the ignore_label regions
            # are not evaluated since the corresponding regions contain weights = 0.
            labels = tf.where(tf.equal(labels, dataset.ignore_label),
                              tf.zeros_like(labels), labels)

            # Define the evaluation metric.
            metric_map = {}
            metric_map['miou'], _ = tf.metrics.mean_iou(predictions,
                                                        labels,
                                                        dataset.num_classes,
                                                        weights=weights)
            metric_map['acc'], _ = tf.metrics.accuracy(
                labels=labels,
                predictions=predictions,
                weights=tf.reshape(weights, shape=[-1]))

            for x in ['miou', 'acc']:
                summaries.add(
                    tf.summary.scalar('metrics/%s' % x, metric_map[x]))

            # Add summaries for losses.
            for loss in tf.get_collection(tf.GraphKeys.LOSSES,
                                          first_clone_scope):
                summaries.add(
                    tf.summary.scalar('losses/%s' % loss.op.name, loss))

            # Build the optimizer based on the device specification.
            with tf.device(config.optimizer_device()):
                learning_rate = train_utils.get_model_learning_rate(
                    FLAGS.learning_policy, FLAGS.base_learning_rate,
                    FLAGS.learning_rate_decay_step,
                    FLAGS.learning_rate_decay_factor,
                    FLAGS.training_number_of_steps, FLAGS.learning_power,
                    FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
                summaries.add(tf.summary.scalar('learning_rate',
                                                learning_rate))

            startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
            for variable in slim.get_model_variables():
                summaries.add(tf.summary.histogram(variable.op.name, variable))

            with tf.device(config.variables_device()):
                total_loss, grads_and_vars = model_deploy.optimize_clones(
                    clones, optimizer)
                total_loss = tf.check_numerics(total_loss,
                                               'Loss is inf or nan.')
                summaries.add(tf.summary.scalar('total_loss', total_loss))

                # Modify the gradients for biases and last layer variables.
                last_layers = model.get_extra_layer_scopes(
                    FLAGS.last_layers_contain_logits_only)
                grad_mult = train_utils.get_model_gradient_multipliers(
                    last_layers, FLAGS.last_layer_gradient_multiplier)
                if grad_mult:
                    grads_and_vars = slim.learning.multiply_gradients(
                        grads_and_vars, grad_mult)

                # Create gradient update op.
                grad_updates = optimizer.apply_gradients(
                    grads_and_vars, global_step=global_step)
                update_ops.append(grad_updates)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss, name='train_op')

            # Add the summaries from the first clone. These contain the summaries
            # created by model_fn and either optimize_clones() or _gather_clone_loss().
            summaries |= set(
                tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

            # Merge all summaries together.
            summary_op = tf.summary.merge(list(summaries))

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)

            # Start the training.
            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=train_utils.get_model_init_fn(
                                    FLAGS.train_logdir,
                                    FLAGS.tf_initial_checkpoint,
                                    FLAGS.initialize_last_layer,
                                    last_layers,
                                    ignore_missing_vars=True),
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  # 设置多gpu训练的相关参数
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,  # gpu数量
      clone_on_cpu=FLAGS.clone_on_cpu,  # 默认为False
      replica_id=FLAGS.task,    # taskId
      num_replicas=FLAGS.num_replicas,  # 默认为1
      num_ps_tasks=FLAGS.num_ps_tasks)  # 默认为0

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones    # 各个gpu均分batch_size

  tf.gfile.MakeDirs(FLAGS.train_logdir)     # 创建存放训练日志的文件
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      dataset = data_generator.Dataset(     # 定义数据集参数
          dataset_name=FLAGS.dataset,   # 数据集名称 cityscapes
          split_name=FLAGS.train_split,  # 指定带有train的tfrecorder数据集 默认为“train”
          dataset_dir=FLAGS.dataset_dir,   # 数据集目录 tfrecoder文件的数据集目录
          batch_size=clone_batch_size,  # 均分后各个gpu训练中指定batch_size 的大小
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],  # 训练中裁剪的图像大小 513,513
          min_resize_value=FLAGS.min_resize_value,  # 默认为 None
          max_resize_value=FLAGS.max_resize_value,  # 默认为None
          resize_factor=FLAGS.resize_factor,    # 默认为None
          min_scale_factor=FLAGS.min_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最小为0.5
          max_scale_factor=FLAGS.max_scale_factor,   # 训练中,图像变换尺度,用于数据增强 默认最大为2
          scale_factor_step_size=FLAGS.scale_factor_step_size,      # 训练中,图像变换尺度增加的步长,默认为0.25  从0.5到2
          model_variant=FLAGS.model_variant,    # 指定模型 xception_65
          num_readers=4,    # 读取数据个数 若多gpu可增大加快训练速度
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      # 计数作用,每训练一个batch, global加1
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab  # 定义deeplab模型
      model_args = (dataset.get_one_shot_iterator(), {
          common.OUTPUT_TYPE: dataset.num_of_classes
      }, dataset.ignore_label) #模型参数
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in tf.model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:      # 默认为False
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_of_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(  # 获取模型学习率
          FLAGS.learning_policy,    # poly学习策略
          FLAGS.base_learning_rate,     # 0.0001
          FLAGS.learning_rate_decay_step,   # 固定2000次进行一次学习率衰退
          FLAGS.learning_rate_decay_factor,     # 0.1
          FLAGS.training_number_of_steps,   # 训练次数 20000
          FLAGS.learning_power,     # poly power 0.9
          FLAGS.slow_start_step,    # 0
          FLAGS.slow_start_learning_rate,   # 1e-4 缓慢开始的学习率
          decay_steps=FLAGS.decay_steps,    # 0.0
          end_learning_rate=FLAGS.end_learning_rate)     # 0.0

      summaries.add(tf.summary.scalar('learning_rate', learning_rate))
      # 模型训练优化器
      if FLAGS.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      elif FLAGS.optimizer == 'adam':   # adam优化器 寻找全局最优点的优化算法,引入了二次方梯度校正
        optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.adam_learning_rate, epsilon=FLAGS.adam_epsilon)
      else:
        raise ValueError('Unknown optimizer')

    if FLAGS.quantize_delay_step >= 0:  # 默认为-1 忽略
      if FLAGS.num_clones > 1:
        raise ValueError('Quantization doesn\'t support multi-clone yet.')
      contrib_quantize.create_training_graph(
          quant_delay=FLAGS.quantize_delay_step)

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps    # FLAGS.startup_delay_steps 默认为15

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)    # 计算total_loss
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      # 获取梯度乘子
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      # grad_mult : {'logits/semantic/biases': 2.0, 'logits/semantic/weights': 1.0}
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(     # 将计算的梯度用于变量上,返回一个应用指定的梯度的操作 opration
          grads_and_vars, global_step=global_step)  # 对global_step进行自增
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    profile_dir = FLAGS.profile_logdir   # 默认为None
    if profile_dir is not None:
      tf.gfile.MakeDirs(profile_dir)

    with contrib_tfprof.ProfileContext(
        enabled=profile_dir is not None, profile_dir=profile_dir):
      init_fn = None
      if FLAGS.tf_initial_checkpoint:   # 获取预训练权重
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      slim.learning.train(
          train_tensor,
          logdir=FLAGS.train_logdir,
          log_every_n_steps=FLAGS.log_steps,
          master=FLAGS.master,
          number_of_steps=FLAGS.training_number_of_steps,
          is_chief=(FLAGS.task == 0),
          session_config=session_config,
          startup_delay_steps=startup_delay_steps,
          init_fn=init_fn,
          summary_op=summary_op,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
Beispiel #11
0
def main(unused_argv):
  print("DEEPLABv3+")
  print("SAVE TO "+FLAGS.train_logdir)
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)
  print("batch_norm: "+str(FLAGS.fine_tune_batch_norm))
  print("initialize_last_layer: "+str(FLAGS.initialize_last_layer))
  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size // config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
  dataset_val = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.val_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)
  tf.logging.info('Validating on %s set', FLAGS.val_split)

  with tf.Graph().as_default() as graph:
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # 4 val

    samples_val = input_generator.get(
        dataset_val,
        FLAGS.train_crop_size,
        FLAGS.train_batch_size,
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        dataset_split=FLAGS.val_split,
        is_training=False,
        model_variant=FLAGS.model_variant)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for images, labels, semantic predictions
    if FLAGS.save_summaries_images:
      summary_image = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
      summaries.add(
          tf.summary.image('samples/%s' % common.IMAGE, summary_image))

      first_clone_label = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
      # Scale up summary image pixel values for better visualization.
      pixel_scaling = max(1, 255 // dataset.num_classes)
      summary_label = tf.cast(first_clone_label * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image('samples/%s' % common.LABEL, summary_label))

      first_clone_output = graph.get_tensor_by_name(
          ('%s/%s:0' % (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
      predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

      summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
      summaries.add(
          tf.summary.image(
              'samples/%s' % common.OUTPUT_TYPE, summary_predictions))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # 4 val
    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_classes},
        crop_size=FLAGS.train_crop_size,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)
    predictions_val = model.predict_labels(samples_val[common.IMAGE], model_options,
                                         image_pyramid=FLAGS.image_pyramid)
    predictions_val = predictions_val[common.OUTPUT_TYPE]
    predictions_val = tf.reshape(predictions_val, shape=[-1])
    labels_val = tf.reshape(samples_val[common.LABEL], shape=[-1])

    # Set ignore_label regions to label 0, because metrics.mean_iou requires
    # range of labels = [0, dataset.num_classes). Note the ignore_label regions
    # are not evaluated since the corresponding regions contain weights = 0.
    #labels = tf.where(
    #    tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels)
    accuracy_validation = slim.metrics.accuracy(tf.to_int32(predictions_val),
                                                tf.to_int32(labels_val))
    iou,conf_mat = tf.metrics.mean_iou(labels_val, predictions_val, num_classes=6)
    #sess.run(tf.local_variables_initializer())


    def train_step_fn(session, *args, **kwargs):
        total_loss, should_stop = train_step(session, *args, **kwargs)

        if train_step_fn.step % FLAGS.validation_check == 0:
            pass
            # throws OutOfRange error after some time
         #   accuracy = session.run(train_step_fn.accuracy_validation)
          #  print('Step %s - Loss: %.2f Accuracy: %.2f%%' % (
          #  str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100))

     #   if train_step_fn.step == (FLAGS.max_steps - 1):
     #       accuracy = session.run(accuracy_test)
     #       print('%s - Loss: %.2f Accuracy: %.2f%%' % ('FINAL TEST', total_loss, accuracy * 100))

        train_step_fn.step += 1
        return [total_loss, should_stop]

    train_step_fn.step = 0
    train_step_fn.accuracy_validation = accuracy_validation

    # Start the training.
    slim.learning.train(
        train_tensor,
        train_step_fn=train_step_fn,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv):
    FLAGS.train_logdir = FLAGS.base_logdir + '/' + FLAGS.task_name
    if FLAGS.restore_name == None:
        FLAGS.restore_logdir = FLAGS.train_logdir
    else:
        FLAGS.restore_logdir = FLAGS.base_logdir + '/' + FLAGS.restore_name

    tf.logging.set_verbosity(tf.logging.INFO)

    # Get logging dir ready.
    if not (os.path.isdir(FLAGS.train_logdir)):
        tf.gfile.MakeDirs(FLAGS.train_logdir)
    elif len(os.listdir(FLAGS.train_logdir)) != 0:
        if not (FLAGS.if_restore):
            if_delete_all = raw_input(
                '#### The log folder %s exists and non-empty; delete all logs? [y/n] '
                % FLAGS.train_logdir)
            if if_delete_all == 'y':
                os.system('rm -rf %s/*' % FLAGS.train_logdir)
                print '==== Log folder emptied.'
        else:
            print '==== Log folder exists; not emptying it because we need to restore from it.'
    tf.logging.info('==== Logging in dir:%s; Training on %s set',
                    FLAGS.train_logdir, FLAGS.train_split)

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.num_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)  # /device:CPU:0

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')
    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = regression_dataset.get_dataset(FLAGS.dataset,
                                             FLAGS.train_split,
                                             dataset_dir=FLAGS.dataset_dir)
    dataset_val = regression_dataset.get_dataset(FLAGS.dataset,
                                                 FLAGS.val_split,
                                                 dataset_dir=FLAGS.dataset_dir)
    print '#### The data has size:', dataset.num_samples, dataset_val.num_samples

    codes = np.load(
        '/ssd2/public/zhurui/Documents/mesh-voxelization/models/cars_64/codes.npy'
    )

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            codes_max = np.amax(codes, axis=1).reshape((-1, 1))
            codes_min = np.amin(codes, axis=1).reshape((-1, 1))
            shape_range = np.hstack(
                (codes_max + (codes_max - codes_min) /
                 (dataset.SHAPE_BINS - 1.), codes_min -
                 (codes_max - codes_min) / (dataset.SHAPE_BINS - 1.)))
            bin_range = [
                np.linspace(r[0], r[1], num=b).tolist()
                for r, b in zip(np.vstack((dataset.pose_range,
                                           shape_range)), dataset.bin_nums)
            ]
            # print np.vstack((dataset.pose_range, shape_range))
            # print bin_range[0]
            # print bin_range[-1]
            outputs_to_num_classes = {}
            outputs_to_indices = {}
            for output, bin_num, idx in zip(dataset.output_names,
                                            dataset.bin_nums,
                                            range(len(dataset.output_names))):
                if FLAGS.if_discrete_loss:
                    outputs_to_num_classes[output] = bin_num
                else:
                    outputs_to_num_classes[output] = 1
                outputs_to_indices[output] = idx
            bin_vals = [tf.constant(value=[bin_range[i]], dtype=tf.float32, shape=[1, dataset.bin_nums[i]], name=name) \
                    for i, name in enumerate(dataset.output_names)]
            # print outputs_to_num_classes
            # print spaces_to_indices

            samples = input_generator.get(dataset,
                                          codes,
                                          clone_batch_size,
                                          dataset_split=FLAGS.train_split,
                                          is_training=True,
                                          model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

            samples_val = input_generator.get(
                dataset_val,
                codes,
                clone_batch_size,
                dataset_split=FLAGS.val_split,
                is_training=False,
                model_variant=FLAGS.model_variant)
            inputs_queue_val = prefetch_queue.prefetch_queue(samples_val,
                                                             capacity=128)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (FLAGS, inputs_queue.dequeue(),
                          outputs_to_num_classes, outputs_to_indices, bin_vals,
                          bin_range, dataset, codes, True, False)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)  # clone_0
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        with tf.device('/device:GPU:3'):
            if FLAGS.if_val:
                ## Construct the validation graph; takes one GPU.
                _build_deeplab(FLAGS,
                               inputs_queue_val.dequeue(),
                               outputs_to_num_classes,
                               outputs_to_indices,
                               bin_vals,
                               bin_range,
                               dataset_val,
                               codes,
                               is_training=False,
                               reuse=True)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for images, labels, semantic predictions
        summary_loss_dict = {}
        if FLAGS.save_summaries_images:
            if FLAGS.num_clones > 1:
                pattern_train = first_clone_scope + '/%s:0'
            else:
                pattern_train = '%s:0'
            pattern_val = 'val-%s:0'
            pattern = pattern_val if FLAGS.if_val else pattern_train

            gather_list = [0] if FLAGS.num_clones < 3 else [0, 1, 2]

            summary_mask = graph.get_tensor_by_name(pattern %
                                                    'not_ignore_mask_in_loss')
            summary_mask = tf.reshape(summary_mask,
                                      [-1, dataset.height, dataset.width, 1])
            summary_mask_float = tf.to_float(summary_mask)
            summaries.add(
                tf.summary.image(
                    'gt/%s' % 'not_ignore_mask',
                    tf.gather(tf.cast(summary_mask_float * 255., tf.uint8),
                              gather_list)))

            summary_image = graph.get_tensor_by_name(pattern % common.IMAGE)
            summaries.add(
                tf.summary.image('gt/%s' % common.IMAGE,
                                 tf.gather(summary_image, gather_list)))

            summary_image_name = graph.get_tensor_by_name(pattern %
                                                          common.IMAGE_NAME)
            summaries.add(
                tf.summary.text('gt/%s' % common.IMAGE_NAME,
                                tf.gather(summary_image_name, gather_list)))

            summary_image_name = graph.get_tensor_by_name(pattern_train %
                                                          common.IMAGE_NAME)
            summaries.add(
                tf.summary.text('gt/%s_train' % common.IMAGE_NAME,
                                tf.gather(summary_image_name, gather_list)))

            summary_vis = graph.get_tensor_by_name(pattern % 'vis')
            summaries.add(
                tf.summary.image('gt/%s' % 'vis',
                                 tf.gather(summary_vis, gather_list)))

            def scale_to_255(tensor, pixel_scaling=None):
                tensor = tf.to_float(tensor)
                if pixel_scaling == None:
                    offset_to_zero = tf.reduce_min(tensor)
                    scale_to_255 = tf.div(
                        255., tf.reduce_max(tensor - offset_to_zero))
                else:
                    offset_to_zero, scale_to_255 = pixel_scaling
                summary_tensor_float = tensor - offset_to_zero
                summary_tensor_float = summary_tensor_float * scale_to_255
                summary_tensor_float = tf.clip_by_value(
                    summary_tensor_float, 0., 255.)
                summary_tensor_uint8 = tf.cast(summary_tensor_float, tf.uint8)
                return summary_tensor_uint8, (offset_to_zero, scale_to_255)

            label_outputs = graph.get_tensor_by_name(pattern %
                                                     'label_pose_shape_map')
            label_id_outputs = graph.get_tensor_by_name(
                pattern % 'pose_shape_label_id_map')
            logit_outputs = graph.get_tensor_by_name(
                pattern % 'scaled_prob_logits_pose_shape_map')

            summary_rot_diffs = graph.get_tensor_by_name(pattern %
                                                         'rot_error_map')
            summary_rot_diffs = tf.where(summary_mask, summary_rot_diffs,
                                         tf.zeros_like(summary_rot_diffs))
            summary_rot_diffs_uint8, _ = scale_to_255(summary_rot_diffs)
            summaries.add(
                tf.summary.image(
                    'metrics_map/%s' % 'rot_diffs',
                    tf.gather(summary_rot_diffs_uint8, gather_list)))

            summary_trans_diffs = graph.get_tensor_by_name(pattern %
                                                           'trans_error_map')
            summary_trans_diffs = tf.where(summary_mask, summary_trans_diffs,
                                           tf.zeros_like(summary_trans_diffs))
            summary_trans_diffs_uint8, _ = scale_to_255(summary_trans_diffs)
            summaries.add(
                tf.summary.image('metrics_map/%s' % 'trans_diffs',
                                 tf.gather(summary_trans_diffs, gather_list)))

            shape_id_outputs = graph.get_tensor_by_name(pattern %
                                                        'shape_id_map')
            shape_id_outputs = tf.where(summary_mask, shape_id_outputs + 1,
                                        tf.zeros_like(shape_id_outputs))
            summary_shape_id_output_uint8, _ = scale_to_255(shape_id_outputs)
            summaries.add(
                tf.summary.image(
                    'shape/shape_id_map',
                    tf.gather(summary_shape_id_output_uint8, gather_list)))

            shape_id_outputs_gt = graph.get_tensor_by_name(pattern %
                                                           'shape_id_map_gt')
            shape_id_outputs_gt = tf.where(summary_mask,
                                           shape_id_outputs_gt + 1,
                                           tf.zeros_like(shape_id_outputs))
            summary_shape_id_output_uint8_gt, _ = scale_to_255(
                shape_id_outputs_gt)
            summaries.add(
                tf.summary.image(
                    'shape/shape_id_map_gt',
                    tf.gather(summary_shape_id_output_uint8_gt, gather_list)))

            if FLAGS.if_summary_metrics:
                shape_id_outputs = graph.get_tensor_by_name(
                    pattern % 'shape_id_map_predict')
                summary_shape_id_output = tf.where(
                    summary_mask, shape_id_outputs,
                    tf.zeros_like(shape_id_outputs))
                summary_shape_id_output_uint8, _ = scale_to_255(
                    summary_shape_id_output)
                summaries.add(
                    tf.summary.image(
                        'shape/shape_id_map_predict',
                        tf.gather(summary_shape_id_output_uint8, gather_list)))

                shape_id_sim_map_train = graph.get_tensor_by_name(
                    pattern_train % 'shape_id_sim_map')
                # shape_id_sim_map_train = tf.where(summary_mask, shape_id_sim_map_train, tf.zeros_like(shape_id_sim_map_train))
                shape_id_sim_map_uint8_train, _ = scale_to_255(
                    shape_id_sim_map_train, pixel_scaling=(0., 255.))
                summaries.add(
                    tf.summary.image(
                        'metrics_map/shape_id_sim_map-trainInv',
                        tf.gather(shape_id_sim_map_uint8_train, gather_list)))

                shape_id_sim_map = graph.get_tensor_by_name(pattern %
                                                            'shape_id_sim_map')
                # shape_id_sim_map = tf.where(summary_mask, shape_id_sim_map, tf.zeros_like(shape_id_sim_map))
                shape_id_sim_map_uint8, _ = scale_to_255(shape_id_sim_map,
                                                         pixel_scaling=(0.,
                                                                        255.))
                summaries.add(
                    tf.summary.image(
                        'metrics_map/shape_id_sim_map-valInv',
                        tf.gather(shape_id_sim_map_uint8, gather_list)))

            for output_idx, output in enumerate(dataset.output_names):
                # # Scale up summary image pixel values for better visualization.
                summary_label_output = tf.gather(label_outputs, [output_idx],
                                                 axis=3)
                summary_label_output = tf.where(
                    summary_mask, summary_label_output,
                    tf.zeros_like(summary_label_output))
                summary_label_output_uint8, pixel_scaling = scale_to_255(
                    summary_label_output)
                summaries.add(
                    tf.summary.image(
                        'output/%s_label' % output,
                        tf.gather(summary_label_output_uint8, gather_list)))

                summary_logit_output = tf.gather(logit_outputs, [output_idx],
                                                 axis=3)
                summary_logit_output = tf.where(
                    summary_mask, summary_logit_output,
                    tf.zeros_like(summary_logit_output))
                summary_logit_output_uint8, _ = scale_to_255(
                    summary_logit_output, pixel_scaling)
                summaries.add(
                    tf.summary.image(
                        'output/%s_logit' % output,
                        tf.gather(summary_logit_output_uint8, gather_list)))

                # summary_label_id_output = tf.to_float(tf.gather(label_id_outputs, [output_idx], axis=3))
                # summary_label_id_output = tf.where(summary_mask, summary_label_id_output+1, tf.zeros_like(summary_label_id_output))
                # summary_label_id_output_uint8, _ = scale_to_255(summary_label_id_output)
                # summary_label_id_output_uint8 = tf.identity(summary_label_id_output_uint8, 'tttt'+output)
                # summaries.add(tf.summary.image(
                #     'test/%s_label_id' % output, tf.gather(summary_label_id_output_uint8, gather_list)))

                summary_diff = tf.abs(
                    tf.to_float(summary_label_output_uint8) -
                    tf.to_float(summary_logit_output_uint8))
                summary_diff = tf.where(summary_mask, summary_diff,
                                        tf.zeros_like(summary_diff))
                summaries.add(
                    tf.summary.image(
                        'diff_map/%s_ldiff' % output,
                        tf.gather(tf.cast(summary_diff, tf.uint8),
                                  gather_list)))

                summary_loss = graph.get_tensor_by_name(
                    (pattern % 'loss_slice_reg_').replace(':0', '') + output +
                    ':0')
                summaries.add(
                    tf.summary.scalar(
                        'slice_loss/' + (pattern % 'reg_').replace(':0', '') +
                        output, summary_loss))

                summary_loss = graph.get_tensor_by_name(
                    (pattern % 'loss_slice_cls_').replace(':0', '') + output +
                    ':0')
                summaries.add(
                    tf.summary.scalar(
                        'slice_loss/' + (pattern % 'cls_').replace(':0', '') +
                        output, summary_loss))

            for pattern in [pattern_train, pattern_val
                            ] if FLAGS.if_val else [pattern_train]:
                add_metrics = ['loss_all_shape_id_cls_metric'
                               ] if FLAGS.if_summary_metrics else []
                for loss_name in [
                        'loss_reg_rot_quat_metric', 'loss_reg_rot_quat',
                        'loss_reg_trans_metric', 'loss_reg_trans',
                        'loss_cls_ALL', 'loss_reg_shape'
                ] + add_metrics:
                    if pattern == pattern_val:
                        summary_loss_avg = graph.get_tensor_by_name(pattern %
                                                                    loss_name)
                        # summary_loss_dict['val-'+loss_name] = summary_loss_avg
                    else:
                        summary_loss_avg = train_utils.get_avg_tensor_from_scopes(
                            FLAGS.num_clones, '%s:0', graph, config, loss_name)
                        # summary_loss_dict['train-'+loss_name] = summary_loss_avg
                    summaries.add(
                        tf.summary.scalar(
                            ('total_loss/' + pattern % loss_name).replace(
                                ':0', ''), summary_loss_avg))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            # optimizer = tf.train.AdamOptimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            print '------ total_loss', total_loss, tf.get_collection(
                tf.GraphKeys.LOSSES, first_clone_scope)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss/train', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            print '////last layers', last_layers

            # Filter trainable variables for last layers ONLY.
            # grads_and_vars = train_utils.filter_gradients(last_layers, grads_and_vars)

            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = False

        def train_step_fn(sess, train_op, global_step, train_step_kwargs):
            train_step_fn.step += 1  # or use global_step.eval(session=sess)

            # calc training losses
            loss, should_stop = slim.learning.train_step(
                sess, train_op, global_step, train_step_kwargs)
            print loss
            # print 'loss: ', loss
            # first_clone_test = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'shape_map')).strip('/'))
            # test = sess.run(first_clone_test)
            # # print test
            # print 'test: ', test.shape, np.max(test), np.min(test), np.mean(test), test.dtype
            should_stop = 0

            if FLAGS.if_val and train_step_fn.step % FLAGS.val_interval_steps == 0:
                # first_clone_test = graph.get_tensor_by_name('val-loss_all:0')
                # test = sess.run(first_clone_test)
                print '-- Validating...'
                first_clone_test = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'shape_id_map')).strip('/'))
                first_clone_test2 = graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, 'shape_id_sim_map')).strip('/'))
                # 'ttttrow:0')
                first_clone_test3 = graph.get_tensor_by_name((
                    '%s/%s:0' %
                    (first_clone_scope, 'not_ignore_mask_in_loss')).strip('/'))
                # 'ttttrow:0')
                test_out, test_out2, test_out3 = sess.run(
                    [first_clone_test, first_clone_test2, first_clone_test3])
                # test_out = test[:, :, :, 3]
                test_out = test_out[test_out3]
                # test_out2 = test2[:, :, :, 3]
                test_out2 = test_out2[test_out3]
                # print test_out
                print 'shape_id_map: ', test_out.shape, np.max(
                    test_out), np.min(test_out), np.mean(test_out), np.median(
                        test_out), test_out.dtype
                print 'shape_id_sim_map: ', test_out2.shape, np.max(
                    test_out2), np.min(test_out2), np.mean(
                        test_out2), np.median(test_out2), test_out2.dtype
                print 'masks sum: ', test_out3.dtype, np.sum(
                    test_out3.astype(float))
                # assert np.max(test_out) == np.max(test_out2), 'MAtch1!!!'
                # assert np.min(test_out) == np.min(test_out2), 'MAtch2!!!'

            # first_clone_label = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'pose_map')).strip('/')) # clone_0/val-loss:0
            # # first_clone_pose_dict = graph.get_tensor_by_name(
            # #         ('%s/%s:0' % (first_clone_scope, 'pose_dict')).strip('/'))
            # first_clone_logit = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'scaled_regression')).strip('/'))
            # not_ignore_mask = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'not_ignore_mask_in_loss')).strip('/'))
            # label, logits, mask = sess.run([first_clone_label, first_clone_logit, not_ignore_mask])
            # mask = np.reshape(mask, (-1, FLAGS.train_crop_size[0], FLAGS.train_crop_size[1], dataset.num_classes))

            # print '... shapes, types, loss', label.shape, label.dtype, logits.shape, logits.dtype, loss
            # print 'mask', mask.shape, np.mean(mask)
            # logits[mask==0.] = 0.
            # print 'logits', logits.shape, np.max(logits), np.min(logits), np.mean(logits), logits.dtype
            # for idx in range(6):
            #     print idx, np.max(label[:, :, :, idx]), np.min(label[:, :, :, idx])
            # label = label[:, :, :, 5]
            # print 'label', label.shape, np.max(label), np.min(label), np.mean(label), label.dtype
            # print pose_dict, pose_dict.shape
            # # print 'training....... logits stats: ', np.max(logits), np.min(logits), np.mean(logits)
            # # label_one_piece = label[0, :, :, 0]
            # # print 'training....... label stats', np.max(label_one_piece), np.min(label_one_piece), np.sum(label_one_piece[label_one_piece!=255.])
            return [loss, should_stop]

        train_step_fn.step = 0

        # trainables = [v.name for v in tf.trainable_variables()]
        # alls =[v.name for v in tf.all_variables()]
        # print '----- Trainables %d: '%len(trainables), trainables
        # print '----- All %d: '%len(alls), alls[:10]
        # print '===== ', len(list(set(trainables) - set(alls)))
        # print '===== ', len(list(set(alls) - set(trainables)))

        if FLAGS.if_print_tensors:
            for op in tf.get_default_graph().get_operations():
                print str(op.name)

        # Start the training.
        slim.learning.train(train_tensor,
                            train_step_fn=train_step_fn,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.restore_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.if_restore,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Beispiel #13
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  graph = tf.Graph()
  with graph.as_default():
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).')
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
          crop_size=FLAGS.train_crop_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          model_variant=FLAGS.model_variant,
          num_readers=2,
          is_training=True,
          should_shuffle=True,
          should_repeat=True)

      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)

      # Soft placement allows placing on CPU ops without GPU implementation.
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)

      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )

      stop_hook = tf.train.StopAtStepHook(FLAGS.training_number_of_steps)

      profile_dir = FLAGS.profile_logdir
      if profile_dir is not None:
        tf.gfile.MakeDirs(profile_dir)

      with tf.contrib.tfprof.ProfileContext(
          enabled=profile_dir is not None, profile_dir=profile_dir):
        with tf.train.MonitoredTrainingSession(
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            config=session_config,
            scaffold=scaffold,
            checkpoint_dir=FLAGS.train_logdir,
            summary_dir=FLAGS.train_logdir,
            log_step_count_steps=FLAGS.log_steps,
            save_summaries_steps=FLAGS.save_summaries_secs,
            save_checkpoint_secs=FLAGS.save_interval_secs,
            hooks=[stop_hook]) as sess:
          while not sess.should_stop():
            sess.run([train_tensor])
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size / config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default():
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes()
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)
    session_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth

    # Save checkpoints regularly.
    saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        saver=saver,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Beispiel #15
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
            assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
                'Training batch size not divisble by number of clones (GPUs).')
            clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=FLAGS.train_crop_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=2,
                is_training=True,
                should_shuffle=True,
                should_repeat=True)

            vdataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.trainval_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=FLAGS.trainval_batch_size,
                crop_size=FLAGS.train_crop_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=2,
                is_training=True,
                should_shuffle=False,
                should_repeat=False)

            viterator = vdataset.get_initializable_iterator()
            next_element = viterator.get_next()

            val_image = tf.placeholder(tf.float32,
                                       shape=(None, FLAGS.train_crop_size[0],
                                              FLAGS.train_crop_size[1], 3))
            val_label = tf.placeholder(tf.int32,
                                       shape=(None, FLAGS.train_crop_size[0],
                                              FLAGS.train_crop_size[1], 1))

            train_tensor, summary_op = _train_deeplab_model(
                dataset.get_one_shot_iterator(), dataset.num_of_classes,
                dataset.ignore_label)

            val_tensor = _val_loss(dataset=vdataset,
                                   image=val_image,
                                   label=val_label,
                                   num_of_classes=vdataset.num_of_classes,
                                   ignore_label=vdataset.ignore_label)

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)

            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            scaffold = tf.train.Scaffold(
                init_fn=init_fn,
                summary_op=summary_op,
            )

            stop_hook = tf.train.StopAtStepHook(FLAGS.training_number_of_steps)

            # Validation set variables
            epoch = 0
            val_loss_per_epoch = []
            steps_per_epoch = int(dataset.num_samples / FLAGS.train_batch_size)
            saver = tf.train.Saver(max_to_keep=1)

            profile_dir = FLAGS.profile_logdir
            if profile_dir is not None:
                tf.gfile.MakeDirs(profile_dir)

            with tf.contrib.tfprof.ProfileContext(enabled=profile_dir
                                                  is not None,
                                                  profile_dir=profile_dir):
                with tf.train.MonitoredTrainingSession(
                        master=FLAGS.master,
                        is_chief=(FLAGS.task == 0),
                        config=session_config,
                        scaffold=scaffold,
                        checkpoint_dir=FLAGS.train_logdir,
                        log_step_count_steps=FLAGS.log_steps,
                        save_summaries_steps=FLAGS.save_summaries_secs,
                        save_checkpoint_secs=FLAGS.save_interval_secs,
                        hooks=[]) as sess:
                    while not sess.should_stop():
                        step = sess.run(tf.train.get_global_step())
                        sess.run([train_tensor])
                        if step % steps_per_epoch == 0:
                            count_validation = 0
                            stop_training = False
                            val_losses = []
                            sess.run(viterator.initializer)
                            while True:
                                try:
                                    val_element = sess.run(next_element)
                                    val_loss, val_summary = sess.run(
                                        val_tensor,
                                        feed_dict={
                                            val_image:
                                            val_element[common.IMAGE],
                                            val_label:
                                            val_element[common.LABEL]
                                        })
                                    val_losses.append(val_loss)
                                    count_validation += 1
                                    #print('  {} [validation] {} {}'.format(count_validation, val_loss, val_element[common.IMAGE_NAME]))
                                except tf.errors.OutOfRangeError:
                                    total_val_loss = sum(val_losses) / len(
                                        val_losses)
                                    val_loss_per_epoch.append(total_val_loss)
                                    print('  {} [validation loss] {}'.format(
                                        count_validation *
                                        FLAGS.train_batch_size,
                                        total_val_loss))
                                    print('  {} [current epoch]   {}'.format(
                                        step, epoch))
                                    break
                            if epoch > 0:
                                min_delta = 0.01
                                patience = 8
                                stop_training = early_stopping(
                                    epoch, val_loss_per_epoch, min_delta,
                                    patience, sess, saver, total_val_loss)
                            # Stops training if current model val loss is worse than previous model val loss
                            if stop_training:
                                break
                            epoch += 1