def main():
fn_list = open("/data/unagi0/dataset/SUNCG-Seg/data_goodlist_v2.txt").readlines()

fn_list = [fn.strip() for fn in fn_list]

cls_hist = np.zeros(256)
counter = Counter()

for fn in tqdm(fn_list):
    gt_fn = os.path.join("/data/unagi0/dataset/SUNCG-Seg/category_v2/", fn + "_category40.png")
    gt_im = np.array(Image.open(gt_fn))


print (counter)

mul = lambda x, y: x * y
n_pixel_per_img = reduce(mul, gt_im.size)
correct_n_pixel = n_pixel_per_img * len(fn_list)

got_n_pixel = sum(counter.values())

assert got_n_pixel == correct_n_pixel

print (got_n_pixel)

save_dic_to_json(counter, "suncg_gt_distribution.json")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')

    parser.add_argument("--gt_dir", type=str, default=None,
                        help="gt dir")
    parser.add_argument("--way", type=str, default="legend", help="legend or colorize",
                        choices=['legend', 'colorize'])
    parser.add_argument("--ext", type=str, default="pdf")
    parser.add_argument("--dataset", type=str, default="suncg")
    parser.add_argument("--title_names", type=str, default=None, nargs='*')

    args = parser.parse_args()

    dataset_dic = {
        "suncg": {
            "json_fn": "./dataset/nyu_info.json",
            "raw_rgb_dir": "/data/unagi0/dataset/SUNCG-Seg/mlt_v2",
            "raw_optional_img_dir": "/data/unagi0/dataset/SUNCG-Seg/hha_v2",
            "gt_dir": "/data/unagi0/dataset/SUNCG-Seg/category_v2",
            # "gt_dir": "/data/unagi0/watanabe/DomainAdaptation/Segmentation/VisDA2017/test_output/suncg-train_rgbhhab_only_3ch---suncg-train_rgbhha/normal-drn_d_38-20.tar/label",
            # "gt_dir": "/data/unagi0/watanabe/DomainAdaptation/Segmentation/VisDA2017/test_output/suncg-train_rgbhha_only_6ch---suncg-train_rgbhha/b16-drn_d_38-10.tar/label",
        }
    }

    raw_rgb_dir = dataset_dic[args.dataset]["raw_rgb_dir"]
    raw_optional_img_dir = dataset_dic[args.dataset]["raw_optional_img_dir"]
    gt_dir = dataset_dic[args.dataset]["gt_dir"]

    with open(dataset_dic[args.dataset]["json_fn"], 'r') as f:
        info = json.load(f)
        label_list = np.array(info['label'] + ["background"], dtype=np.str)

    if args.way == "legend":

        vis_with_legend(indir_list=[], outdir=args.outdir, label_list=label_list, raw_rgb_dir=raw_rgb_dir,
                        raw_optional_img_dir=raw_optional_img_dir, gt_dir=gt_dir, ext=args.ext,
                        title_names=args.title_names)

    elif args.way == "colorize":  # TODO
        NotImplementedError()
Exemplo n.º 2
0
                              list(model_g2.parameters()) +
                              list(model_f1.parameters()),
                              opt=args.opt,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    model_g1.load_state_dict(checkpoint['g1_state_dict'])
    model_g2.load_state_dict(checkpoint['g2_state_dict'])
    model_f1.load_state_dict(checkpoint['f1_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}'".format(args.resume))

    json_fn = os.path.join(args.outdir, "param_%s_resume.json" % args.savename)
    check_if_done(json_fn)
    args.machine = os.uname()[1]
    save_dic_to_json(args.__dict__, json_fn)

    start_epoch = checkpoint['epoch']

else:
    [model_g1, model_g2], model_f1 = get_multichannel_model(
        net_name=args.net,
        input_ch_list=args.inch_list,
        n_class=args.n_class,
        method=detailed_method,
        res=args.res,
        is_data_parallel=args.is_data_parallel)

    optimizer = get_optimizer(list(model_g1.parameters()) +
                              list(model_g2.parameters()) +
                              list(model_f1.parameters()),
from util import save_dic_to_json

fn_list = open(
    "/data/unagi0/dataset/SUNCG-Seg/data_goodlist_v2.txt").readlines()

fn_list = [fn.strip() for fn in fn_list]

cls_hist = np.zeros(256)
counter = Counter()

for fn in tqdm(fn_list):
    gt_fn = os.path.join("/data/unagi0/dataset/SUNCG-Seg/category_v2/",
                         fn + "_category40.png")
    gt_im = Image.open(gt_fn)
    counter += Counter(np.array(gt_im).flatten().astype(long))

print(counter)

mul = lambda x, y: x * y
n_pixel_per_img = reduce(mul, gt_im.size)
correct_n_pixel = n_pixel_per_img * len(fn_list)

got_n_pixel = sum(counter.values())

assert got_n_pixel == correct_n_pixel

print(got_n_pixel)

save_dic_to_json(counter, "suncg_gt_distribution.json")