示例#1
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=False,
                should_repeat=True)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.

            model_args = (dataset.get_one_shot_iterator(), {
                common.OUTPUT_TYPE: dataset.num_of_classes,
                common.INSTANCE: 1,
                common.OFFSET: 2
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                _build_deeplab,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            ############ SEG LABEL ###############

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

            ########### INSTANCE CENTER ###########

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.LABEL_INSTANCE)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            summary_label = tf.cast(first_clone_label, tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL_INSTANCE,
                                 summary_label))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.INSTANCE)).strip('/'))
            # Scale up summary image pixel values for better visualization.

            label_max_x = tf.reduce_max(first_clone_label)
            first_clone_label = tf.multiply(
                tf.divide(first_clone_label, label_max_x), 255.0)
            #upscaling = tf.constant(255.0, dtype=tf.float32)

            #summary_label = tf.cast(first_clone_label * upscaling, tf.uint8)
            summary_label = tf.cast(first_clone_label, tf.uint8)

            summaries.add(
                tf.summary.image('samples/%s' % common.INSTANCE,
                                 summary_label))

            ########## INSTANCE OFFSET ###########

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.LABEL_OFFSET)).strip('/'))

            # Scale up (between 0-255) summary image pixel values for better visualization.

            x_offset, y_offset, extra = tf.split(first_clone_label, 3, 3)
            label_max_x = tf.reduce_max(x_offset)
            x_offset = tf.multiply(tf.divide(x_offset, label_max_x), 255.0)
            label_max_y = tf.reduce_max(y_offset)
            y_offset = tf.multiply(tf.divide(y_offset, label_max_y), 255.0)
            channel_padding = tf.zeros_like(x_offset)
            first_clone_label = tf.concat(
                [x_offset, y_offset, channel_padding], 3)

            summary_label = tf.image.convert_image_dtype(first_clone_label,
                                                         dtype=tf.uint8,
                                                         saturate=True)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL_OFFSET,
                                 summary_label))

            ################ PREDICTIOS #######################

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.OFFSET)).strip('/'))
            # Scale up (between 0-255) summary image pixel values for better visualization.

            #first_clone_label = tf.squeeze(first_clone_label, 0)
            x_offset, y_offset = tf.split(first_clone_label, 2, 3)
            label_max_x = tf.reduce_max(x_offset)
            x_offset = tf.multiply(tf.divide(x_offset, label_max_x), 255.0)
            label_max_y = tf.reduce_max(y_offset)
            y_offset = tf.multiply(tf.divide(y_offset, label_max_y), 255.0)
            channel_padding = tf.zeros_like(x_offset)
            first_clone_label = tf.concat(
                [x_offset, y_offset, channel_padding], 3)
            #first_clone_label = tf.expand_dims(first_clone_label, 0)

            summary_label = tf.cast(first_clone_label, tf.uint8)

            summaries.add(
                tf.summary.image('samples/%s' % common.OFFSET, summary_label))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        if FLAGS.quantize_delay_step >= 0:
            if FLAGS.num_clones > 1:
                raise ValueError(
                    'Quantization doesn\'t support multi-clone yet.')
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay_step)

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
示例#2
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.eval_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=FLAGS.eval_batch_size,
        crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        model_variant=FLAGS.model_variant,
        num_readers=2,
        is_training=False,
        should_shuffle=False,
        should_repeat=False)

    tf.gfile.MakeDirs(FLAGS.eval_logdir)
    tf.logging.info('Evaluating on %s set', FLAGS.eval_split)

    with tf.Graph().as_default():
        samples = dataset.get_one_shot_iterator().get_next()

        model_options = common.ModelOptions(
            outputs_to_num_classes={
                common.OUTPUT_TYPE: dataset.num_of_classes,
                common.INSTANCE: 1,
                common.OFFSET: 2
            },
            crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.
        samples[common.IMAGE].set_shape([
            FLAGS.eval_batch_size,
            int(FLAGS.eval_crop_size[0]),
            int(FLAGS.eval_crop_size[1]), 3
        ])

        if tuple(FLAGS.eval_scales) == (1.0, ):
            tf.logging.info('Performing single-scale test.')
            predictions = model.predict_labels(
                samples[common.IMAGE],
                model_options,
                image_pyramid=FLAGS.image_pyramid)
        else:
            tf.logging.info('Performing multi-scale test.')
            if FLAGS.quantize_delay_step >= 0:
                raise ValueError(
                    'Quantize mode is not supported with multi-scale test.')

        predictions_semantic = predictions[common.OUTPUT_TYPE]
        predictions_center_points = predictions[common.INSTANCE]
        predictions_offset_vectors = predictions[common.OFFSET]

        # tf Non-maxima Suppression
        # Pooling based NMS for Pooling Instance Centers
        # Filtering predictions that are less than 0.1
        instance_prediction = generate_instance_segmentation(
            predictions_semantic, predictions_center_points,
            predictions_offset_vectors)

        category_prediction = tf.squeeze(predictions_semantic)

        category_label = tf.squeeze(samples[common.LABEL][0])
        not_ignore_mask = tf.not_equal(category_label, 255)
        category_label = tf.cast(
            category_label * tf.cast(not_ignore_mask, tf.int32), tf.int32)
        instance_label = tf.squeeze(samples[common.LABEL_INSTANCE_IDS][0])
        category_prediction = category_prediction * tf.cast(
            not_ignore_mask, tf.int64)
        instance_prediction = instance_prediction * tf.cast(
            not_ignore_mask, tf.int64)

        # Define the evaluation metric.
        metric_map = {}
        metric_map[
            'panoptic_quality'] = streaming_metrics.streaming_panoptic_quality(
                category_label,
                instance_label,
                category_prediction,
                instance_prediction,
                num_classes=19,
                max_instances_per_category=256,
                ignored_label=255,
                offset=256 * 256)
        metric_map[
            'parsing_covering'] = streaming_metrics.streaming_parsing_covering(
                category_label,
                instance_label,
                category_prediction,
                instance_prediction,
                num_classes=19,
                max_instances_per_category=256,
                ignored_label=255,
                offset=256 * 256,
                normalize_by_image_size=True)
        metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map(
            metric_map)

        summary_ops = []
        for metric_name, metric_value in metrics_to_values.iteritems():
            if metric_name == 'panoptic_quality':
                [pq, sq, rq, total_tp, total_fn,
                 total_fp] = tf.unstack(metric_value, 6, axis=0)
                panoptic_metrics = {
                    # Panoptic quality.
                    'pq': pq,
                    # Segmentation quality.
                    'sq': sq,
                    # Recognition quality.
                    'rq': rq,
                    # Total true positives.
                    'total_tp': total_tp,
                    # Total false negatives.
                    'total_fn': total_fn,
                    # Total false positives.
                    'total_fp': total_fp,
                }
                # Find the valid classes that will be used for evaluation. We will
                # ignore the `ignore_label` class and other classes which have (tp + fn
                # + fp) equal to 0.
                valid_classes = tf.logical_and(
                    tf.not_equal(tf.range(0, dataset.num_of_classes),
                                 dataset.ignore_label),
                    tf.not_equal(total_tp + total_fn + total_fp, 0))
                for target_metric, target_value in panoptic_metrics.iteritems(
                ):
                    output_metric_name = '{}_{}'.format(
                        metric_name, target_metric)
                    op = tf.summary.scalar(
                        output_metric_name,
                        tf.reduce_mean(
                            tf.boolean_mask(target_value, valid_classes)))
                    op = tf.Print(op, [target_value],
                                  output_metric_name + '_classwise: ',
                                  summarize=dataset.num_of_classes)
                    op = tf.Print(op, [
                        tf.reduce_mean(
                            tf.boolean_mask(target_value, valid_classes))
                    ],
                                  output_metric_name + '_mean: ',
                                  summarize=1)
                    summary_ops.append(op)
            elif metric_name == 'parsing_covering':
                [
                    per_class_covering, total_per_class_weighted_ious,
                    total_per_class_gt_areas
                ] = tf.unstack(metric_value, 3, axis=0)
                # Find the valid classes that will be used for evaluation. We will
                # ignore the `void_label` class and other classes which have
                # total_per_class_weighted_ious + total_per_class_gt_areas equal to 0.
                valid_classes = tf.logical_and(
                    tf.not_equal(tf.range(0, dataset.num_of_classes),
                                 dataset.ignore_label),
                    tf.not_equal(
                        total_per_class_weighted_ious +
                        total_per_class_gt_areas, 0))
                op = tf.summary.scalar(
                    metric_name,
                    tf.reduce_mean(
                        tf.boolean_mask(per_class_covering, valid_classes)))
                op = tf.Print(op, [per_class_covering],
                              metric_name + '_classwise: ',
                              summarize=dataset.num_of_classes)
                op = tf.Print(op, [
                    tf.reduce_mean(
                        tf.boolean_mask(per_class_covering, valid_classes))
                ],
                              metric_name + '_mean: ',
                              summarize=1)
                summary_ops.append(op)
            else:
                raise ValueError('The metric_name "%s" is not supported.' %
                                 metric_name)

        num_eval_iters = None
        if FLAGS.max_number_of_evaluations > 0:
            num_eval_iters = FLAGS.max_number_of_evaluations

        if FLAGS.quantize_delay_step >= 0:
            contrib_quantize.create_eval_graph()

        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.
            TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS)

        metric_values = slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.checkpoint_dir,
            logdir=FLAGS.eval_logdir,
            num_evals=20,
            eval_op=metrics_to_updates.values(),
            final_op=metrics_to_values.values(),
            summary_op=tf.summary.merge(summary_ops),
            max_number_of_evaluations=FLAGS.max_number_of_evaluations,
            eval_interval_secs=FLAGS.eval_interval_secs)
示例#3
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Get dataset-dependent information.
    dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.vis_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=FLAGS.vis_batch_size,
        crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        model_variant=FLAGS.model_variant,
        is_training=False,
        should_shuffle=False,
        should_repeat=False)

    train_id_to_eval_id = None
    if dataset.dataset_name == data_generator.get_cityscapes_dataset_name():
        tf.logging.info('Cityscapes requires converting train_id to eval_id.')
        train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID

    # Prepare for visualization.
    tf.gfile.MakeDirs(FLAGS.vis_logdir)

    save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(save_dir)

    instance_save_dir = os.path.join(FLAGS.vis_logdir,
                                     _INSTANCE_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(instance_save_dir)

    regression_save_dir = os.path.join(FLAGS.vis_logdir,
                                       _OFFSET_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(regression_save_dir)

    panoptic_save_dir = os.path.join(FLAGS.vis_logdir,
                                     _PANOPTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(panoptic_save_dir)

    raw_save_dir = os.path.join(FLAGS.vis_logdir,
                                _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(raw_save_dir)

    tf.logging.info('Visualizing on %s set', FLAGS.vis_split)

    with tf.Graph().as_default():
        samples = dataset.get_one_shot_iterator().get_next()

        model_options = common.ModelOptions(
            outputs_to_num_classes={
                common.OUTPUT_TYPE: dataset.num_of_classes,
                common.INSTANCE: 1,
                common.OFFSET: 2
            },
            crop_size=[int(sz) for sz in FLAGS.vis_crop_size],
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        if tuple(FLAGS.eval_scales) == (1.0, ):
            tf.logging.info('Performing single-scale test.')
            predictions = model.predict_labels(
                samples[common.IMAGE],
                model_options=model_options,
                image_pyramid=FLAGS.image_pyramid)
        else:
            tf.logging.info('Performing multi-scale test.')
            if FLAGS.quantize_delay_step >= 0:
                raise ValueError(
                    'Quantize mode is not supported with multi-scale test.')
            predictions = model.predict_labels_multi_scale(
                samples[common.IMAGE],
                model_options=model_options,
                eval_scales=FLAGS.eval_scales,
                add_flipped_images=FLAGS.add_flipped_images)

        predictions_semantic = predictions[common.OUTPUT_TYPE]
        predictions_instance = predictions[common.INSTANCE]
        predictions_regression = predictions[common.OFFSET]

        if FLAGS.min_resize_value and FLAGS.max_resize_value:
            # Only support batch_size = 1, since we assume the dimensions of original
            # image after tf.squeeze is [height, width, 3].
            assert FLAGS.vis_batch_size == 1

            # Reverse the resizing and padding operations performed in preprocessing.
            # First, we slice the valid regions (i.e., remove padded region) and then
            # we resize the predictions back.
            original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE])
            original_image_shape = tf.shape(original_image)
            predictions_semantic = tf.slice(
                predictions_semantic, [0, 0, 0],
                [1, original_image_shape[0], original_image_shape[1]])
            resized_shape = tf.to_int32([
                tf.squeeze(samples[common.HEIGHT]),
                tf.squeeze(samples[common.WIDTH])
            ])
            predictions_semantic = tf.squeeze(
                tf.image.resize_images(
                    tf.expand_dims(predictions_semantic, 3),
                    resized_shape,
                    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                    align_corners=True), 3)
            ############################### POST PROCESSING LOGITS FROM INSTANCE CENTER #####################
            predictions_instance = tf.slice(
                predictions_instance, [0, 0, 0],
                [1, original_image_shape[0], original_image_shape[1]])
            resized_shape = tf.to_int32([
                tf.squeeze(samples[common.HEIGHT]),
                tf.squeeze(samples[common.WIDTH])
            ])
            predictions_instance = tf.squeeze(
                tf.image.resize_images(
                    tf.expand_dims(predictions_instance, 3),
                    resized_shape,
                    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                    align_corners=True), 3)

            ############################### POST PROCESSING LOGITS FROM INSTANCE REGRESSION #####################
            predictions_regression = tf.slice(
                predictions_regression, [0, 0, 0, 0],
                [1, original_image_shape[0], original_image_shape[1], 1])
            resized_shape = tf.to_int32([
                tf.squeeze(samples[common.HEIGHT]),
                tf.squeeze(samples[common.WIDTH]), 2
            ])
            predictions_regression = tf.image.resize_images(
                predictions_regression,
                resized_shape,
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)

        tf.train.get_or_create_global_step()
        if FLAGS.quantize_delay_step >= 0:
            contrib_quantize.create_eval_graph()

        num_iteration = 0
        max_num_iteration = FLAGS.max_number_of_iterations

        checkpoints_iterator = contrib_training.checkpoints_iterator(
            FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)
        for checkpoint_path in checkpoints_iterator:
            num_iteration += 1
            tf.logging.info('Starting visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            tf.logging.info('Visualizing with model %s', checkpoint_path)

            scaffold = tf.train.Scaffold(
                init_op=tf.global_variables_initializer())
            session_creator = tf.train.ChiefSessionCreator(
                scaffold=scaffold,
                master=FLAGS.master,
                checkpoint_filename_with_path=checkpoint_path)
            with tf.train.MonitoredSession(session_creator=session_creator,
                                           hooks=None) as sess:
                batch = 0
                image_id_offset = 0

                while not sess.should_stop():
                    tf.logging.info('Visualizing batch %d', batch + 1)
                    _process_batch(
                        sess=sess,
                        original_images=samples[common.ORIGINAL_IMAGE],
                        semantic_predictions=predictions_semantic,
                        instance_predictions=predictions_instance,
                        regression_predictions=predictions_regression,
                        image_names=samples[common.IMAGE_NAME],
                        image_heights=samples[common.HEIGHT],
                        image_widths=samples[common.WIDTH],
                        image_id_offset=image_id_offset,
                        save_dir=save_dir,
                        instance_save_dir=instance_save_dir,
                        regression_save_dir=regression_save_dir,
                        panoptic_save_dir=panoptic_save_dir,
                        raw_save_dir=raw_save_dir,
                        train_id_to_eval_id=train_id_to_eval_id)
                    image_id_offset += FLAGS.vis_batch_size
                    batch += 1

            tf.logging.info('Finished visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            if max_num_iteration > 0 and num_iteration >= max_num_iteration:
                break