示例#1
0
def main(_):
    clone_batch_size = FLAGS.batch_size
    steps_per_epoch = FLAGS.training_number_of_steps
    m_epoch = FLAGS.epoch

    train_dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.train_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=clone_batch_size,
        crop_size=FLAGS.train_crop_size,
        min_scale_factor=FLAGS.min_scale_factor,
        max_scale_factor=FLAGS.max_scale_factor,
        num_readers=2,
        is_training=True,
        should_shuffle=True,
        should_repeat=True)
    train_iterator = train_dataset.get_one_shot_iterator()

    test_dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.test_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=clone_batch_size,
        crop_size=FLAGS.train_crop_size,
        min_scale_factor=FLAGS.min_scale_factor,
        max_scale_factor=FLAGS.max_scale_factor,
        num_readers=2,
        is_training=False,
        should_shuffle=True,
        should_repeat=True)
    test_iterator = test_dataset.get_one_shot_iterator()

    # build model
    my_model, init_e = build_model()

    # train
    train(my_model,
          train_iterator,
          steps_per_epoch,
          m_epoch,
          initial_epoch=init_e)

    # evaluate and predict
    eval_and_predict(my_model, test_iterator)

    # summary and save
    my_model.summary()
示例#2
0
def main(_):
    clone_batch_size = 8
    # steps_per_epoch = int(1800 / clone_batch_size)
    # m_epoch = int((FLAGS.training_number_of_steps+1800) / 1800)
    steps_per_epoch = 600 * 4
    m_epoch = 1

    train_dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.train_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=clone_batch_size,
        crop_size=FLAGS.train_crop_size,
        min_scale_factor=FLAGS.min_scale_factor,
        max_scale_factor=FLAGS.max_scale_factor,
        num_readers=2,
        is_training=True,
        should_shuffle=False,
        should_repeat=False)
    train_iterator = train_dataset.get_one_shot_iterator()

    my_train(m_epoch,
             steps_per_epoch,
             batch_handler=handle_batch,
             iterator=train_iterator)
  def testPascalVocSegTestData(self):
    dataset = data_generator.Dataset(
        dataset_name='pascal_voc_seg',
        split_name='val',
        dataset_dir=
        'deeplab/testing/pascal_voc_seg',
        batch_size=1,
        crop_size=[3, 3],  # Use small size for testing.
        min_resize_value=3,
        max_resize_value=3,
        resize_factor=None,
        min_scale_factor=0.01,
        max_scale_factor=2.0,
        scale_factor_step_size=0.25,
        is_training=False,
        model_variant='mobilenet_v2')

    self.assertAllEqual(dataset.num_of_classes, 21)
    self.assertAllEqual(dataset.ignore_label, 255)

    num_of_images = 3
    with self.test_session() as sess:
      iterator = dataset.get_one_shot_iterator()

      for i in range(num_of_images):
        batch = iterator.get_next()
        batch, = sess.run([batch])
        image_attributes = _get_attributes_of_image(i)
        self.assertEqual(batch[common.HEIGHT][0], image_attributes.height)
        self.assertEqual(batch[common.WIDTH][0], image_attributes.width)
        self.assertEqual(batch[common.IMAGE_NAME][0],
                         image_attributes.image_name.encode())

      # All data have been read.
      with self.assertRaisesRegexp(tf.errors.OutOfRangeError, ''):
        sess.run([iterator.get_next()])
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  # Get dataset-dependent information.
  dataset = data_generator.Dataset(
      dataset_name=FLAGS.dataset,
      split_name=FLAGS.vis_split,
      dataset_dir=FLAGS.dataset_dir,
      batch_size=FLAGS.vis_batch_size,
      crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
      min_resize_value=FLAGS.min_resize_value,
      max_resize_value=FLAGS.max_resize_value,
      resize_factor=FLAGS.resize_factor,
      model_variant=FLAGS.model_variant,
      is_training=False,
      should_shuffle=False,
      should_repeat=False)

  train_id_to_eval_id = None
  if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
    tf.logging.info('Cityscapes requires converting train_id to eval_id.')
    train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID

  # Prepare for visualization.
  tf.gfile.MakeDirs(FLAGS.vis_logdir)
  save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
  tf.gfile.MakeDirs(save_dir)
  raw_save_dir = os.path.join(
      FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
  tf.gfile.MakeDirs(raw_save_dir)

  tf.logging.info('Visualizing on %s set', FLAGS.vis_split)

  with tf.Graph().as_default():
    samples = dataset.get_one_shot_iterator().get_next()

    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: dataset.num_of_classes},
        crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if tuple(FLAGS.eval_scales) == (1.0,):
      tf.logging.info('Performing single-scale test.')
      predictions = model.predict_labels(
          samples[common.IMAGE],
          model_options=model_options,
          image_pyramid=FLAGS.image_pyramid)
    else:
      tf.logging.info('Performing multi-scale test.')
      if FLAGS.quantize_delay_step >= 0:
        raise ValueError(
            'Quantize mode is not supported with multi-scale test.')
      predictions = model.predict_labels_multi_scale(
          samples[common.IMAGE],
          model_options=model_options,
          eval_scales=FLAGS.eval_scales,
          add_flipped_images=FLAGS.add_flipped_images)
    predictions = predictions[common.OUTPUT_TYPE]

    if FLAGS.min_resize_value and FLAGS.max_resize_value:
      # Only support batch_size = 1, since we assume the dimensions of original
      # image after tf.squeeze is [height, width, 3].
      assert FLAGS.vis_batch_size == 1

      # Reverse the resizing and padding operations performed in preprocessing.
      # First, we slice the valid regions (i.e., remove padded region) and then
      # we resize the predictions back.
      original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
      original_image_shape = tf.shape(original_image)
      predictions = tf.slice(
          predictions,
          [0, 0, 0],
          [1, original_image_shape[0], original_image_shape[1]])
      resized_shape = tf.to_int32([tf.squeeze(samples[common.HEIGHT]),
                                   tf.squeeze(samples[common.WIDTH])])
      predictions = tf.squeeze(
          tf.image.resize_images(tf.expand_dims(predictions, 3),
                                 resized_shape,
                                 method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                                 align_corners=True), 3)

    tf.train.get_or_create_global_step()
    if FLAGS.quantize_delay_step >= 0:
      contrib_quantize.create_eval_graph()

    num_iteration = 0
    max_num_iteration = FLAGS.max_number_of_iterations

    checkpoints_iterator = contrib_training.checkpoints_iterator(
        FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
    for checkpoint_path in checkpoints_iterator:
      num_iteration += 1
      tf.logging.info(
          'Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                       time.gmtime()))
      tf.logging.info('Visualizing with model %s', checkpoint_path)

      scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer())
      session_creator = tf.train.ChiefSessionCreator(
          scaffold=scaffold,
          master=FLAGS.master,
          checkpoint_filename_with_path=checkpoint_path)
      with tf.train.MonitoredSession(
          session_creator=session_creator, hooks=None) as sess:
        batch = 0
        image_id_offset = 0

        while not sess.should_stop():
          tf.logging.info('Visualizing batch %d', batch + 1)
          _process_batch(sess=sess,
                         original_images=samples[common.ORIGINAL_IMAGE],
                         semantic_predictions=predictions,
                         image_names=samples[common.IMAGE_NAME],
                         image_heights=samples[common.HEIGHT],
                         image_widths=samples[common.WIDTH],
                         image_id_offset=image_id_offset,
                         save_dir=save_dir,
                         raw_save_dir=raw_save_dir,
                         train_id_to_eval_id=train_id_to_eval_id)
          image_id_offset += FLAGS.vis_batch_size
          batch += 1

      tf.logging.info(
          'Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
                                                       time.gmtime()))
      if max_num_iteration > 0 and num_iteration >= max_num_iteration:
        break
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.eval_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=FLAGS.eval_batch_size,
        crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        model_variant=FLAGS.model_variant,
        num_readers=2,
        is_training=False,
        should_shuffle=False,
        should_repeat=False)

    tf.gfile.MakeDirs(FLAGS.eval_logdir)
    tf.logging.info('Evaluating on %s set', FLAGS.eval_split)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)

    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False,
                                    gpu_options=gpu_options)

    #session_config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        samples = dataset.get_one_shot_iterator().get_next()
        #print(samples[common.IMAGE_NAME])

        model_options = common.ModelOptions(
            outputs_to_num_classes={
                common.OUTPUT_TYPE: dataset.num_of_classes
            },
            crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.

        samples[common.IMAGE].set_shape([
            FLAGS.eval_batch_size,
            int(FLAGS.eval_crop_size[0]),
            int(FLAGS.eval_crop_size[1]), 3
        ])
        if tuple(FLAGS.eval_scales) == (1.0, ):

            tf.logging.info('Performing single-scale test.')
            predictions, logits = model.predict_labels(
                samples[common.IMAGE],
                model_options,
                image_pyramid=FLAGS.image_pyramid,
                skips=FLAGS.skips)

        else:
            tf.logging.info('Performing multi-scale test.')
            if FLAGS.quantize_delay_step >= 0:
                raise ValueError(
                    'Quantize mode is not supported with multi-scale test.')

            predictions = model.predict_labels_multi_scale(
                samples[common.IMAGE],
                model_options=model_options,
                skips=FLAGS.skips,
                eval_scales=FLAGS.eval_scales,
                add_flipped_images=FLAGS.add_flipped_images)
        predictions = predictions[common.OUTPUT_TYPE]
        predictions = tf.reshape(predictions, shape=[-1])
        labels = tf.reshape(samples[common.LABEL], shape=[-1])
        weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))

        # Set ignore_label regions to label 0, because metrics.mean_iou requires
        # range of labels = [0, dataset.num_classes). Note the ignore_label regions
        # are not evaluated since the corresponding regions contain weights = 0.
        labels = tf.where(tf.equal(labels, dataset.ignore_label),
                          tf.zeros_like(labels), labels)

        predictions_tag = 'miou'
        for eval_scale in FLAGS.eval_scales:
            predictions_tag += '_' + str(eval_scale)
        if FLAGS.add_flipped_images:
            predictions_tag += '_flipped'

        # Define the evaluation metric.

        metric_map = {}

        # to remove "predictions out of bound error"
        indices = tf.squeeze(
            tf.where(tf.less_equal(labels, dataset.num_of_classes - 1)), 1)
        labels_ind = tf.cast(tf.gather(labels, indices), tf.int32)
        predictions_ind = tf.gather(predictions, indices)
        # end of insert

        miou, update_miou = tf.metrics.mean_iou(labels_ind,
                                                predictions_ind,
                                                dataset.num_of_classes,
                                                weights=weights,
                                                name="mean_iou")
        tf.summary.scalar(predictions_tag, miou)

        # Define the evaluation metric IOU for individual classes
        iou_v, update_op = my_metrics.iou(labels_ind,
                                          predictions_ind,
                                          dataset.num_of_classes,
                                          weights=weights)
        for index in range(0, dataset.num_of_classes):
            metric_map['class_' + str(index) + '_iou'] = (iou_v[index],
                                                          update_op[index])
            tf.summary.scalar('class_' + str(index) + '_iou', iou_v[index])

        # Confusion matrix save hook. It updates the confusion matrix on tensorboard at the end of eval loop.
        confusionMatrixSaveHook = confusion_matrix.SaverHook(
            labels=['BG', 'water', 'ice', 'snow', 'clutter'],
            confusion_matrix_tensor_name='mean_iou/total_confusion_matrix',
            summary_writer=tf.summary.FileWriterCache.get(
                str(FLAGS.eval_logdir)))

        summary_op = tf.summary.merge_all()

        summary_hook = tf.contrib.training.SummaryAtEndHook(
            log_dir=FLAGS.eval_logdir, summary_op=summary_op)
        hooks = [summary_hook, confusionMatrixSaveHook]

        num_eval_iters = None
        if FLAGS.max_number_of_evaluations > 0:
            num_eval_iters = FLAGS.max_number_of_evaluations

        if FLAGS.quantize_delay_step >= 0:
            tf.contrib.quantize.create_eval_graph()

        tf.contrib.training.evaluate_repeatedly(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.checkpoint_dir,
            eval_ops=[update_miou, update_op],
            max_number_of_evaluations=num_eval_iters,
            hooks=hooks,
            eval_interval_secs=FLAGS.eval_interval_secs)
示例#6
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
            assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
                'Training batch size not divisble by number of clones (GPUs).')
            clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=2,
                is_training=True,
                should_shuffle=True,
                should_repeat=True)

            train_tensor, summary_op = _train_deeplab_model(
                dataset.get_one_shot_iterator(), dataset.num_of_classes,
                dataset.ignore_label)

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)

            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            scaffold = tf.train.Scaffold(
                init_fn=init_fn,
                summary_op=summary_op,
            )

            stop_hook = tf.train.StopAtStepHook(
                last_step=FLAGS.training_number_of_steps)

            profile_dir = FLAGS.profile_logdir
            if profile_dir is not None:
                tf.gfile.MakeDirs(profile_dir)

            with tf.contrib.tfprof.ProfileContext(enabled=profile_dir
                                                  is not None,
                                                  profile_dir=profile_dir):
                with tf.train.MonitoredTrainingSession(
                        master=FLAGS.master,
                        is_chief=(FLAGS.task == 0),
                        config=session_config,
                        scaffold=scaffold,
                        checkpoint_dir=FLAGS.train_logdir,
                        summary_dir=FLAGS.train_logdir,
                        log_step_count_steps=FLAGS.log_steps,
                        save_summaries_steps=FLAGS.save_summaries_secs,
                        save_checkpoint_secs=FLAGS.save_interval_secs,
                        hooks=[stop_hook]) as sess:
                    while not sess.should_stop():
                        sess.run([train_tensor])
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.eval_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=FLAGS.eval_batch_size,
        crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        model_variant=FLAGS.model_variant,
        num_readers=2,
        is_training=False,
        should_shuffle=False,
        should_repeat=False)

    tf.gfile.MakeDirs(FLAGS.eval_logdir)
    tf.logging.info('Evaluating on %s set', FLAGS.eval_split)

    with tf.Graph().as_default():
        samples = dataset.get_one_shot_iterator().get_next()

        model_options = common.ModelOptions(
            outputs_to_num_classes={
                common.OUTPUT_TYPE: dataset.num_of_classes
            },
            crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.
        samples[common.IMAGE].set_shape([
            FLAGS.eval_batch_size,
            int(FLAGS.eval_crop_size[0]),
            int(FLAGS.eval_crop_size[1]), 3
        ])
        if tuple(FLAGS.eval_scales) == (1.0, ):
            tf.logging.info('Performing single-scale test.')
            predictions = model_func.predict_labels(
                samples[common.IMAGE],
                model_options,
                image_pyramid=FLAGS.image_pyramid)
        else:
            tf.logging.info('Performing multi-scale test.')
            if FLAGS.quantize_delay_step >= 0:
                raise ValueError(
                    'Quantize mode is not supported with multi-scale test.')

            predictions = model_func.predict_labels_multi_scale(
                samples[common.IMAGE],
                model_options=model_options,
                eval_scales=FLAGS.eval_scales,
                add_flipped_images=FLAGS.add_flipped_images)
        predictions = predictions[common.OUTPUT_TYPE]
        predictions = tf.reshape(predictions, shape=[-1])
        labels = tf.reshape(samples[common.LABEL], shape=[-1])
        weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))

        # Set ignore_label regions to label 0, because metrics.mean_iou requires
        # range of labels = [0, dataset.num_classes). Note the ignore_label regions
        # are not evaluated since the corresponding regions contain weights = 0.
        labels = tf.where(tf.equal(labels, dataset.ignore_label),
                          tf.zeros_like(labels), labels)

        predictions_tag = 'miou'
        for eval_scale in FLAGS.eval_scales:
            predictions_tag += '_' + str(eval_scale)
        if FLAGS.add_flipped_images:
            predictions_tag += '_flipped'

        # Define the evaluation metric.
        metric_map = {}
        num_classes = dataset.num_of_classes
        metric_map['eval/%s_overall' % predictions_tag] = tf.metrics.mean_iou(
            labels=labels,
            predictions=predictions,
            num_classes=num_classes,
            weights=weights)
        # IoU for each class.
        one_hot_predictions = tf.one_hot(predictions, num_classes)
        one_hot_predictions = tf.reshape(one_hot_predictions,
                                         [-1, num_classes])
        one_hot_labels = tf.one_hot(labels, num_classes)
        one_hot_labels = tf.reshape(one_hot_labels, [-1, num_classes])
        for c in range(num_classes):
            predictions_tag_c = '%s_class_%d' % (predictions_tag, c)
            tp, tp_op = tf.metrics.true_positives(
                labels=one_hot_labels[:, c],
                predictions=one_hot_predictions[:, c],
                weights=weights)
            fp, fp_op = tf.metrics.false_positives(
                labels=one_hot_labels[:, c],
                predictions=one_hot_predictions[:, c],
                weights=weights)
            fn, fn_op = tf.metrics.false_negatives(
                labels=one_hot_labels[:, c],
                predictions=one_hot_predictions[:, c],
                weights=weights)
            tp_fp_fn_op = tf.group(tp_op, fp_op, fn_op)
            iou = tf.where(tf.greater(tp + fn, 0.0), tp / (tp + fn + fp),
                           tf.constant(np.NaN))
            metric_map['eval/%s' % predictions_tag_c] = (iou, tp_fp_fn_op)

        (metrics_to_values,
         metrics_to_updates) = contrib_metrics.aggregate_metric_map(metric_map)

        summary_ops = []
        for metric_name, metric_value in six.iteritems(metrics_to_values):
            op = tf.summary.scalar(metric_name, metric_value)
            op = tf.Print(op, [metric_value], metric_name)
            summary_ops.append(op)

        summary_op = tf.summary.merge(summary_ops)
        summary_hook = contrib_training.SummaryAtEndHook(
            log_dir=FLAGS.eval_logdir, summary_op=summary_op)
        hooks = [summary_hook]

        num_eval_iters = None
        if FLAGS.max_number_of_evaluations > 0:
            num_eval_iters = FLAGS.max_number_of_evaluations

        if FLAGS.quantize_delay_step >= 0:
            contrib_quantize.create_eval_graph()

        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.
            TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
        contrib_training.evaluate_repeatedly(
            checkpoint_dir=FLAGS.checkpoint_dir,
            master=FLAGS.master,
            eval_ops=list(metrics_to_updates.values()),
            max_number_of_evaluations=num_eval_iters,
            hooks=hooks,
            eval_interval_secs=FLAGS.eval_interval_secs)
示例#8
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (dataset.get_one_shot_iterator(), {
                common.OUTPUT_TYPE: dataset.num_of_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        if FLAGS.quantize_delay_step >= 0:
            if FLAGS.num_clones > 1:
                raise ValueError(
                    'Quantization doesn\'t support multi-clone yet.')
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay_step)

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
示例#9
0
def main(unused_argv):
    # Check model parameters
    check_model_conflict()

    data_inforamtion = data_generator._DATASETS_INFORMATION[FLAGS.dataset_name]
    tf.logging.set_verbosity(tf.logging.INFO)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    for split in FLAGS.train_split:
        tf.logging.info('Training on %s set', split)

    path = FLAGS.train_logdir
    parameters_dict = vars(FLAGS)
    with open(os.path.join(path, 'json.txt'), 'w', encoding='utf-8') as f:
        json.dump(parameters_dict, f, indent=3)

    with open(os.path.join(path, 'logging.txt'), 'w') as f:
        for key in parameters_dict:
            f.write("{}: {}".format(str(key), str(parameters_dict[key])))
            f.write("\n")
        f.write("\nStart time: {}".format(
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
        f.write("\n")

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
            assert FLAGS.batch_size % FLAGS.num_clones == 0, (
                'Training batch size not divisble by number of clones (GPUs).')
            clone_batch_size = FLAGS.batch_size // FLAGS.num_clones

            if FLAGS.dataset_name == '2019_ISBI_CHAOS_MR_T1' or FLAGS.dataset_name == '2019_ISBI_CHAOS_MR_T2':
                min_resize_value = data_inforamtion.height
                max_resize_value = data_inforamtion.height
            else:
                if FLAGS.min_resize_value is not None:
                    min_resize_value = FLAGS.min_resize_value
                else:
                    min_resize_value = data_inforamtion.height

                if FLAGS.max_resize_value is not None:
                    max_resize_value = FLAGS.max_resize_value
                else:
                    max_resize_value = data_inforamtion.height

            train_generator = data_generator.Dataset(
                dataset_name=FLAGS.dataset_name,
                split_name=FLAGS.train_split,
                guidance_type=FLAGS.guidance_type,
                batch_size=clone_batch_size,
                pre_crop_flag=FLAGS.pre_crop_flag,
                mt_class=FLAGS.mt_output_node,
                crop_size=data_inforamtion.train["train_crop_size"],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                num_readers=2,
                is_training=True,
                shuffle_data=True,
                repeat_data=True,
                prior_num_slice=FLAGS.prior_num_slice,
                prior_num_subject=FLAGS.prior_num_subject,
                seq_length=FLAGS.seq_length,
                seq_type="bidirection",
                z_loss_name=FLAGS.z_loss_name,
            )

            if "val" not in FLAGS.train_split:
                val_generator = data_generator.Dataset(
                    dataset_name=FLAGS.dataset_name,
                    split_name=["val"],
                    guidance_type=FLAGS.guidance_type,
                    batch_size=1,
                    mt_class=FLAGS.mt_output_node,
                    crop_size=[
                        data_inforamtion.height, data_inforamtion.width
                    ],
                    min_resize_value=FLAGS.min_resize_value,
                    max_resize_value=FLAGS.max_resize_value,
                    num_readers=2,
                    is_training=False,
                    shuffle_data=False,
                    repeat_data=True,
                    prior_num_slice=FLAGS.prior_num_slice,
                    prior_num_subject=FLAGS.prior_num_subject,
                    seq_length=FLAGS.seq_length,
                    seq_type="bidirection",
                    z_loss_name=FLAGS.z_loss_name,
                )

            model_options = common.ModelOptions(
                outputs_to_num_classes=train_generator.num_of_classes,
                crop_size=data_inforamtion.train["train_crop_size"],
                output_stride=FLAGS.output_stride)

            steps = tf.compat.v1.placeholder(tf.int32, shape=[])

            dataset1 = train_generator.get_dataset()
            iter1 = dataset1.make_one_shot_iterator()
            train_samples = iter1.get_next()

            train_tensor, summary_op = _train_pgn_model(
                train_samples, train_generator.num_of_classes, model_options,
                train_generator.ignore_label)

            if "val" not in FLAGS.train_split:
                dataset2 = val_generator.get_dataset()
                iter2 = dataset2.make_one_shot_iterator()
                val_samples = iter2.get_next()

                val_tensor, _ = _val_pgn_model(val_samples,
                                               val_generator.num_of_classes,
                                               model_options,
                                               val_generator.ignore_label,
                                               steps)

            # Soft placement allows placing on CPU ops without GPU implementation.
            session_config = tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False)

            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    train_logdir=FLAGS.train_logdir,
                    tf_initial_checkpoint=FLAGS.tf_initial_checkpoint,
                    initialize_first_layer=True,
                    initialize_last_layer=FLAGS.initialize_last_layer,
                    ignore_missing_vars=True)

            scaffold = tf.train.Scaffold(
                init_fn=init_fn,
                summary_op=summary_op,
            )

            stop_hook = tf.train.StopAtStepHook(FLAGS.training_number_of_steps)
            saver = tf.train.Saver()
            best_dice = 0
            with tf.train.MonitoredTrainingSession(
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    config=session_config,
                    scaffold=scaffold,
                    checkpoint_dir=FLAGS.train_logdir,
                    log_step_count_steps=FLAGS.log_steps,
                    save_summaries_steps=20,
                    save_checkpoint_steps=FLAGS.save_checkpoint_steps,
                    hooks=[stop_hook]) as sess:

                # step=0
                total_val_loss, total_val_steps = [], []
                best_model_performance = 0.0
                while not sess.should_stop():
                    _, global_step = sess.run(
                        [train_tensor,
                         tf.train.get_global_step()])
                    if "val" not in FLAGS.train_split:
                        if global_step % FLAGS.validation_steps == 0:
                            cm_total = 0
                            for j in range(
                                    val_generator.splits_to_sizes["val"]):
                                cm_total += sess.run(val_tensor,
                                                     feed_dict={steps: j})

                            mean_dice_score, _ = metrics.compute_mean_dsc(
                                total_cm=cm_total)

                            total_val_loss.append(mean_dice_score)
                            total_val_steps.append(global_step)
                            plt.legend(["validation loss"])
                            plt.xlabel("global step")
                            plt.ylabel("loss")
                            plt.plot(total_val_steps, total_val_loss, "bo-")
                            plt.grid(True)
                            plt.savefig(FLAGS.train_logdir + "/losses.png")

                            if mean_dice_score > best_dice:
                                best_dice = mean_dice_score
                                saver.save(
                                    get_session(sess),
                                    os.path.join(FLAGS.train_logdir,
                                                 'model.ckpt-best'))
                                # saver.save(get_session(sess), os.path.join(FLAGS.train_logdir, 'model.ckpt-best-%d' %global_step))
                                txt = 20 * ">" + " saving best mdoel model.ckpt-best-%d with DSC: %f" % (
                                    global_step, best_dice)
                                print(txt)
                                with open(os.path.join(path, 'logging.txt'),
                                          'a') as f:
                                    f.write(txt)
                                    f.write("\n")

            with open(os.path.join(path, 'logging.txt'), 'a') as f:
                f.write("\nEnd time: {}".format(
                    time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
                f.write("\n")