def model_fn(features, labels, mode, params):
    """Defines the model compatible with tf.estimator."""
    del labels, params
    if mode == tf.estimator.ModeKeys.TRAIN:
        _build_deeplab(
            features, model.get_output_to_num_classes(FLAGS),
            model_input.dataset_descriptors[FLAGS.dataset].ignore_label)

        #  Print out the objective loss and regularization loss independently to
        #  track NaN loss issue
        objective_losses = tf.losses.get_losses()
        objective_losses = tf.Print(objective_losses, [objective_losses],
                                    message='Objective Losses: ',
                                    summarize=100)
        objective_loss = tf.reduce_sum(objective_losses)
        tf.summary.scalar('objective_loss', objective_loss)

        reg_losses = tf.losses.get_regularization_losses()
        reg_losses = tf.Print(reg_losses, [reg_losses],
                              message='Reg Losses: ',
                              summarize=100)
        reg_loss = tf.reduce_sum(reg_losses)
        tf.summary.scalar('regularization_loss', reg_loss)

        loss = objective_loss + reg_loss

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        tf.summary.scalar('learning_rate', learning_rate)

        grads_and_vars = optimizer.compute_gradients(loss)
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 tf.train.get_global_step())
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(loss, name='train_op')

        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
        )
Exemple #2
0
def main(unused_argv):
    FLAGS.comb_dropout_keep_prob = 1.0
    FLAGS.image_keep_prob = 1.0
    FLAGS.elements_keep_prob = 1.0

    # Get dataset-dependent information.

    tf.gfile.MakeDirs(FLAGS.eval_logdir)
    tf.logging.info('Evaluating on %s set', FLAGS.split)

    with tf.Graph().as_default():
        samples = model_input.get_input_fn(FLAGS)()

        # Get model segmentation predictions.
        num_classes = model_input.dataset_descriptors[
            FLAGS.dataset].num_classes
        output_to_num_classes = model.get_output_to_num_classes(FLAGS)

        if tuple(FLAGS.eval_scales) == (1.0, ):
            tf.logging.info('Performing single-scale test.')
            predictions, probs = model.predict_labels(
                samples['image'],
                samples,
                FLAGS,
                outputs_to_num_classes=output_to_num_classes,
                image_pyramid=FLAGS.image_pyramid,
                merge_method=FLAGS.merge_method,
                atrous_rates=FLAGS.atrous_rates,
                add_image_level_feature=FLAGS.add_image_level_feature,
                aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
                aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
                multi_grid=FLAGS.multi_grid,
                depth_multiplier=FLAGS.depth_multiplier,
                output_stride=FLAGS.output_stride,
                decoder_output_stride=FLAGS.decoder_output_stride,
                decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
                crop_size=[FLAGS.image_size, FLAGS.image_size],
                logits_kernel_size=FLAGS.logits_kernel_size,
                model_variant=FLAGS.model_variant)
        else:
            tf.logging.info('Performing multi-scale test.')
            predictions, probs = model.predict_labels_multi_scale(
                samples['image'],
                samples,
                FLAGS,
                outputs_to_num_classes=output_to_num_classes,
                eval_scales=FLAGS.eval_scales,
                add_flipped_images=FLAGS.add_flipped_images,
                merge_method=FLAGS.merge_method,
                atrous_rates=FLAGS.atrous_rates,
                add_image_level_feature=FLAGS.add_image_level_feature,
                aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
                aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
                multi_grid=FLAGS.multi_grid,
                depth_multiplier=FLAGS.depth_multiplier,
                output_stride=FLAGS.output_stride,
                decoder_output_stride=FLAGS.decoder_output_stride,
                decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
                crop_size=[FLAGS.image_size, FLAGS.image_size],
                logits_kernel_size=FLAGS.logits_kernel_size,
                model_variant=FLAGS.model_variant)

        metric_map = {}
        for output in output_to_num_classes:
            output_predictions = predictions[output]
            output_probs = probs[output]
            if output == 'segment':
                output_predictions = tf.expand_dims(output_predictions, 3)
                if num_classes == 2:
                    labels = samples['label']

                    iou, weights = model.foreground_iou(
                        labels, output_predictions, FLAGS)
                    soft_iou, _ = model.foreground_iou(
                        labels, output_probs[:, :, :, 1:2], FLAGS)

                    metric_map['mIOU'] = tf.metrics.mean(iou)
                    metric_map['soft_mIOU'] = tf.metrics.mean(soft_iou)

                    high_prob_overlaps = calc_high_prob_overlaps(
                        labels, output_probs, weights)
                    metric_map['highestOverlaps'] = tf.metrics.mean(
                        high_prob_overlaps)

                    output_probs *= weights

                else:
                    output_predictions = tf.reshape(output_predictions,
                                                    shape=[-1])
                    labels = tf.reshape(samples['label'], shape=[-1])
                    weights = tf.to_float(
                        tf.not_equal(
                            labels, model_input.dataset_descriptors[
                                FLAGS.dataset].ignore_label))

                    # Set ignore_label regions to label 0, because metrics.mean_iou
                    # requires range of labels=[0, dataset.num_classes).
                    # Note the ignore_label regions are not evaluated since
                    # the corresponding regions contain weights=0.
                    labels = tf.where(
                        tf.equal(
                            labels, model_input.dataset_descriptors[
                                FLAGS.dataset].ignore_label),
                        tf.zeros_like(labels), labels)

                    predictions_tag = 'mIOU'
                    for eval_scale in FLAGS.eval_scales:
                        predictions_tag += '_' + str(eval_scale)
                    if FLAGS.add_flipped_images:
                        predictions_tag += '_flipped'

                    # Define the evaluation metric.
                    metric_map[predictions_tag] = slim.metrics.mean_iou(
                        output_predictions,
                        labels,
                        num_classes,
                        weights=weights)

                def label_summary(labels, weights, name):
                    tf.summary.image(
                        name,
                        tf.reshape(
                            tf.cast(
                                tf.to_float(labels * 255) /
                                tf.to_float(num_classes), tf.uint8) *
                            tf.cast(weights, tf.uint8),
                            [-1, FLAGS.image_size, FLAGS.image_size, 1]), 8)

                label_summary(labels, weights, 'label')
                label_summary(output_predictions, weights,
                              'output_predictions')
                tf.summary.image('logits',
                                 tf.expand_dims(output_probs[:, :, :, 1], 3))

            elif output == 'regression':
                labels = samples['label']
                ignore_mask = model.get_ignore_mask(labels, FLAGS)

                accurate = calc_accuracy_in_box(labels, output_probs,
                                                ignore_mask)
                metric_map['inBoxAccuracy'] = tf.metrics.mean(accurate)

        tf.summary.image('image', samples['image'], 8)

        metrics_to_values, metrics_to_updates = slim.metrics.aggregate_metric_map(
            metric_map)

        for metric_name, metric_value in metrics_to_values.iteritems():
            metric_value = tf.Print(metric_value, [metric_value], metric_name)
            tf.summary.scalar(metric_name, metric_value)

        num_batches = int(
            math.ceil(FLAGS.num_samples / float(FLAGS.batch_size)))

        tf.logging.info('Eval num images %d', FLAGS.num_samples)
        tf.logging.info('Eval batch size %d and num batch %d',
                        FLAGS.batch_size, num_batches)

        slim.evaluation.evaluation_loop(
            master='',
            checkpoint_dir=FLAGS.checkpoint_dir,
            logdir=FLAGS.eval_logdir,
            num_evals=num_batches,
            eval_op=metrics_to_updates.values(),
            summary_op=tf.summary.merge_all(),
            max_number_of_evaluations=None,
            eval_interval_secs=FLAGS.eval_interval_secs)
Exemple #3
0
def main(unused_argv):
    # Get dataset-dependent information.
    # Prepare for visualization.
    tf.gfile.MakeDirs(FLAGS.vis_logdir)
    save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(save_dir)
    raw_save_dir = os.path.join(FLAGS.vis_logdir,
                                _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(raw_save_dir)
    num_vis_examples = FLAGS.num_vis_examples

    print('Visualizing on set', FLAGS.split)

    g = tf.Graph()
    with g.as_default():
        samples = model_input.get_input_fn(FLAGS)()
        outputs_to_num_classes = model.get_output_to_num_classes(FLAGS)

        # Get model segmentation predictions.
        if tuple(FLAGS.eval_scales) == (1.0, ):
            tf.logging.info('Performing single-scale test.')
            predictions, probs = model.predict_labels(
                samples['image'],
                samples,
                FLAGS,
                outputs_to_num_classes=outputs_to_num_classes,
                image_pyramid=FLAGS.image_pyramid,
                merge_method=FLAGS.merge_method,
                atrous_rates=FLAGS.atrous_rates,
                add_image_level_feature=FLAGS.add_image_level_feature,
                aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
                aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
                multi_grid=FLAGS.multi_grid,
                depth_multiplier=FLAGS.depth_multiplier,
                output_stride=FLAGS.output_stride,
                decoder_output_stride=FLAGS.decoder_output_stride,
                decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
                crop_size=[FLAGS.image_size, FLAGS.image_size],
                logits_kernel_size=FLAGS.logits_kernel_size,
                model_variant=FLAGS.model_variant)
        else:
            tf.logging.info('Performing multi-scale test.')
            predictions, probs = model.predict_labels_multi_scale(
                samples['image'],
                samples,
                FLAGS,
                outputs_to_num_classes=outputs_to_num_classes,
                eval_scales=FLAGS.eval_scales,
                add_flipped_images=FLAGS.add_flipped_images,
                merge_method=FLAGS.merge_method,
                atrous_rates=FLAGS.atrous_rates,
                add_image_level_feature=FLAGS.add_image_level_feature,
                aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
                aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
                multi_grid=FLAGS.multi_grid,
                depth_multiplier=FLAGS.depth_multiplier,
                output_stride=FLAGS.output_stride,
                decoder_output_stride=FLAGS.decoder_output_stride,
                decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
                crop_size=[FLAGS.image_size, FLAGS.image_size],
                logits_kernel_size=FLAGS.logits_kernel_size,
                model_variant=FLAGS.model_variant)

        if FLAGS.output_mode == 'segment':
            predictions = tf.squeeze(
                tf.cast(predictions[FLAGS.output_mode], tf.int32))
            probs = probs[FLAGS.output_mode]

            labels = tf.squeeze(tf.cast(samples['label'], tf.int32))
            weights = tf.cast(
                tf.not_equal(
                    labels, model_input.dataset_descriptors[
                        FLAGS.dataset].ignore_label), tf.int32)

            labels *= weights
            predictions *= weights

            tf.train.get_or_create_global_step()
            saver = tf.train.Saver(contrib_slim.get_variables_to_restore())
            sv = tf.train.Supervisor(graph=g,
                                     logdir=FLAGS.vis_logdir,
                                     init_op=tf.global_variables_initializer(),
                                     summary_op=None,
                                     summary_writer=None,
                                     global_step=None,
                                     saver=saver)
            num_batches = int(
                math.ceil(num_vis_examples / float(FLAGS.batch_size)))
            last_checkpoint = None

            # Infinite loop to visualize the results when new checkpoint is created.
            while True:
                last_checkpoint = contrib_slim.evaluation.wait_for_new_checkpoint(
                    FLAGS.checkpoint_dir, last_checkpoint)
                start = time.time()
                print('Starting visualization at ' +
                      time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
                print('Visualizing with model %s', last_checkpoint)

                print('Visualizing with model ', last_checkpoint)

                with sv.managed_session(FLAGS.master,
                                        start_standard_services=False) as sess:
                    # sv.start_queue_runners(sess)
                    sv.saver.restore(sess, last_checkpoint)

                    image_id_offset = 0
                    refs = []
                    for batch in range(num_batches):
                        print('Visualizing batch', batch + 1, num_batches)
                        refs.extend(
                            _process_batch(sess=sess,
                                           samples=samples,
                                           semantic_predictions=predictions,
                                           labels=labels,
                                           image_id_offset=image_id_offset,
                                           save_dir=save_dir))
                        image_id_offset += FLAGS.batch_size

            print('Finished visualization at ' +
                  time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
            if time_to_next_eval > 0:
                time.sleep(time_to_next_eval)