Esempio n. 1
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    common.print_args()

    dataset = generator.Dataset(
                dataset_dir=FLAGS.dataset_dir,
                dataset_name=FLAGS.dataset_name,
                split_name=FLAGS.split_name,
                batch_size=FLAGS.batch_size,
                crop_size=[int(val) for val in FLAGS.crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                is_training=False,
                model_variant=FLAGS.model_variant,
                should_shuffle=False,
                should_repeat=False)

    tf.logging.info('Evaluation on %s set on %s', ''.join(FLAGS.split_name), FLAGS.dataset_name)

    with tf.Graph().as_default():
        samples = dataset.get_one_shot_iterator().get_next()

        model = segModel.SegModel(
                num_classes=dataset.num_classes,
                model_variant=FLAGS.model_variant,
                output_stride=FLAGS.output_stride,
                backbone_atrous_rates=FLAGS.backbone_atrous_rates,
                is_training=False,
                ppm_rates=FLAGS.ppm_rates,
                ppm_pooling_type=FLAGS.ppm_pooling_type,
                atrous_rates=FLAGS.atrous_rates,
                module_order=FLAGS.module_order,
                decoder_output_stride=FLAGS.decoder_output_stride)

        if FLAGS.eval_scales == [1.0]:
            tf.logging.info('Evaluate the single scale image.')
            predictions, _ = model.predict_labels(images=samples[common.IMAGE],
                                               add_flipped_images=FLAGS.add_flipped_images)
        else:
            tf.logging.info('Evaluate the multi-scale image.')
            predictions, _ = model.predict_labels_for_multiscale(
                                    images=samples[common.IMAGE],
                                    add_flipped_images=FLAGS.add_flipped_images,
                                    eval_scales=FLAGS.eval_scales)
        predictions = tf.reshape(predictions, shape=[-1])
        labels = tf.reshape(samples[common.LABEL], shape=[-1])
        weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))


        # Set ignore_label regions to label 0, because metrics.mean_iou requires
        # range of labels = [0, dataset.num_classes). Note the ignore_label regions
        # are not evaluated since the corresponding regions contain weights = 0.
        labels = tf.where(
            tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels)

        miou, update_op = tf.metrics.mean_iou(
            predictions, labels, dataset.num_classes, weights=weights)
        miou_text = '{}_miou_f_{}_multiscale_{}'.format(
                        ''.join(FLAGS.split_name),
                        int(FLAGS.add_flipped_images),
                        int(len(FLAGS.eval_scales) > 1))
        tf.summary.scalar(miou_text, miou)

        summary_op = tf.summary.merge_all()
        # それぞれのepochごとに行いたい処理
        summary_hook = tf.contrib.training.SummaryAtEndHook(
            log_dir=FLAGS.eval_logdir, summary_op=summary_op)
        hooks = [summary_hook]

        num_eval = None
        if FLAGS.max_number_of_evaluations > 0:
            num_eval = FLAGS.max_number_of_evaluations

        tf.contrib.tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=tf.contrib.tfprof.model_analyzer.
            TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
        tf.contrib.tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
        tf.contrib.training.evaluate_repeatedly(
            master='',
            checkpoint_dir=FLAGS.checkpoint_dir,
            eval_ops=[update_op],
            max_number_of_evaluations=num_eval,
            hooks=hooks,
            eval_interval_secs=FLAGS.eval_interval_secs)
Esempio n. 2
0
def main(unused_args):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.logging.info('Prepare to export model to %s', FLAGS.export_path)
    common.print_args()

    with tf.Graph().as_default():
        image, image_size, resized_image_size = _create_input_tensor()

        model = segModel.SegModel(
            num_classes=generator._DATASETS_INFORMATION[
                FLAGS.dataset_name]['num_classes'],
            model_variant=FLAGS.model_variant,
            output_stride=FLAGS.output_stride,
            fine_tune_batch_norm=False,
            backbone_atrous_rates=FLAGS.backbone_atrous_rates,
            is_training=False,
            ppm_rates=FLAGS.ppm_rates,
            ppm_pooling_type=FLAGS.ppm_pooling_type,
            atrous_rates=FLAGS.atrous_rates,
            module_order=FLAGS.module_order,
            decoder_output_stride=FLAGS.decoder_output_stride)

        if FLAGS.inference_scales == [1.0]:
            tf.logging.info('Evaluate the single scale image.')
            predictions, probabilities = model.predict_labels(
                images=image, add_flipped_images=FLAGS.add_flipped_images)
        else:
            tf.logging.info('Evaluate the multi-scale image.')
            predictions, probabilities = model.predict_labels_for_multiscale(
                images=image,
                add_flipped_images=FLAGS.add_flipped_images,
                eval_scales=FLAGS.inference_scales)
        # prediction is a thing after argmax.
        raw_predictions = tf.identity(tf.cast(predictions, tf.float32),
                                      _RAW_OUTPUT_NAME)
        # probabilities is a thing berore argmax.
        raw_probabilities = tf.identity(probabilities, _RAW_OUTPUT_PROB_NAME)

        # Crop the valid regions from the predictions.
        semantic_predictions = raw_predictions[:, :resized_image_size[0], :
                                               resized_image_size[1]]
        semantic_probabilities = raw_probabilities[:, :resized_image_size[0], :
                                                   resized_image_size[1]]

        # Resize back the prediction to the original image size.
        def _resize_label(label, label_size):
            # Expand dimension of label to [1, height, width, 1] for resize operation.
            label = tf.expand_dims(label, 3)
            resized_label = tf.image.resize_images(
                label,
                label_size,
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)
            return tf.cast(tf.squeeze(resized_label, 3), tf.int32)

        semantic_predictions = _resize_label(semantic_predictions, image_size)
        semantic_predictions = tf.identity(semantic_predictions,
                                           name=_OUTPUT_NAME)

        semantic_probabilities = tf.image.resize_bilinear(
            semantic_probabilities,
            image_size,
            align_corners=True,
            name=_OUTPUT_PROB_NAME)

        saver = tf.train.Saver(tf.all_variables())

        dirname = os.path.dirname(FLAGS.export_path)
        tf.gfile.MakeDirs(dirname)
        graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
        freeze_graph.freeze_graph_with_def_protos(
            graph_def,
            saver.as_saver_def(),
            FLAGS.checkpoint_path,
            _OUTPUT_NAME + ',' + _OUTPUT_PROB_NAME,
            restore_op_name=None,
            filename_tensor_name=None,
            output_graph=FLAGS.export_path,
            clear_devices=True,
            initializer_nodes=None)

        if FLAGS.save_inference_graph:
            tf.train.write_graph(graph_def, dirname, 'inference_graph.pbtxt')
Esempio n. 3
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    common.print_args()

    tf.logging.info('Training on %s %s set' %
                    (','.join(FLAGS.split_name), FLAGS.dataset_name))

    graph = tf.Graph()
    crop_size = [int(val) for val in FLAGS.crop_size]
    # create graph in new thread.
    with graph.as_default():
        # create dataset generator
        dataset = generator.Dataset(
            dataset_dir=FLAGS.dataset_dir,
            dataset_name=FLAGS.dataset_name,
            split_name=FLAGS.split_name,
            batch_size=FLAGS.batch_size,
            crop_size=[int(val) for val in FLAGS.crop_size],
            min_resize_value=FLAGS.min_resize_value,
            max_resize_value=FLAGS.max_resize_value,
            resize_factor=FLAGS.resize_factor,
            min_scale_factor=FLAGS.min_scale_factor,
            max_scale_factor=FLAGS.max_scale_factor,
            scale_factor_step_size=FLAGS.scale_factor_step_size,
            is_training=True,
            model_variant=FLAGS.model_variant,
            should_shuffle=True,
            should_repeat=True)

        iterator = dataset.get_one_shot_iterator()
        global_step = tf.train.get_or_create_global_step()

        # create learning_rate
        learning_rate = tf.train.polynomial_decay(
            learning_rate=FLAGS.base_learning_rate,
            global_step=global_step,
            decay_steps=FLAGS.train_steps,
            end_learning_rate=0,
            power=0.9)

        # create optimizer
        optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

        # build models
        with tf.name_scope('clone') as scope:
            with tf.variable_scope(tf.get_variable_scope(), reuse=False):
                samples = iterator.get_next()
                input = tf.identity(samples[common.IMAGE], name='Input_Image')
                labels = tf.identity(samples[common.LABEL],
                                     name='Semantic_Label')

                global model
                model = segModel.SegModel(
                    num_classes=dataset.num_classes,
                    model_variant=FLAGS.model_variant,
                    output_stride=FLAGS.output_stride,
                    fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
                    weight_decay=FLAGS.weight_decay,
                    backbone_atrous_rates=FLAGS.backbone_atrous_rates,
                    is_training=True,
                    ppm_rates=FLAGS.ppm_rates,
                    ppm_pooling_type=FLAGS.ppm_pooling_type,
                    atrous_rates=FLAGS.atrous_rates,
                    module_order=FLAGS.module_order,
                    decoder_output_stride=FLAGS.decoder_output_stride)

                logits = model.build(images=input)
                logits = tf.identity(logits, name='dense_prediction')
                logits, labels = resize_logits_or_labels(logits, labels)

                add_softmax_cross_entropy_loss(logits,
                                               labels,
                                               dataset.num_classes,
                                               dataset.ignore_label,
                                               loss_weight=1.0)

                log_summaries(input, labels, dataset.num_classes, logits,
                              dataset.ignore_label)

            # should_log
            losses = tf.losses.get_losses(scope=scope)
            total_loss = get_total_loss(losses, global_step, scope)

            grads = optimizer.compute_gradients(total_loss)

            grad_updates = optimizer.apply_gradients(grads,
                                                     global_step=global_step)

            # Gather update_ops. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)

            # Print total loss to the terminal.
            total_loss = tf.cond(
                math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0),
                lambda: _print_tensor('total loss :', total_loss),
                lambda: total_loss)

            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        summary_op = tf.summary.merge_all(scope='clone')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        init_fn = None
        if FLAGS.tf_initial_checkpoint:
            init_fn = get_model_init_fn(
                train_logdir=FLAGS.train_logdir,
                tf_initial_checkpoint=FLAGS.tf_initial_checkpoint,
                ignore_missing_vars=True)

        scaffold = tf.train.Scaffold(init_fn=init_fn, summary_op=summary_op)

        stop_hook = tf.train.StopAtStepHook(last_step=FLAGS.train_steps)

        if FLAGS.save_summaries_secs <= 0:
            FLAGS.save_summaries_secs = None

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_logdir,
                hooks=[stop_hook],
                # config=session_config,
                scaffold=scaffold,
                summary_dir=FLAGS.train_logdir,
                log_step_count_steps=FLAGS.log_steps,
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_checkpoint_secs=FLAGS.save_interval_secs) as sess:
            iter = 0
            while not sess.should_stop():
                sess.run([train_tensor])
                if iter % FLAGS.log_steps == 0:
                    sys.stdout.write('\n')
                    sys.stdout.flush()
                iter += 1
Esempio n. 4
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    common.print_args()

    # Create dataset generator
    dataset = generator.Dataset(
        dataset_dir=FLAGS.dataset_dir,
        dataset_name=FLAGS.dataset_name,
        split_name=FLAGS.split_name,
        batch_size=1,
        crop_size=[int(val) for val in FLAGS.crop_size],
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        min_scale_factor=FLAGS.min_scale_factor,
        max_scale_factor=FLAGS.max_scale_factor,
        scale_factor_step_size=FLAGS.scale_factor_step_size,
        is_training=False,
        model_variant=FLAGS.model_variant,
        should_shuffle=False,
        should_repeat=False)

    tf.gfile.MakeDirs(FLAGS.vis_logdir)
    miou_text = '_flip_{}_multiscale_{}'.format(
        int(FLAGS.add_flipped_images), int(len(FLAGS.eval_scales) > 1))
    save_dir = os.path.join(FLAGS.vis_logdir,
                            _SEMANTIC_PREDICTION_DIR + miou_text)
    tf.gfile.MakeDirs(save_dir)

    tf.logging.info('Visualizing on %s set', ''.join(FLAGS.split_name))

    with tf.Graph().as_default():
        iterator = dataset.get_one_shot_iterator()
        samples = iterator.get_next()

        model = segModel.SegModel(
            num_classes=dataset.num_classes,
            model_variant=FLAGS.model_variant,
            output_stride=FLAGS.output_stride,
            backbone_atrous_rates=FLAGS.backbone_atrous_rates,
            is_training=False,
            ppm_rates=FLAGS.ppm_rates,
            ppm_pooling_type=FLAGS.ppm_pooling_type,
            atrous_rates=FLAGS.atrous_rates,
            module_order=FLAGS.module_order,
            decoder_output_stride=FLAGS.decoder_output_stride)

        if FLAGS.eval_scales == [1.0]:
            tf.logging.info('Evaluate the single scale image.')
            predictions, _ = model.predict_labels(
                images=samples[common.IMAGE],
                add_flipped_images=FLAGS.add_flipped_images)
        else:
            tf.logging.info('Evaluate the multi-scale image.')
            predictions, _ = model.predict_labels_for_multiscale(
                images=samples[common.IMAGE],
                add_flipped_images=FLAGS.add_flipped_images,
                eval_scales=FLAGS.eval_scales)

        checkpoints_iterator = tf.contrib.training.checkpoints_iterator(
            FLAGS.checkpoint_dir, min_interval_secs=FLAGS.eval_interval_secs)

        if FLAGS.log_confusion:
            # Get the confusion matrix op.
            conf_mat_op, conf_mat, batch_conf = create_conf_mat_op(
                samples[common.LABEL], predictions, dataset.num_classes,
                dataset.ignore_label)
            # Initializer for initializing the conf_mat.
            init_fn = tf.initialize_variables([conf_mat])
            # restore from checkpoint excluding variable "conf_mat".
            variables_to_restore = tf.global_variables()[:-1]
        else:
            # dummy op.
            conf_mat_op, conf_mat, batch_conf = None, None, None
            init_fn = None
            variables_to_restore = tf.global_variables()

        num_iteration = 0
        max_num_iteration = FLAGS.max_number_of_iterations

        for checkpoint_path in checkpoints_iterator:
            num_iteration += 1
            tf.logging.info('Starting visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            tf.logging.info('Visualizing with model %s', checkpoint_path)
            restorer = tf.train.Saver(variables_to_restore)
            scaffold = tf.train.Scaffold(
                init_op=tf.global_variables_initializer(),
                saver=restorer,
                ready_for_local_init_op=init_fn)
            session_creator = tf.train.ChiefSessionCreator(
                scaffold=scaffold,
                master='',
                checkpoint_filename_with_path=checkpoint_path)
            with tf.train.MonitoredSession(session_creator=session_creator,
                                           hooks=None) as sess:
                batch = 0
                image_id_offset = 0

                while not sess.should_stop():
                    tf.logging.info('Visualizing batch %d', batch + 1)
                    accumulated_conf_mat = _process_batch(
                        sess=sess,
                        samples=samples,
                        predictions=predictions,
                        image_id_offset=image_id_offset,
                        save_dir=save_dir,
                        conf_mat_op=conf_mat_op,
                        conf_mat=conf_mat,
                        batch_conf=batch_conf)
                    image_id_offset += 1
                    batch += 1

            if FLAGS.log_confusion:
                summary_conf_mat(accumulated_conf_mat, save_dir)

            tf.logging.info('Finished visualization at ' +
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))

            if max_num_iteration > 0 and num_iteration >= max_num_iteration:
                break
Esempio n. 5
0
def main(argv):
    common.print_args()
    tf.logging.set_verbosity(tf.logging.DEBUG)
    crop_size = [int(FLAGS.crop_size[0]), int(FLAGS.crop_size[1])]
    dataset = generator.Dataset(FLAGS.dataset_dir,
                                FLAGS.dataset_name,
                                FLAGS.split_name,
                                10,
                                crop_size,
                                min_resize_value=FLAGS.min_resize_value,
                                max_resize_value=FLAGS.max_resize_value,
                                resize_factor=FLAGS.resize_factor,
                                min_scale_factor=FLAGS.min_scale_factor,
                                max_scale_factor=FLAGS.max_scale_factor,
                                scale_factor_step_size=FLAGS.scale_factor_step_size,
                                is_training=FLAGS.is_training,
                                model_variant=FLAGS.model_variant,
                                should_shuffle=False)

    iterator = dataset.get_one_shot_iterator()
    samples = iterator.get_next()

    sess = tf.Session()

    print('=' * 10 + 'samples content' + '=' * 10)
    for key, val in samples.items():
        print('samples[{}]: {}'.format(key, val.shape))
    print('=' * 35)
    res = sess.run(samples)
    print(res.keys())

    image = res[common.IMAGE]
    ann = res[common.LABEL]
    original_height = res[common.IMAGE_HEIGHT]
    original_width  = res[common.IMAGE_WIDTH]
    print(original_height)
    print(original_width)

    print('image.shape => ', image.shape)
    print('ann.shape   => ', ann.shape)

    figure = plt.figure(figsize=(20, 10))
    num_imgs = 2
    gridspec_master = GridSpec(nrows=num_imgs, ncols=2)

    for i in range(num_imgs):
        grid_sub_1 = GridSpecFromSubplotSpec(nrows=1,
                                             ncols=1,
                                             subplot_spec=gridspec_master[i, 0])
        axes_1 = figure.add_subplot(grid_sub_1[:, :])
        axes_1.set_xticks([])
        axes_1.set_yticks([])
        axes_1.imshow(image[i].astype(np.uint8))
        axes_1.set_title('target')

        grid_sub_2 = GridSpecFromSubplotSpec(nrows=1,
                                             ncols=1,
                                             subplot_spec=gridspec_master[i, 1])
        axes_2 = figure.add_subplot(grid_sub_2[:, :])
        axes_2.set_xticks([])
        axes_2.set_yticks([])
        label = np.squeeze(ann[i], axis=2)
        label[label == 255] = np.max(label[label != 255]) +1
        axes_2.imshow(label, cmap='gray')
        axes_2.set_title('label')
    plt.show()

    print(ann.shape)
Esempio n. 6
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    common.print_args()



    EKF_VIS_DIR = 'EKF_RESULTS'
    ekf_vis_dir = os.path.join(FLAGS.ckpt_dir, EKF_VIS_DIR)

    if not os.path.exists(ekf_vis_dir):
        os.makedirs(ekf_vis_dir)

    with tf.Graph().as_default():
        dataset = generator.Dataset(
            dataset_dir=os.path.join('dataset', 'cityscapes', 'tfrecord'),
            dataset_name='cityscapes',
            split_name=['val_fine'],
            batch_size=1,
            crop_size=[HEIGHT, WIDTH],
            is_training=False,
            model_variant='resnet_v1_101_beta',
            min_scale_factor=0.50,
            max_scale_factor=2.0,
            should_shuffle=False,
            should_repeat=False)

        ite = dataset.get_one_shot_iterator()
        sample = ite.get_next()
        images, labels = sample[common.IMAGE], sample[common.LABEL]

        module_order = search_model()

        model = segModel.SegModel(
            num_classes=dataset.num_classes,
            model_variant='resnet_v1_101_beta',
            output_stride=16,
            backbone_atrous_rates=[1, 2, 4],
            is_training=False,
            ppm_rates=[1, 2, 3, 6],
            module_order=module_order,
            decoder_output_stride=4)

        logits = model.build(images=images)
        logits = tf.image.resize_bilinear(logits, tf.shape(images)[1:3], align_corners=True)
        # logits = tf.nn.softmax(logits, axis=3)

        height_ind = tf.range(HEIGHT, dtype=tf.int32)
        width_ind = tf.range(WIDTH, dtype=tf.int32)
        height_ind = tf.expand_dims(tf.math.logical_and(tf.math.greater_equal(height_ind, FLAGS.mask_heights[0]),
                                                        tf.math.greater_equal(FLAGS.mask_heights[1], height_ind)), 1)

        width_ind = tf.expand_dims(tf.math.logical_and(tf.math.greater_equal(width_ind, FLAGS.mask_widths[0]),
                                                        tf.math.greater_equal(FLAGS.mask_widths[1], width_ind)), 0)
        # height_ind = tf.expand_dims(tf.math.equal(height_ind, HEIGHT//4), 1)
        # width_ind = tf.expand_dims(tf.math.equal(width_ind, WIDTH//4), 0)

        height_map = []
        width_map = []
        for w in range(WIDTH):
            height_map.append(height_ind)
        for h in range(HEIGHT):
            width_map.append(width_ind)
        height_map = tf.concat(height_map, axis=1)
        width_map = tf.concat(width_map, axis=0)

        height_map = tf.cast(height_map, tf.float32)
        width_map = tf.cast(width_map, tf.float32)
        mask = tf.expand_dims(tf.math.multiply(height_map, width_map), axis=2)

        m_concat = []
        for _ in range(dataset.num_classes):
            m_concat.append(mask)
        mask = tf.concat(m_concat, axis=2)
        masked_logits = tf.multiply(mask, logits)
        grad = tf.gradients(masked_logits, [images])

        ########### SESSION CREATING PROCESS #############
        checkpoints_iterator = tf.contrib.training.checkpoints_iterator(
            FLAGS.ckpt_dir)

        for checkpoint_path in checkpoints_iterator:
            restorer = tf.train.Saver()
            scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer(),
                                         saver=restorer,
                                         ready_for_local_init_op=None)

            session_creator = tf.train.ChiefSessionCreator(
                scaffold=scaffold,
                master='',
                checkpoint_filename_with_path=checkpoint_path)

            with tf.train.MonitoredSession(
                    session_creator=session_creator, hooks=None) as sess:
                grad_list = []
                batch_num = 0
                while not sess.should_stop():
                    im, l, g = sess.run([images, logits, grad])
                    grad_list.append(g[0])
                    im = im[0]
                    g = np.abs(g[0][0])

                    if batch_num == 0:
                        sum_g = g
                    else:
                        sum_g += g
                    batch_num += 1

                    g = normalizeImg(g)
                    im = normalizeImg(im)

                    img_path = os.path.join(ekf_vis_dir, '{}_img.jpg'.format(batch_num))
                    grad_path = os.path.join(ekf_vis_dir, '{}_grad.jpg'.format(batch_num))

                    cv2.imwrite(img_path, cv2.cvtColor(im*255, cv2.COLOR_RGB2BGR))
                    cv2.imwrite(grad_path, cv2.cvtColor(g*255, cv2.COLOR_RGB2BGR))

                    print('Processing {} done!'.format(batch_num))
                    if batch_num % 100 == 0:
                        print('100 second sleep')
                        time.sleep(100)
                    if batch_num >= 20:
                        break

            break

        sum_g = sum_g / batch_num
        print('max: {:.3f}, min: {:.3f}'.format(np.max(sum_g), np.min(sum_g)))
        sum_g = normalizeImg(sum_g)
        binary_g = sum_g.copy()
        binary_g[sum_g > np.mean(sum_g)] = 255

        cv2.imwrite(os.path.join(ekf_vis_dir, 'average_grad.jpg'), cv2.cvtColor(sum_g*255, cv2.COLOR_RGB2BGR))

        plt.title('average grad')
        plt.imshow(sum_g)
        plt.show()

        cv2.imwrite(os.path.join(ekf_vis_dir, 'sum_grad_{}.jpg'.format(batch_num)), cv2.cvtColor(binary_g.astype(np.uint8), cv2.COLOR_RGB2BGR))