示例#1
0
def run_inference_graph(model, trained_checkpoint_prefix, input_images,
                        input_shape, pad_to_shape, label_color_map,
                        output_directory):

    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=model,
        input_shape=input_shape,
        pad_to_shape=pad_to_shape,
        label_color_map=label_color_map)

    with tf.Session() as sess:
        input_graph_def = tf.get_default_graph().as_graph_def()
        saver = tf.train.Saver()
        saver.restore(sess, trained_checkpoint_prefix)

        for idx, image_path in enumerate(input_images):
            image_raw = np.array(Image.open(image_path))

            start_time = timeit.default_timer()
            predictions = sess.run(outputs,
                                   feed_dict={placeholder_tensor: image_raw})
            elapsed = timeit.default_timer() - start_time

            print('{}) wall time: {}'.format(elapsed, idx + 1))
            filename = os.path.basename(image_path)
            save_location = os.path.join(output_directory, filename)

            predictions = predictions.astype(np.uint8)
            output_channels = len(label_color_map[0])
            if output_channels == 1:
                predictions = np.squeeze(predictions[0], -1)
            im = Image.fromarray(predictions)
            im.save(save_location, "PNG")
def export_inference_graph(pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           pad_to_shape=None,
                           output_colours=False,
                           output_collection_name='predictions'):

    _, segmentation_model = model_builder.build(pipeline_config.model,
                                                is_training=False)

    tf.gfile.MakeDirs(output_directory)
    frozen_graph_path = os.path.join(output_directory,
                                     'frozen_inference_graph.pb')
    eval_graphdef_path = os.path.join(output_directory, 'export_graph.pbtxt')
    saved_model_path = os.path.join(output_directory, 'saved_model')
    model_path = os.path.join(output_directory, 'model.ckpt')

    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=segmentation_model,
        input_shape=input_shape,
        pad_to_shape=pad_to_shape,
        label_color_map=(CITYSCAPES_LABEL_COLORS if output_colours else None),
        output_collection_name=output_collection_name)

    profile_inference_graph(tf.get_default_graph())

    saver = tf.train.Saver()
    input_saver_def = saver.as_saver_def()

    graph_def = tf.get_default_graph().as_graph_def()
    f = tf.gfile.FastGFile(eval_graphdef_path, "w")
    f.write(str(graph_def))

    write_graph_and_checkpoint(
        inference_graph_def=tf.get_default_graph().as_graph_def(),
        model_path=model_path,
        input_saver_def=input_saver_def,
        trained_checkpoint_prefix=trained_checkpoint_prefix)

    output_node_names = outputs.name.split(":")[0]

    freeze_graph_with_def_protos(
        input_graph_def=tf.get_default_graph().as_graph_def(),
        input_saver_def=input_saver_def,
        input_checkpoint=trained_checkpoint_prefix,
        output_graph=frozen_graph_path,
        output_node_names=output_node_names,
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        clear_devices=True,
        initializer_nodes='')

    print("Done!")
示例#3
0
    def setup(self):
        self.pipeline_config = pipeline_pb2.PipelineConfig()
        with tf.gfile.GFile(self.config_path, 'r') as f:
            text_format.Merge(f.read(), self.pipeline_config)

        self.num_classes, self.segmentation_model = model_builder.build(
            self.pipeline_config.model, is_training=False)
        self.outputs, self.placeholder_tensor = deploy_segmentation_inference_graph(
            model=self.segmentation_model,
            input_shape=self.input_shape,
            pad_to_shape=self.pad_to_shape,
            label_color_map=self.label_color_map)

        self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
        self.sess = tf.Session(config=tf.ConfigProto(
            gpu_options=self.gpu_options))
        self.input_graph_def = tf.get_default_graph().as_graph_def()
        self.saver = tf.train.Saver()
        self.saver.restore(self.sess, self.trained_checkpoint_prefix)
def run_inference_graph(model, trained_checkpoint_prefix, dataset, num_images,
                        ignore_label, pad_to_shape, num_classes,
                        processor_type, annot_type, num_gpu, export_folder,
                        **kwargs):
    batch = 1

    dataset = dataset.batch(batch, drop_remainder=True)
    dataset = dataset.apply(tf.data.experimental.ignore_errors())
    data_iter = dataset.make_one_shot_iterator()
    input_dict = data_iter.get_next()

    input_tensor = input_dict[dataset_builder._IMAGE_FIELD]
    annot_tensor = input_dict[dataset_builder._LABEL_FIELD]
    input_name = input_dict[dataset_builder._IMAGE_NAME_FIELD]

    input_shape = [None] + input_tensor.shape.as_list()[1:]

    name_pl = tf.placeholder(tf.string,
                             input_name.shape.as_list(),
                             name="name_pl")
    annot_pl = tf.placeholder(tf.float32,
                              annot_tensor.shape.as_list(),
                              name="annot_pl")
    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=model,
        input_shape=input_shape,
        #input=input_tensor,
        pad_to_shape=pad_to_shape,
        input_type=tf.float32)

    process_annot = annot_dict[annot_type]
    processor_class = processor_dict[processor_type]

    processor = processor_class(model, outputs, num_classes, annot_pl,
                                placeholder_tensor, name_pl, ignore_label,
                                process_annot, num_gpu, batch, **kwargs)

    processor.name = processor_type
    processor.post_process_ops()

    preprocess_input = processor.get_preprocessed()

    input_fetch = [input_name, input_tensor, annot_tensor]

    metric_vars = [v for v in tf.local_variables() if "ConfMat" in v.name]
    reset_metric = tf.variables_initializer(metric_vars)

    fetch = processor.get_fetch_dict()
    ood_score = processor.get_output_image()

    #######################################
    #weights = processor.get_weights()
    #ood_mean = tf.reduce_sum(ood_score*weights)/tf.reduce_sum(weights)
    #ood_median = get_median(ood_score)
    #pct_ood_gt = tf.reduce_sum(processor.annot*weights)/tf.reduce_sum(weights)
    #point_list = []
    roc_points = processor.metrics["roc"]
    iou_points = processor.metrics["iou"]
    threshs = np.array(range(400)) / (400 - 1)
    #######################################

    moose_mask = cv2.imread("imgs/moose_mask.png")[..., 0:1]

    moose_mask = (moose_mask > 128).astype(np.uint8)

    feed = processor.get_feed_dict()
    prediction = processor.get_prediction()
    colour_prediction = _map_to_colored_labels(prediction, OOD_LABEL_COLORS)
    colour_annot = _map_to_colored_labels(annot_pl, OOD_LABEL_COLORS)

    num_step = num_images // batch

    # previous_export_set = set([os.path.basename(f) for f in glob("exported/*/*/*.png")])
    #previous_export_set = {'sun_bccbrnzxuvtlnfte.png', 'sun_btotndklvjecpext.png', '05_Schafgasse_1_000015_000150_leftImg8bit.png', '07_Festplatz_Flugfeld_000000_000250_leftImg8bit.png', 'sun_bsxsdrjnkydomeni.png', 'frankfurt_000001_071288_leftImg8bit.png', '02_Hanns_Klemm_Str_44_000001_000200_leftImg8bit.png', '04_Maurener_Weg_8_000002_000140_leftImg8bit.png', 'rand.png', '05_Schafgasse_1_000004_000170_leftImg8bit.png', 'munster_000040_000019_leftImg8bit.png', 'sun_bbcoqwpogowtuyvw.png', '02_Hanns_Klemm_Str_44_000005_000190_leftImg8bit.png', '07_Festplatz_Flugfeld_000001_000230_leftImg8bit.png', '07_Festplatz_Flugfeld_000002_000440_leftImg8bit.png', '04_Maurener_Weg_8_000005_000200_leftImg8bit.png', 'munster_000074_000019_leftImg8bit.png', '04_Maurener_Weg_8_000008_000200_leftImg8bit.png', 'frankfurt_000001_049770_leftImg8bit.png', 'sun_aaalbzqrimafwbiv.png', '02_Hanns_Klemm_Str_44_000015_000210_leftImg8bit.png', 'sun_aevmsxcxjbsoluch.png', 'sun_bgboysxblgxwcinn.png', 'sun_bjvurbfklntazktu.png', '04_Maurener_Weg_8_000012_000190_leftImg8bit.png', '02_Hanns_Klemm_Str_44_000011_000240_leftImg8bit.png', '02_Hanns_Klemm_Str_44_000009_000220_leftImg8bit.png', '04_Maurener_Weg_8_000013_000230_leftImg8bit.png', 'sun_bcebhcwjetrpvgsz.png', 'sun_bgwmloggfpvwqzzr.png', '04_Maurener_Weg_8_000000_000200_leftImg8bit.png', 'sun_blpteetxpjmjcejm.png', '07_Festplatz_Flugfeld_000003_000340_leftImg8bit.png', '12_Umberto_Nobile_Str_000001_000280_leftImg8bit.png', '07_Festplatz_Flugfeld_000003_000320_leftImg8bit.png', '05_Schafgasse_1_000012_000220_leftImg8bit.png', 'sun_bcqjcrtydolfnxqd.png', 'sun_bvhyciwhwphjbpjz.png', '04_Maurener_Weg_8_000003_000130_leftImg8bit.png', '02_Hanns_Klemm_Str_44_000014_000200_leftImg8bit.png', '04_Maurener_Weg_8_000004_000210_leftImg8bit.png', '04_Maurener_Weg_8_000008_000180_leftImg8bit.png', 'sun_aaaenaoynzhoyheo.png', 'sun_aqvldktdprlskoki.png', 'sun_bjlpzthlefdpouad.png', 'lindau_000016_000019_leftImg8bit.png', 'frankfurt_000001_025921_leftImg8bit.png', '07_Festplatz_Flugfeld_000000_000260_leftImg8bit.png'}
    previous_export_set = set()
    print(previous_export_set)

    all_results = []
    category_results = {}

    config = tf.ConfigProto(allow_soft_placement=True)
    # config.gpu_options.per_process_gpu_memory_fraction=0.8
    run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        init_feed = processor.get_init_feed()
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ], init_feed)

        vars_noload = set(processor.get_vars_noload())
        vars_toload = [
            v for v in tf.global_variables() if v not in vars_noload
        ]
        saver = tf.train.Saver(vars_toload)
        saver.restore(sess, trained_checkpoint_prefix)

        print("finalizing graph")
        sess.graph.finalize()

        #one sun image is bad
        num_step -= 1

        print("running for", num_step, "steps")
        for idx in range(num_step):

            start_time = timeit.default_timer()

            inputs = sess.run(input_fetch)

            annot_raw = inputs[2]
            img_raw = inputs[1]
            image_path = inputs[0]

            cur_path = image_path[0].decode()
            save_name = cur_path.replace("/mnt/md0/Data/",
                                         "").replace(".jpg", ".png")
            if ".png" not in save_name:
                save_name += ".png"
            previous_export = save_name in previous_export_set
            # if not previous_export:
            #     print("skipping")
            #     continue

            # from libs import perlin
            # h,w = placeholder_tensor.shape.as_list()[1:3]
            # #img_raw = np.expand_dims((perlin.make_perlin(10,h,w)*255).astype(np.uint8),0)
            # m = np.mean(img_raw,(0,1,2))
            # s = np.std(img_raw,(0,1,2))
            # _channel_means = [123.68, 116.779, 103.939]
            # norm = np.clip(img_raw - m + _channel_means,0,255)
            #img_raw = norm
            #import pdb; pdb.set_trace()

            # img_raw *= moose_mask

            if preprocess_input is not None:
                processed_input = sess.run(preprocess_input,
                                           feed_dict={
                                               placeholder_tensor: img_raw,
                                               annot_pl: annot_raw,
                                               name_pl: image_path
                                           })
            else:
                processed_input = img_raw

            # ood1 = sess.run(ood_score, feed_dict={placeholder_tensor: processed_input, annot_pl: annot_raw, name_pl: image_path})
            # ood2 = sess.run(ood_score, feed_dict={placeholder_tensor: img_raw, annot_pl: annot_raw, name_pl: image_path})
            # plt.figure()
            # plt.imshow(ood1[0,...,0])
            # plt.figure()
            # plt.imshow(ood2[0,...,0])
            # plt.show()

            feed_dict = {
                placeholder_tensor: processed_input,
                annot_pl: annot_raw,
                name_pl: image_path
            }

            feed_dict.update(feed)

            # all_pred_do = []
            # for tempi in range(2):
            #     all_pred_do.append(sess.run(processor.stacked_pred, feed_dict))

            sess.run(reset_metric)

            res = {}
            for f in fetch:
                #print("running", f)
                res.update(sess.run(f, feed_dict, options=run_options))

            roc, iou = sess.run([roc_points, iou_points])

            result = processor.post_process(res)

            all_results.append([
                save_name, result["auroc"], result["aupr"], result["max_iou"],
                result["fpr_at_tpr"], result["detection_error"]
            ])

            # category = save_name.replace("SUN2012/Images/","").split("/")[1]
            # print(idx/num_step,":", category, "          ", end="\r")

            # intresting_result = np.sum(np.logical_and(annot_raw >= 19, annot_raw != 255))/np.prod(annot_raw.shape) > 0.005

            # if intresting_result:
            #     if category not in category_results:
            #         category_results[category] = {"auroc": [], "aupr": [], "max_iou": [], "fpr_at_tpr": [], "detection_error": []}
            #     category_results[category]["auroc"].append(result["auroc"])
            #     category_results[category]["aupr"].append(result["aupr"])
            #     category_results[category]["max_iou"].append(result["max_iou"])
            #     category_results[category]["fpr_at_tpr"].append(result["fpr_at_tpr"])
            #     category_results[category]["detection_error"].append(result["detection_error"])

            # cur_point = sess.run([pct_ood_gt, ood_mean, ood_median], feed_dict)
            # print(cur_point)

            # point_list.append(cur_point)
            # print(result["auroc"], np.sum(np.logical_and(annot_raw >= 19, annot_raw != 255))/np.prod(annot_raw.shape))

            # intresting_result = result["auroc"] > 0.9 or (result["auroc"] > 0.0001 and result["auroc"] < 0.1)
            # intresting_result = np.sum(np.logical_and(annot_raw >= 19, annot_raw != 255))/np.prod(annot_raw.shape) > 0.005

            previous_export = False
            if True or previous_export:
                output_image, new_annot, colour_pred = sess.run(
                    [ood_score, colour_annot, colour_prediction],
                    feed_dict,
                    options=run_options)

                if len(output_image.shape) == 3:
                    output_image = np.expand_dims(output_image, -1)

                # output_image -= output_image.min()
                # output_image /= output_image.max()

                out_img = img_raw[0][..., ::-1].astype(np.uint8)
                out_pred = colour_pred[0][..., ::-1].astype(np.uint8)
                out_map = output_image[0, ..., 0]
                # plt.imshow(out_map)
                # plt.show()
                # import pdb; pdb.set_trace()
                #{"mean_sub": processor.mean_sub,"img_dist": processor.img_dist,"bad_pixel": processor.bad_pixel,"var_inv_tile": processor.var_inv_tile,"left": processor.left}
                out_annot = new_annot[0][..., ::-1].astype(np.uint8)

                # iou_i = np.argmax(iou)
                # fpr, tpr = roc[:,0], roc[:,1]
                # roc_i = np.argmax(tpr + 1 - fpr)
                # iou_t = threshs[iou_i]
                # roc_t = threshs[roc_i]

                # # roc_select = ((output_image[0,...,0]) > roc_t).astype(np.uint8)*255
                # # iou_select = ((output_image[0,...,0]) > iou_t).astype(np.uint8)*255

                # overlay = cv2.addWeighted(out_pred, 0.5, out_img, 0.5, 0)

                # cv2.imshow("image", cv2.resize(out_img, (0,0), fx=0.9, fy=0.9))
                # cv2.imshow("uncertainty", cv2.resize(out_map, (0,0), fx=0.9, fy=0.9))
                # cv2.imshow("annot", cv2.resize(out_annot, (0,0), fx=0.9, fy=0.9))
                # cv2.imshow("prediction", cv2.resize(overlay, (0,0), fx=0.9, fy=0.9))

                print(save_name)

                def do_save():
                    save_folder = os.path.join(export_folder, processor.name)
                    img_save_path = os.path.join(save_folder, "image")
                    map_save_path = os.path.join(save_folder, "map")
                    pred_save_path = os.path.join(save_folder, "pred")
                    annot_save_path = os.path.join(save_folder, "annot")

                    # roc_save_path = os.path.join(save_folder, "roc")
                    # iou_save_path = os.path.join(save_folder, "iou")
                    # for f in [img_save_path, map_save_path, pred_save_path, annot_save_path, roc_save_path, iou_save_path]:
                    for f in [
                            img_save_path, map_save_path, pred_save_path,
                            annot_save_path
                    ]:
                        os.makedirs(os.path.join(f,
                                                 os.path.dirname(save_name)),
                                    exist_ok=True)
                    s1 = cv2.imwrite(os.path.join(img_save_path, save_name),
                                     out_img)
                    s2 = cv2.imwrite(
                        os.path.join(map_save_path,
                                     save_name.replace(".png", ".exr")),
                        out_map)
                    s3 = cv2.imwrite(os.path.join(pred_save_path, save_name),
                                     out_pred)
                    s4 = cv2.imwrite(os.path.join(annot_save_path, save_name),
                                     out_annot)
                    if not (s1 and s2 and s3 and s4):
                        import pdb
                        pdb.set_trace()
                    # cv2.imwrite(os.path.join(roc_save_path, save_name), roc_select)
                    # cv2.imwrite(os.path.join(iou_save_path, save_name), iou_select)

                do_save()
                # if previous_export:
                #     do_save()
                #     #previous_export_set.remove(save_name)
                #     if len(previous_export_set) == 0:
                #         break
                # else: #let us decide
                #     while True:
                #         key = cv2.waitKey()
                #         if key == 27: #escape
                #             return
                #         elif key == 32: #space
                #             break
                #         elif key == 115: #s
                #             do_save()
                #             print("saved!")
                #         elif key == 98: #b
                #             import pdb; pdb.set_trace()

        # print()
        # csv_file_name = "category_score/" + processor.name + ".csv"
        # os.makedirs("category_score", exist_ok=True)
        # with open(csv_file_name, "w") as csv:
        #     csv.write("category,auroc,aupr,max_iou,fpr_at_tpr,detection_error,count\n")
        #     for c in sorted(list(category_results.keys())):
        #         csv.write(c + ",")
        #         for metric_name in ["auroc","aupr","max_iou","fpr_at_tpr","detection_error"]:
        #             csv.write(str(np.mean(category_results[c][metric_name])) + ",")
        #         csv.write(str(len(category_results[c]["auroc"])) + "\n")

        meta = os.path.join(export_folder, processor.name, "meta.csv")
        with open(meta, "w") as f:
            f.write("path,auroc,aupr,max_iou,fpr_at_tpr,detection_error\n")
            f.write("\n".join([",".join(map(str, l)) for l in all_results]))
def run_inference_graph(model, trained_checkpoint_prefix, input_dict,
                        num_images, input_shape, pad_to_shape, label_color_map,
                        output_directory, num_classes, patch_size):
    assert len(input_shape) == 3, "input shape must be rank 3"
    effective_shape = [None] + input_shape

    batch = 1
    if isinstance(model._feature_extractor, resnet_ex_class):
        batch = 2
    elif isinstance(model._feature_extractor, mobilenet_ex_class):
        batch = 1

    dataset = create_input(input_dict, batch)
    dataset = dataset.apply(tf.data.experimental.ignore_errors())
    data_iterator = dataset.make_one_shot_iterator()
    input_dict = data_iterator.get_next()

    input_tensor = input_dict[dataset_builder._IMAGE_FIELD]
    annot_tensor = input_dict[dataset_builder._LABEL_FIELD]

    flip = True
    if flip:
        input_tensor = tf.concat(
            [input_tensor,
             tf.image.flip_left_right(input_tensor)], 0)
        annot_tensor = tf.concat([
            annot_tensor[..., 0],
            tf.image.flip_left_right(annot_tensor)[..., 0]
        ], 0)
    else:
        annot_tensor = annot_tensor[..., 0]

    #import pdb; pdb.set_trace()

    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=model,
        input_shape=effective_shape,
        input=input_tensor,
        pad_to_shape=pad_to_shape,
        label_color_map=label_color_map)

    final_logits = outputs[model.final_logits_key]

    if FLAGS.use_dtform:
        stats_dir = os.path.join(output_directory, "stats.dtform")
    else:
        stats_dir = os.path.join(output_directory, "stats")
    class_mean_file = os.path.join(stats_dir, "class_mean.npz")
    class_cov_file = os.path.join(stats_dir, "class_cov_inv.npz")

    first_pass = True
    with tf.device("gpu:1"):
        avg_mask, sorted_feats = process_annot(annot_tensor, final_logits,
                                               num_classes)

        # if os.path.exists(mean_file) and os.path.exists(class_mean_file):
        feed_dict = {}
        if os.path.exists(class_mean_file):
            class_mean_v = np.load(class_mean_file)["arr_0"]

            # class_mean_pl = tf.placeholder(tf.float32, class_mean_v.shape)
            # feed_dict[class_mean_pl] = class_mean_v
            first_pass = False
            if FLAGS.use_patch:
                comp = stats.PatchCovComputer
            else:
                comp = stats.CovComputer
            stat_computer = comp(sorted_feats, avg_mask, class_mean_v)
            output_file = class_cov_file
            print("second_pass")
        else:
            if FLAGS.use_patch:
                comp = stats.PatchMeanComputer
            else:
                comp = stats.MeanComputer
            stat_computer = comp(sorted_feats, avg_mask)
            output_file = class_mean_file
            print("first_pass")

        update_op = stat_computer.get_update_op()

    coord = tf.train.Coordinator()
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    full_eye = None
    coord = tf.train.Coordinator()

    with tf.Session(config=config) as sess:
        saver = tf.train.Saver(tf.global_variables())

        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        saver.restore(sess, trained_checkpoint_prefix)

        num_step = epoch * num_images // batch
        print("running for", num_step * batch)
        for idx in range(num_step):
            start_time = timeit.default_timer()

            sess.run(update_op, feed_dict=feed_dict)

            elapsed = timeit.default_timer() - start_time
            end = "\r"
            if idx % 50 == 0:
                #every now and then do regular print
                end = "\n"
            print('{0:.4f} wall time: {1}'.format(elapsed / batch,
                                                  (idx + 1) * batch),
                  end=end)
        print('{0:.4f} wall time: {1}'.format(elapsed / batch,
                                              (idx + 1) * batch))
        os.makedirs(stats_dir, exist_ok=True)

        stat_computer.save_variable(sess, output_file)

        coord.request_stop()
        coord.join(threads)
def run_inference_graph(model, trained_checkpoint_prefix, input_images,
                        annot_filenames, input_shape, pad_to_shape,
                        label_color_map, output_directory, num_classes,
                        patch_size):
    effective_shape = copy.deepcopy(input_shape)
    if patch_size:
        effective_shape[:2] = patch_size
        patches, patch_place = img_to_patch(input_shape, patch_size)

    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=model,
        input_shape=effective_shape,
        pad_to_shape=pad_to_shape,
        label_color_map=label_color_map)

    pred_tensor = outputs[model.main_class_predictions_key]

    # pl_size = np.reduce_prod(placeholder_tensor.get_shape().as_list())
    # placeholder_tensor = tf.random_uniform(tf.shape(placeholder_tensor),maxval=pl_size)

    stats_dir = os.path.join(output_directory, "stats")
    class_mean_file = os.path.join(stats_dir, "class_mean.npz")
    class_cov_file = os.path.join(stats_dir, "class_cov_inv.npz")

    x = None
    y = None
    #m_k = None
    #v_k_inv = None
    class_m_k = None
    class_v_k_inv = None
    first_pass = True

    # if os.path.exists(mean_file) and os.path.exists(class_mean_file):
    if os.path.exists(class_mean_file):
        #m_k = torch.tensor(np.load(mean_file)["arr_0"])
        class_m_k = torch.tensor(np.load(class_mean_file)["arr_0"])
        first_pass = False
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        input_graph_def = tf.get_default_graph().as_graph_def()
        saver = tf.train.Saver(tf.global_variables())

        # feats = outputs[model.final_logits_key]
        # shape = feats.get_shape().as_list()
        # feats = tf.reshape(feats,[-1, shape[-1]])
        # temp = tf.constant(0, tf.float32, feats.get_shape().as_list())
        # covar, update = tf.contrib.metrics.streaming_covariance(feats, temp)
        # mean = [v for v in tf.get_collection(tf.GraphKeys.METRIC_VARIABLES) if "mean_prediction" in v.op.name][0]
        # fetch += [update]
        # sess.run(tf.global_variables_initializer())
        # sess.run(tf.local_variables_initializer())
        pred_shape = pred_tensor.get_shape().as_list()
        annot_place, sorted_feats, avg_mask = process_annot(
            pred_shape, outputs[model.final_logits_key], num_classes)
        fetch = [sorted_feats, avg_mask]
        saver.restore(sess, trained_checkpoint_prefix)

        k = None
        class_k = None
        if first_pass:
            passes = [True, False]
        else:
            passes = [False]  #means loaded from disk

        for first_pass in passes:
            if first_pass:
                print("first pass")
            else:
                print("second pass")
            for idx in range(len(input_images)):
                image_path = input_images[idx]
                # image_raw = np.expand_dims(cv2.imread(image_path),0)
                image_raw = cv2.imread(image_path)
                # annot_raw = np.expand_dims(cv2.imread(annot_filenames[idx]),0)
                annot_raw = cv2.imread(annot_filenames[idx])
                # import pdb; pdb.set_trace()

                start_time = timeit.default_timer()
                for flipped in [False, True]:
                    if flipped:
                        image_raw = np.fliplr(image_raw)
                        annot_raw = np.fliplr(annot_raw)

                    if patch_size:
                        all_image_raw = sess.run(
                            patches, feed_dict={patch_place: image_raw})
                        all_annot_raw = sess.run(
                            patches, feed_dict={patch_place: annot_raw})
                    else:
                        all_image_raw = [image_raw]
                        all_annot_raw = [annot_raw]

                    for i in range(len(all_image_raw)):
                        feed = {placeholder_tensor: all_image_raw[i]}
                        feed[annot_place] = all_annot_raw[i, ..., 0]
                        res = sess.run(fetch, feed_dict=feed)
                        sorted_logits = res[0]
                        mask = res[1]
                        #m_k, v_k_inv, k = compute_stats(m_k, v_k_inv, logits, k, first_pass, mask)
                        #for b in range(sorted_logits.shape[0]): #should only be 1
                        class_m_k, class_v_k_inv, class_k = compute_stats(
                            class_m_k, class_v_k_inv, sorted_logits, class_k,
                            first_pass, mask)

                    # if idx > 10:
                    #     import pdb; pdb.set_trace()
                # if idx > 5:
                #     break

                elapsed = timeit.default_timer() - start_time
                print('{}) wall time: {}'.format(elapsed, idx + 1))

                # m_k, v_k_inv = sess.run([mean, covar])

            os.makedirs(stats_dir, exist_ok=True)

            if first_pass:
                class_m_k_np = class_m_k.numpy()
                #m_k = m_k.numpy()
                #if np.isnan(m_k).any() or np.isnan(class_m_k).any():
                if np.isnan(class_m_k_np).any():
                    print("nan time")
                    import pdb
                    pdb.set_trace()
                #np.savez(mean_file, m_k)
                np.savez(class_mean_file, class_m_k_np)
            else:
                #v_k = b_inv(v_k_inv)
                #class_v_k = b_inv(class_v_k_inv)

                class_v_k_inv_np = (class_v_k_inv / (class_k + 1)).numpy()
                #v_k_inv = (v_k_inv/(k+1)).numpy()

                # if np.isnan(v_k_inv).any() or np.isnan(class_v_k_inv).any():
                if np.isnan(class_v_k_inv_np).any():
                    print("nan time")
                    import pdb
                    pdb.set_trace()

                np.savez(class_cov_file, class_v_k_inv_np)
def run_inference_graph(model, trained_checkpoint_prefix, input_dict,
                        num_images, ignore_label, input_shape, pad_to_shape,
                        label_color_map, output_directory, num_classes,
                        eval_dir, min_dir, dist_dir, hist_dir, dump_dir):
    assert len(input_shape) == 3, "input shape must be rank 3"
    batch = 1
    do_ood = FLAGS.do_ood
    epsilon = FLAGS.epsilon
    #epsilon = np.linspace(0,0.0001,10)
    dump_dir += "_" + str(epsilon)
    #from normalise_data.py
    # norms = np.load(os.path.join(dump_dir, "normalisation.npy")).item()
    # mean_value = norms["mean"]
    # std_value = norms["std"]
    mean_value = 508.7571
    std_value = 77.60572284853058
    if FLAGS.max_softmax:
        thresh = 0.0650887573964497  #dim from sun train
    else:
        thresh = 0.37583892617449666  #dim from sun train
    effective_shape = [batch] + input_shape

    input_queue = create_input(input_dict, batch, 15, 15, 15)
    input_dict = input_queue.dequeue()

    input_tensor = input_dict[dataset_builder._IMAGE_FIELD]
    annot_tensor = input_dict[dataset_builder._LABEL_FIELD]
    input_name = input_dict[dataset_builder._IMAGE_NAME_FIELD]

    annot_pl = tf.placeholder(tf.float32, annot_tensor.get_shape().as_list())
    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=model,
        input_shape=effective_shape,
        #input=input_tensor,
        pad_to_shape=pad_to_shape,
        input_type=tf.float32)

    pred_tensor = outputs[model.main_class_predictions_key]
    final_logits = outputs[model.final_logits_key]
    unscaled_logits = outputs[model.unscaled_logits_key]

    stats_dir = os.path.join(eval_dir, "stats.dtform")
    class_mean_file = os.path.join(stats_dir, "class_mean.npz")
    class_cov_file = os.path.join(stats_dir, "class_cov_inv.npz")

    global_cov = FLAGS.global_cov
    global_mean = FLAGS.global_mean

    print("loading means and covs")
    mean = np.load(class_mean_file)["arr_0"]
    var_inv = np.load(class_cov_file)["arr_0"]
    print("done loading")
    var_dims = list(var_inv.shape[-2:])
    mean_dims = list(mean.shape[-2:])
    depth = mean_dims[-1]

    if global_cov:
        var_brod = np.ones_like(var_inv)
        var_inv = np.sum(var_inv, axis=(0, 1, 2), keepdims=True) * var_brod
    if global_mean:
        mean_brod = np.ones_like(mean)
        mean = np.mean(mean, axis=(0, 1, 2), keepdims=True) * mean_brod
        # import pdb; pdb.set_trace()

    #mean = np.reshape(mean, [-1] + mean_dims)
    #var_inv = np.reshape(var_inv, [-1] + var_dims)
    with tf.device("gpu:1"):
        not_correct = tf.to_float(
            tf.not_equal(annot_pl, tf.to_float(pred_tensor)))
        dist_class, img_dist, full_dist, min_dist, mean_p, var_inv_p, vars_noload, dbg = process_logits(
            final_logits, mean, var_inv, depth,
            pred_tensor.get_shape().as_list(), num_classes, global_cov,
            global_mean)
        dist_colour = _map_to_colored_labels(dist_class, label_color_map)
        pred_colour = _map_to_colored_labels(pred_tensor, label_color_map)

        if FLAGS.max_softmax:
            interp_logits = tf.image.resize_bilinear(
                unscaled_logits,
                pred_tensor.shape.as_list()[1:3])
            dist_pred = 1 - tf.reduce_max(tf.nn.softmax(
                interp_logits / FLAGS.t_value),
                                          -1,
                                          keepdims=True)
            dist_class = tf.to_float(dist_pred >= thresh)
        else:
            dist_pred = tf.expand_dims(
                pred_to_ood(min_dist, mean_value, std_value, thresh), -1)
            dist_class = tf.to_float(dist_pred >= thresh)

        #pred is the baseline of assuming all ood
        pred_tensor = tf.ones_like(pred_tensor)

    with tf.device("gpu:1"):
        neg_validity_mask = get_valid(annot_pl, ignore_label)
        # with tf.variable_scope("PredIou"):
        #     (pred_miou, pred_conf_mat, pred_update), _ = get_miou(not_correct, pred_tensor, num_classes, ignore_label, do_ood, neg_validity_mask)
        with tf.variable_scope("DistIou"):
            (dist_miou, dist_conf_mat,
             dist_update), _ = get_miou(not_correct, dist_class, num_classes,
                                        ignore_label, do_ood,
                                        neg_validity_mask)

        weights = tf.to_float(neg_validity_mask)

        num_thresholds = 200

        with tf.variable_scope("Roc"):
            RocPoints, roc_update = tf.contrib.metrics.streaming_curve_points(
                not_correct, dist_pred, weights, num_thresholds, curve='ROC')
        with tf.variable_scope("Pr"):
            PrPoints, pr_update = tf.contrib.metrics.streaming_curve_points(
                not_correct, dist_pred, weights, num_thresholds, curve='PR')

        dbg = []  #[not_correct, dist_pred, dist_class]

    stream_vars_valid = [v for v in tf.local_variables() if 'Roc/' in v.name]
    reset_op = tf.variables_initializer(stream_vars_valid)

    update_op = [dist_update]
    if not FLAGS.write_out:
        update_op += [pr_update, roc_update]
    update_op = tf.group(update_op)

    mean = np.reshape(mean, mean_p.get_shape().as_list())
    var_inv = np.reshape(var_inv, var_inv_p.get_shape().as_list())

    input_fetch = [input_name, input_tensor, annot_tensor]

    fetch = {"update": update_op}

    if FLAGS.train_kernel:
        fetch["predictions"] = pred_tensor
        fetch["min_dist_out"] = min_dist[0]

    if FLAGS.write_img:
        fetch["prediction_colour"] = pred_colour
        fetch["dist_out"] = tf.cast(dist_colour[0], tf.uint8)
        fetch["full_dist_out"] = full_dist[0]
        fetch["min_dist_out"] = min_dist[0]

    if FLAGS.write_out:
        fetch["img_dist_out"] = img_dist[0]
        fetch["unscaled_logits_out"] = unscaled_logits[0]

    grads = tf.gradients(min_dist, placeholder_tensor)
    epsilon_pl = tf.placeholder(tf.float32, (), "epsilon")
    if epsilon > 0.0:
        adv_img = placeholder_tensor - epsilon_pl * tf.sign(grads)
    else:
        adv_img = tf.expand_dims(placeholder_tensor, 0)

    num_step = num_images // batch
    print("running for", num_step, "steps")
    #os.makedirs(dump_dir, exist_ok=True)

    if FLAGS.write_out:
        write_queue = Queue(30)
        num_writers = 20
        writers = [ParallelWriter(write_queue) for i in range(num_writers)]

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ], {
            mean_p: mean,
            var_inv_p: var_inv
        })
        tf.train.start_queue_runners(sess)
        vars_toload = [
            v for v in tf.global_variables() if v not in vars_noload
        ]
        saver = tf.train.Saver(vars_toload)
        saver.restore(sess, trained_checkpoint_prefix)

        if FLAGS.train_kernel:
            kimg_pl, kedges_pl, kloss, ktrain_step, kfilter = kernel_model(
                (1, 1024, 2048, 1))
            init = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope="kmodel")
            sess.run(tf.variables_initializer(init))
            #sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

        for idx in range(num_step):

            start_time = timeit.default_timer()

            inputs = sess.run(input_fetch)

            annot_raw = inputs[2]
            img_raw = inputs[1]
            image_path = inputs[0][0].decode("utf-8")
            filename = os.path.basename(image_path)
            dump_filename = os.path.join(dump_dir, filename + ".npy")
            adv_img_out = sess.run(adv_img,
                                   feed_dict={
                                       placeholder_tensor: img_raw,
                                       annot_pl: annot_raw,
                                       epsilon_pl: epsilon
                                   })
            adv_img_out = adv_img_out[0]

            #import pdb; pdb.set_trace()
            #sess.run(reset_op)
            res, dbg_v = sess.run([fetch, dbg],
                                  feed_dict={
                                      placeholder_tensor: adv_img_out,
                                      annot_pl: annot_raw
                                  })

            roc = sess.run(RocPoints)
            auc = -np.trapz(roc[:, 1], roc[:, 0])
            # if auc <= 0.8480947:
            #     ###DBG
            #     def sigmoid(x):
            #         return 1 / (1 + np.exp(-x))
            #     min_og, not_correct_out = sess.run([min_dist, not_correct], {placeholder_tensor: img_raw, annot_pl: annot_raw})
            #     min_adv = sess.run(min_dist, {placeholder_tensor: adv_img_out, annot_pl: annot_raw})
            #     norm_min_og = (min_og - mean_value)/std_value
            #     norm_min_adv = (min_adv - mean_value)/std_value
            #     pred_og = sigmoid(norm_min_og * 0.13273349404335022 + 0.38076120615005493)
            #     pred_adv = sigmoid(norm_min_adv * 0.13273349404335022 + 0.38076120615005493)
            #     pred_avg = np.mean([pred_og, pred_adv],0)
            #     #######
            #     import pdb; pdb.set_trace()
            #     print(filename)

            dist_miou_v = sess.run([dist_miou])

            if FLAGS.train_kernel:
                predictions = res["predictions"]
                min_dist_out = res["min_dist_out"]
                edges = cv2.Canny(predictions[0].astype(np.uint8), 1, 1)
                #import pdb; pdb.set_trace()
                filter = train_kernel(min_dist_out, edges, sess, kimg_pl,
                                      kedges_pl, kloss, ktrain_step, kfilter)
                #all_filters.append(filter)
                # kernel = gkern(sigma=0.2)
                dilated = np.expand_dims(
                    cv2.filter2D(edges, -1, filter[..., 0, 0]),
                    -1).astype(np.float32)
                dilated = dilated / np.max(dilated)

                disp = cv2.resize(
                    np.concatenate([to_img(min_dist_out),
                                    to_img(dilated)], 1),
                    (int(1920), int(1080)))
                cv2.imshow("test", disp)
                cv2.waitKey(1)

            if FLAGS.write_img:
                prediction_colour = res["prediction_colour"]
                dist_out = res["dist_out"]
                full_dist_out = res["full_dist_out"]
                predictions = res["predictions"]
                min_dist_out = res["min_dist_out"]

                # annot_out = res[8][0]
                # n_values = np.max(annot_out) + 1
                # one_hot_out = np.eye(n_values)[annot_out][...,0,:num_classes]

                min_dist_v = min_dist_out  # np.expand_dims(np.nanmin(full_dist_out, -1), -1)
                min_dist_v[np.logical_not(
                    np.isfinite(min_dist_v))] = np.nanmin(min_dist_out)
                min_dist_v = min_dist_v - np.min(min_dist_v)  #min now at 0
                min_dist_v = (255 * min_dist_v / np.max(min_dist_v)).astype(
                    np.uint8)  #max now at 255

                save_location = os.path.join(output_directory, filename)
                dist_filename = os.path.join(dist_dir, filename)
                min_filename = os.path.join(min_dir, filename)

                #write_hist(min_dist_out, "Min Dist", os.path.join(hist_dir, filename))

                #all_mins.append(min_dist_out)

                # if idx == 30:
                #     write_hist(all_mins, "Combined Dists", os.path.join(hist_dir, "all"))

                prediction_colour = prediction_colour.astype(np.uint8)
                output_channels = len(label_color_map[0])
                if output_channels == 1:
                    prediction_colour = np.squeeze(prediction_colour[0], -1)
                else:
                    prediction_colour = prediction_colour[0]
                #import pdb; pdb.set_trace()
                write_queue.put((idx, save_location, prediction_colour))
                write_queue.put((idx, min_filename, min_dist_v))
                write_queue.put((idx, dist_filename, dist_out))

            if FLAGS.write_out:
                img_dist_out = res["img_dist_out"]
                unscaled_logits_out = res["unscaled_logits_out"]

                #if not os.path.exists(dump_filename):
                write_queue.put((idx, dump_filename, {
                    "dist": img_dist_out,
                    "unscaled_logits": unscaled_logits_out
                }))
                #else:
                #    print("skipping", filename, "                          ")

            if FLAGS.debug:
                dist_out = res[2][0].astype(np.uint8)
                full_dist_out = res[4][0]
                min_dist_out = res[5][0]

                min_dist_v = np.expand_dims(np.nanmin(full_dist_out, -1), -1)
                min_dist_v[np.logical_not(
                    np.isfinite(min_dist_v))] = np.nanmin(full_dist_out)
                min_dist_v = min_dist_v - np.min(min_dist_v)  #min now at 0
                min_dist_v = (255 * min_dist_v / np.max(min_dist_v)).astype(
                    np.uint8)  #max now at 255

                final_out = res[7][0]
                annot_out = inputs[2][0]
                img_out = inputs[1][0]

                thresh = np.median(min_dist_out)
                grain = (np.max(min_dist_out) - np.min(min_dist_out)) / 300
                print(thresh, "  ", grain)
                while True:
                    mask = np.expand_dims(min_dist_out < thresh, -1)
                    #cv2.imshow("img", (255*mask).astype(np.uint8))
                    cv2.imshow("img", (img_out * mask).astype(np.uint8))
                    key = cv2.waitKey(1)
                    if key == 27:  #escape
                        break
                    elif key == 115:  #s
                        thresh += grain
                        print(thresh, "  ", grain)
                    elif key == 119:  #w
                        thresh -= grain
                        print(thresh, "  ", grain)
                    elif key == 97:  #a
                        grain -= 5
                        print(thresh, "  ", grain)
                    elif key == 100:  #d
                        grain += 5
                        print(thresh, "  ", grain)
                    elif key == 112:  #p
                        import pdb
                        pdb.set_trace()

            elapsed = timeit.default_timer() - start_time
            end = "\r"
            if idx % 50 == 0:
                #every now and then do regular print
                end = "\n"
            if FLAGS.write_out:
                qsize = write_queue.qsize()
            else:
                qsize = 0

            print('{0:.4f} iter: {1}, iou: {2:.6f}, auc: {3:.6f}'.format(
                elapsed, idx + 1, dist_miou_v[0], auc))

        if not FLAGS.write_out:
            roc = sess.run(RocPoints)
            pr = sess.run(PrPoints)

            make_plots(roc, pr, num_thresholds)

        if FLAGS.write_out:
            for w in writers:
                w.close()
        print('{0:.4f} iter: {1}, iou: {2:.6f}'.format(elapsed, idx + 1,
                                                       dist_miou_v[0]))
示例#8
0
def run_inference_graph(model, trained_checkpoint_prefix, dataset, num_images,
                        ignore_label, pad_to_shape, num_classes,
                        processor_type, annot_type, num_gpu, **kwargs):
    batch = 1

    dataset = dataset.batch(batch, drop_remainder=True)
    data_iter = dataset.make_one_shot_iterator()
    input_dict = data_iter.get_next()

    input_tensor = input_dict[dataset_builder._IMAGE_FIELD]
    annot_tensor = input_dict[dataset_builder._LABEL_FIELD]
    input_name = input_dict[dataset_builder._IMAGE_NAME_FIELD]

    input_shape = [None] + input_tensor.shape.as_list()[1:]

    name_pl = tf.placeholder(tf.string,
                             input_name.shape.as_list(),
                             name="name_pl")
    annot_pl = tf.placeholder(tf.float32,
                              annot_tensor.shape.as_list(),
                              name="annot_pl")
    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=model,
        input_shape=input_shape,
        #input=input_tensor,
        pad_to_shape=pad_to_shape,
        input_type=tf.float32)

    process_annot = annot_dict[annot_type]
    processor_class = processor_dict[processor_type]

    processor = processor_class(model, outputs, num_classes, annot_pl,
                                placeholder_tensor, name_pl, ignore_label,
                                process_annot, num_gpu, batch, **kwargs)

    # if processor_type == "MaxSoftmax":
    #     processor = MaxSoftmaxProcessor(model, outputs, num_classes,
    #                         annot_pl, placeholder_tensor,
    #                         FLAGS.epsilon, FLAGS.t_value, ignore_label,
    #                         ood_annot)
    # elif processor_type == "Mahal":
    #     processor = MahalProcessor(model, outputs, num_classes, annot_pl,
    #                         placeholder_tensor, eval_dir, FLAGS.epsilon,
    #                         FLAGS.global_cov, FLAGS.global_mean, ignore_label,
    #                         ood_annot)
    # else:
    #     raise ValueError(str(processor_type) + " is an unknown processor")

    processor.post_process_ops()

    preprocess_input = processor.get_preprocessed()

    input_fetch = [input_name, input_tensor, annot_tensor]

    fetch = processor.get_fetch_dict()
    feed = processor.get_feed_dict()
    ood_values = processor.get_output_image()
    # ood_mean = tf.reduce_mean(ood_values)
    # ood_median = get_median(ood_values)
    # pct_ood_gt = tf.reduce_mean(processor.annot)

    num_step = num_images // batch

    np.set_printoptions(threshold=2, edgeitems=1)
    print_exclude = {"tp", "fp", "tn", "fn", "pred", "new_pred"}

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.per_process_gpu_memory_fraction = 1.
    run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        init_feed = processor.get_init_feed()
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ], init_feed)

        vars_noload = set(processor.get_vars_noload())
        vars_toload = [
            v for v in tf.global_variables() if v not in vars_noload
        ]
        saver = tf.train.Saver(vars_toload)
        saver.restore(sess, trained_checkpoint_prefix)

        print("finalizing graph")

        sess.graph.finalize()

        # run_meta = tf.RunMetadata()
        # opts = tf.profiler.ProfileOptionBuilder.float_operation()
        # flops = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts)
        # if flops is not None:
        #     print('NUMBER OF FLOPS:', flops.total_float_ops)
        # num_param = np.sum([np.prod([s for s in v.get_shape().as_list() if s is not None]) for v in tf.global_variables()])
        # print("NUMBER OF PARAMETERS:", num_param)
        # sys.exit(0)

        # temp_fw = tf.summary.FileWriter("temptb", graph=sess.graph)
        # temp_fw.flush()
        # temp_fw.close()

        #one sun image is bad
        avg_time = 0
        num_step -= 1

        print("running for", num_step, "steps")
        for idx in range(num_step):

            start_time = timeit.default_timer()

            inputs = sess.run(input_fetch)

            annot_raw = inputs[2]
            img_raw = inputs[1]
            image_path = inputs[0]
            # print(image_path)
            if preprocess_input is not None:
                processed_input = sess.run(preprocess_input,
                                           feed_dict={
                                               placeholder_tensor: img_raw,
                                               annot_pl: annot_raw,
                                               name_pl: image_path
                                           })
            else:
                processed_input = img_raw

            feed_dict = {
                placeholder_tensor: processed_input,
                annot_pl: annot_raw,
                name_pl: image_path
            }

            feed_dict.update(feed)

            res = {}
            for f in fetch:
                #print("running", f)
                res.update(sess.run(f, feed_dict, options=run_options))

            result = processor.post_process(res)

            # cur, disp_annot, disp_weights = sess.run([ood_values, processor.annot, processor.weights], feed_dict)
            # # from sklearn.metrics import roc_auc_score
            # # print("sklearn: ", roc_auc_score(np.reshape(disp_annot, [-1]), np.reshape(cur,[-1])))
            # outimg = cur
            # print(np.mean(cur), np.std(cur))
            # plt.figure()
            # plt.imshow(cur[0,...,0])
            # plt.figure()
            # plt.imshow(disp_annot[0,...,0])
            # plt.figure()
            # plt.imshow(disp_weights[0,...,0])
            # plt.figure()
            # plt.imshow(annot_raw[0,...,0])
            # plt.show()

            # cur_point = sess.run([pct_ood_gt, ood_mean, ood_median], feed_dict)
            # print(cur_point)

            elapsed = timeit.default_timer() - start_time
            end = "\r"
            if idx % 1 == 0:
                #every now and then do regular print
                end = "\n"

            # pred = res["metrics"]["pred"]
            # new_pred = res["metrics"]["new_pred"]
            # from post_process.validation_metrics import filter_ood
            # pred_0 = filter_ood(pred, 110)
            # pred_1 = filter_ood(pred, 110, dilate=7, erode=7)
            # pred_2 = filter_ood(pred, 110, dilate=9, erode=9)
            # # import pdb; pdb.set_trace()
            # import matplotlib.pyplot as plt
            # plt.subplot(2,2,1)
            # plt.imshow(pred[0])
            # plt.subplot(2,2,2)
            # plt.imshow(np.squeeze(pred_0))
            # plt.subplot(2,2,3)
            # plt.imshow(np.squeeze(pred_1))
            # plt.subplot(2,2,4)
            # plt.imshow(np.squeeze(pred_2))
            # plt.show()

            to_print = {}
            for v in result:
                if v not in print_exclude:
                    to_print[v] = result[v]
            if idx > 0:
                avg_time += (elapsed - avg_time) / (idx)
            print('{0:.4f}({1:.4f}): {2}, {3}'.format(elapsed, avg_time,
                                                      idx + 1, to_print),
                  end=end)
        print('{0:.4f}({1:.4f}): {2}, {3}'.format(elapsed, avg_time, idx + 1,
                                                  to_print))
        return result
def run_inference_graph(model, trained_checkpoint_prefix,
                        dataset, num_images, ignore_label, input_shape, pad_to_shape,
                        label_color_map, output_directory, num_classes, eval_dir,
                        min_dir, dist_dir, hist_dir, dump_dir):
    assert len(input_shape) == 3, "input shape must be rank 3"
    batch = 1
    do_ood = FLAGS.do_ood
    epsilon = FLAGS.epsilon
    dump_dir += "_" + str(epsilon)
    mean_value = 508.7571
    std_value = 77.60572284853058
    if FLAGS.max_softmax:
        thresh = 0.07100591715976332 #dim dist from sun train
        #thresh = 0.0650887573964497 #dim from sun train
    else:
        thresh = 0.37583892617449666 #dim from sun train
    effective_shape = [batch] + input_shape

    dataset = dataset.batch(batch, drop_remainder=True)
    dataset = dataset.apply(tf.data.experimental.ignore_errors())
    data_iter = dataset.make_one_shot_iterator()
    input_dict = data_iter.get_next()

    input_tensor = input_dict[dataset_builder._IMAGE_FIELD]
    annot_tensor = input_dict[dataset_builder._LABEL_FIELD]
    input_name = input_dict[dataset_builder._IMAGE_NAME_FIELD]

    annot_pl = tf.placeholder(tf.float32, annot_tensor.get_shape().as_list())
    outputs, placeholder_tensor = deploy_segmentation_inference_graph(
        model=model,
        input_shape=effective_shape,
        #input=input_tensor,
        pad_to_shape=pad_to_shape,
        input_type=tf.float32)

    pred_tensor = outputs[model.main_class_predictions_key]
    final_logits = outputs[model.final_logits_key]
    unscaled_logits = outputs[model.unscaled_logits_key]


    #mean = np.reshape(mean, [-1] + mean_dims)
    #var_inv = np.reshape(var_inv, [-1] + var_dims)
    with tf.device("gpu:1"):
        if not FLAGS.max_softmax:
            dist_class, img_dist, full_dist, min_dist, mean_p, var_inv_p, vars_noload, dbg  = process_logits(final_logits, mean, var_inv, depth, pred_tensor.get_shape().as_list(), num_classes, global_cov, global_mean)
            dist_colour = _map_to_colored_labels(dist_class, label_color_map)
            pred_colour = _map_to_colored_labels(pred_tensor, label_color_map)
            selected = min_dist

        if do_ood:
            if FLAGS.max_softmax:
                interp_logits = tf.image.resize_bilinear(unscaled_logits, pred_tensor.shape.as_list()[1:3])
                dist_pred = 1.0 - tf.reduce_max(tf.nn.softmax(interp_logits/FLAGS.t_value),-1, keepdims=True)
                dist_class = tf.to_float(dist_pred >= thresh)
                selected = dist_pred
                vars_noload = []
            else:
                #dist_pred = tf.reduce_min(tf.nn.softmax(full_dist), -1, keepdims=True)
                dist_pred = tf.expand_dims(pred_to_ood(min_dist, mean_value, std_value, thresh),-1)
                dist_class = tf.to_float(dist_pred >= thresh)
            
            #pred is the baseline of assuming all ood
            pred_tensor = tf.ones_like(pred_tensor)

    with tf.device("gpu:1"):
        neg_validity_mask = get_valid(annot_pl, ignore_label)
        with tf.variable_scope("PredIou"):
            (pred_miou, pred_conf_mat, pred_update), _ = get_miou(annot_pl, pred_tensor, num_classes, ignore_label, do_ood, neg_validity_mask)
        with tf.variable_scope("DistIou"):
            (dist_miou, dist_conf_mat, dist_update), _ = get_miou(annot_pl, dist_class, num_classes, ignore_label, do_ood, neg_validity_mask)
  
        weights = tf.to_float(neg_validity_mask)

        num_thresholds = 200

        ood_label = tf.to_float(annot_pl >= num_classes)

        with tf.variable_scope("Roc"):
            RocPoints, roc_update = tf.contrib.metrics.streaming_curve_points(ood_label,dist_pred,weights,num_thresholds,curve='ROC')
        with tf.variable_scope("Pr"):
            PrPoints, pr_update = tf.contrib.metrics.streaming_curve_points(ood_label,dist_pred,weights,num_thresholds,curve='PR')

    update_op = [pred_update, dist_update, pr_update, roc_update]
    update_op = tf.group(update_op)

    if not FLAGS.max_softmax:
        mean = np.reshape(mean, mean_p.get_shape().as_list())
        var_inv = np.reshape(var_inv, var_inv_p.get_shape().as_list())

    input_fetch = [input_name, input_tensor, annot_tensor]

    fetch = {"update": update_op,
            "selected": selected,
            "ood_label": ood_label,
        }

    dbg = []

    if FLAGS.train_kernel:
        fetch["predictions"] = pred_tensor
        fetch["min_dist_out"] = min_dist[0]

    if FLAGS.write_img:
        fetch["prediction_colour"] = pred_colour
        fetch["dist_out"] = tf.cast(dist_colour[0], tf.uint8)
        fetch["full_dist_out"] = full_dist[0]
        fetch["min_dist_out"] = min_dist[0]

    if FLAGS.write_out:
        fetch["img_dist_out"] = img_dist[0]
        fetch["unscaled_logits_out"] = unscaled_logits[0]

    grads = tf.gradients(selected, placeholder_tensor)
    if epsilon > 0.0:
        adv_img = placeholder_tensor - epsilon*tf.sign(grads)
    else:
        adv_img = tf.expand_dims(placeholder_tensor, 0)

    num_step = num_images // batch
    print("running for", num_step, "steps")
    #os.makedirs(dump_dir, exist_ok=True)

    if FLAGS.write_out:
        write_queue = Queue(30)
        num_writers = 20
        writers = [ParallelWriter(write_queue) for i in range(num_writers)]

    config = tf.ConfigProto(allow_soft_placement=True)
    #config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        init_feed = {}
        if not FLAGS.max_softmax:
            init_feed = {mean_p: mean, var_inv_p: var_inv}
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()],init_feed)
        vars_toload = [v for v in tf.global_variables() if v not in vars_noload]
        saver = tf.train.Saver(vars_toload)
        saver.restore(sess, trained_checkpoint_prefix)

        if FLAGS.train_kernel:
            kimg_pl, kedges_pl, kloss, ktrain_step, kfilter = kernel_model((1, 1024, 2048, 1))
            init = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="kmodel")
            sess.run(tf.variables_initializer(init))
            #sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

        for idx in range(num_step):

            start_time = timeit.default_timer()

            inputs = sess.run(input_fetch)

            annot_raw = inputs[2]
            img_raw = inputs[1]
            image_path = inputs[0][0].decode("utf-8")
            filename = os.path.basename(image_path)
            dump_filename = os.path.join(dump_dir, filename + ".npy")

            adv_img_out = sess.run(adv_img, feed_dict={placeholder_tensor: img_raw, annot_pl: annot_raw})
            adv_img_out = adv_img_out[0]

            res, dbg_v = sess.run([fetch, dbg], feed_dict={
                            placeholder_tensor: adv_img_out, annot_pl: annot_raw})

            roc = sess.run(RocPoints)
            auc = -np.trapz(roc[:,1], roc[:,0])

            pred_miou_v, dist_miou_v = sess.run([pred_miou, dist_miou])
            # if auc > 0.1:
            #     import pdb; pdb.set_trace()
            # if idx % 25 == 0 and idx != 0:
            #     roc = sess.run(RocPoints)
            #     plt.plot(roc[:,0], roc[:,1])
            #     plt.show()

            if FLAGS.train_kernel:
                predictions = res["predictions"]
                min_dist_out = res["min_dist_out"]
                edges = cv2.Canny(predictions[0].astype(np.uint8),1,1)
                #import pdb; pdb.set_trace()
                filter = train_kernel(min_dist_out, edges, sess, kimg_pl, kedges_pl, kloss, ktrain_step, kfilter)
                #all_filters.append(filter)
                # kernel = gkern(sigma=0.2)
                dilated = np.expand_dims(cv2.filter2D(edges,-1,filter[...,0,0]),-1).astype(np.float32)
                dilated = dilated/np.max(dilated)
                
                disp = cv2.resize(np.concatenate([to_img(min_dist_out), to_img(dilated)], 1), (int(1920), int(1080)))
                cv2.imshow("test", disp)
                cv2.waitKey(1)

            if FLAGS.write_img:
                prediction_colour = res["prediction_colour"]
                dist_out = res["dist_out"]
                full_dist_out = res["full_dist_out"]
                predictions = res["predictions"]
                min_dist_out = res["min_dist_out"]

                # annot_out = res[8][0]
                # n_values = np.max(annot_out) + 1
                # one_hot_out = np.eye(n_values)[annot_out][...,0,:num_classes]

                min_dist_v = min_dist_out# np.expand_dims(np.nanmin(full_dist_out, -1), -1)
                min_dist_v[np.logical_not(np.isfinite(min_dist_v))] = np.nanmin(min_dist_out)
                min_dist_v = min_dist_v - np.min(min_dist_v) #min now at 0
                min_dist_v = (255*min_dist_v/np.max(min_dist_v)).astype(np.uint8) #max now at 255
                
                save_location = os.path.join(output_directory, filename)
                dist_filename = os.path.join(dist_dir, filename)
                min_filename = os.path.join(min_dir, filename)

                #write_hist(min_dist_out, "Min Dist", os.path.join(hist_dir, filename))

                #all_mins.append(min_dist_out)

                # if idx == 30:
                #     write_hist(all_mins, "Combined Dists", os.path.join(hist_dir, "all"))

                prediction_colour = prediction_colour.astype(np.uint8)
                output_channels = len(label_color_map[0])
                if output_channels == 1:
                    prediction_colour = np.squeeze(prediction_colour[0],-1)
                else:
                    prediction_colour = prediction_colour[0]
                #import pdb; pdb.set_trace()
                write_queue.put((idx, save_location, prediction_colour))
                write_queue.put((idx, min_filename, min_dist_v))
                write_queue.put((idx, dist_filename, dist_out))
            
            if FLAGS.write_out:
                img_dist_out = res["img_dist_out"]
                unscaled_logits_out = res["unscaled_logits_out"]

                #if not os.path.exists(dump_filename):
                write_queue.put((idx, dump_filename, {"dist": img_dist_out, "unscaled_logits": unscaled_logits_out}))
                #else:
                #    print("skipping", filename, "                          ")
            
            if FLAGS.debug:
                dist_out = res[2][0].astype(np.uint8)
                full_dist_out = res[4][0]
                min_dist_out = res[5][0]

                min_dist_v = np.expand_dims(np.nanmin(full_dist_out, -1), -1)
                min_dist_v[np.logical_not(np.isfinite(min_dist_v))] = np.nanmin(full_dist_out)
                min_dist_v = min_dist_v - np.min(min_dist_v) #min now at 0
                min_dist_v = (255*min_dist_v/np.max(min_dist_v)).astype(np.uint8) #max now at 255
                
                final_out = res[7][0]
                annot_out = inputs[2][0]
                img_out = inputs[1][0]
                
                thresh = np.median(min_dist_out)
                grain = (np.max(min_dist_out) - np.min(min_dist_out))/300
                print(thresh, "  ", grain)
                while True:
                    mask = np.expand_dims(min_dist_out < thresh,-1)
                    #cv2.imshow("img", (255*mask).astype(np.uint8))
                    cv2.imshow("img", (img_out*mask).astype(np.uint8))
                    key = cv2.waitKey(1)
                    if key == 27: #escape
                        break
                    elif key == 115: #s
                        thresh += grain
                        print(thresh, "  ", grain)
                    elif key == 119: #w
                        thresh -= grain
                        print(thresh, "  ", grain)
                    elif key == 97: #a
                        grain -= 5
                        print(thresh, "  ", grain)
                    elif key == 100: #d
                        grain += 5
                        print(thresh, "  ", grain)
                    elif key == 112: #p
                        import pdb; pdb.set_trace()
            
            elapsed = timeit.default_timer() - start_time
            end = "\r"
            if idx % 50 == 0:
                #every now and then do regular print
                end = "\n"
            if FLAGS.write_out:
                qsize = write_queue.qsize()
            else:
                qsize = 0
            print('{0:.4f} iter: {1}, pred iou: {2:.6f}, dist iou: {3:.6f}, auc:{4:0.6f}'.format(elapsed, idx+1, pred_miou_v, dist_miou_v, auc))

        if not FLAGS.write_out:
            roc = sess.run(RocPoints)
            pr = sess.run(PrPoints)
            
            make_plots(roc,pr,num_thresholds)
        
        if FLAGS.write_out:
            for w in writers:
                w.close()
        print('{0:.4f} iter: {1}, pred iou: {2:.6f}, dist iou: {3:.6f}'.format(elapsed, idx+1, pred_miou_v, dist_miou_v))