def calc_colors():
    opt = TestOptions().parse()
    save_img_path = opt.results_img_dir
    if os.path.isdir(save_img_path) is False:
        print('Create path: {0}'.format(save_img_path))
        os.makedirs(save_img_path)
    opt.batch_size = 1
    dataset = Fusion_Testing_Dataset(opt)
    dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=2)

    dataset_size = len(dataset)
    print('#Testing images = %d' % dataset_size)

    model = create_model(opt)
    # model.setup_to_test('coco_finetuned_mask_256')
    model.setup_to_test('coco_finetuned_mask_256_ffs')

    count_empty = 0
    for data_raw in tqdm(dataset_loader, dynamic_ncols=True):
        # if os.path.isfile(join(save_img_path, data_raw['file_id'][0] + '.png')) is True:
        #     continue
        data_raw['full_img'][0] = data_raw['full_img'][0].cuda()
        if data_raw['empty_box'][0] == 0:
            data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda()
            box_info = data_raw['box_info'][0]
            box_info_2x = data_raw['box_info_2x'][0]
            box_info_4x = data_raw['box_info_4x'][0]
            box_info_8x = data_raw['box_info_8x'][0]
            cropped_data = util.get_colorization_data(data_raw['cropped_img'], opt, ab_thresh=0, p=opt.sample_p)
            full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)
            model.set_input(cropped_data)
            model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
            model.forward()
        else:
            count_empty += 1
            full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)
            model.set_forward_without_box(full_img_data)
        specific_file_save_path = join(save_img_path, data_raw['file_id'][0] + '.png')
        model.save_current_imgs(specific_file_save_path)
    print('{0} images without bounding boxes'.format(count_empty))
    return specific_file_save_path
Exemple #2
0
    # './datasets/val/100APPLE/IMG_0791.png']
    for i, data_raw in enumerate(dataset_loader):
        data_raw = data_raw.cuda()
        data_raw = util.crop_mult(data_raw, mult=8)

        # pdb.set_trace();
        # (Pdb) pp data_raw.size()
        # torch.Size([1, 3, 256, 256])

        # with no points
        for (pp, sample_p) in enumerate(sample_ps):
            xxx = '%08d_%.3f' % (i, sample_p)
            img_path = [xxx.replace('.', 'p')]

            data = util.get_colorization_data(data_raw,
                                              opt,
                                              ab_thresh=0.,
                                              p=sample_p)

            # (Pdb) pp data.keys()
            # dict_keys(['A', 'B', 'hint_B', 'mask_B'])

            # (Pdb) pp data['hint_B'].size()
            # torch.Size([1, 2, 256, 256])
            # (Pdb) pp data['hint_B'].max()
            # tensor(0., device='cuda:0')
            # (Pdb) pp data['hint_B'].min()
            # tensor(0., device='cuda:0')

            # (Pdb) pp data['mask_B'].size()
            # torch.Size([1, 1, 256, 256])
            # (Pdb) pp data['mask_B'].min()
Exemple #3
0
    psnrs = np.zeros((opt.how_many, S))
    entrs = np.zeros((opt.how_many, S))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for i, data_raw in enumerate(dataset_loader):
        data_raw[0] = data_raw[0].to(device)
        data_raw[0] = util.crop_mult(data_raw[0], mult=8)
        data = {}
        A = []
        B = []

        HintB = []
        MaskB = []
        for f in range(0, data_raw[0].shape[1] * opt.n_frames, 3):
            #                 print(data_raw[:,f:f+3].shape)
            d = util.get_colorization_data(data_raw[0], opt, p=opt.sample_p)
            A.append(d["A"])
            B.append(d["B"])
            HintB.append(d["hint_B"])
            MaskB.append(d["mask_B"])
        data["A"] = torch.cat(A, 1)
        data["B"] = torch.cat(B, 1)
        data["hint_B"] = torch.cat(HintB, 1)
        data["mask_B"] = torch.cat(MaskB, 1)

        # with no points
        for (pp, sample_p) in enumerate(sample_ps):
            img_path = [('%08d_%.3f' % (i, sample_p)).replace('.', 'p')]
            # data = util.get_colorization_data(data_raw[0], opt, ab_thresh=0., p=sample_p)

            model.set_input(data)
    # model.setup_to_test('coco_finetuned_mask_256')
    model.setup_to_test('coco_finetuned_mask_256_ffs')

    count_empty = 0
    for data_raw in tqdm(dataset_loader, dynamic_ncols=True):
        # if os.path.isfile(join(save_img_path, data_raw['file_id'][0] + '.png')) is True:
        #     continue
        data_raw['full_img'][0] = data_raw['full_img'][0].cuda()
        if data_raw['empty_box'][0] == 0:
            data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda()
            box_info = data_raw['box_info'][0]
            box_info_2x = data_raw['box_info_2x'][0]
            box_info_4x = data_raw['box_info_4x'][0]
            box_info_8x = data_raw['box_info_8x'][0]
            cropped_data = util.get_colorization_data(data_raw['cropped_img'],
                                                      opt,
                                                      ab_thresh=0,
                                                      p=opt.sample_p)
            full_img_data = util.get_colorization_data(data_raw['full_img'],
                                                       opt,
                                                       ab_thresh=0,
                                                       p=opt.sample_p)
            model.set_input(cropped_data)
            model.set_fusion_input(
                full_img_data,
                [box_info, box_info_2x, box_info_4x, box_info_8x])
            model.forward()
        else:
            count_empty += 1
            full_img_data = util.get_colorization_data(data_raw['full_img'],
                                                       opt,
                                                       ab_thresh=0,
Exemple #5
0
                                      time.minute)

    shutil.copyfile('./checkpoints/%s/latest_net_G.pth' % opt.name,
                    './checkpoints/%s/%s_net_G.pth' % (opt.name, str_now))

    psnrs = np.zeros((opt.how_many, N))

    bar = pb.ProgressBar(maxval=opt.how_many).start()
    for i, data_raw in enumerate(dataset_loader):
        data_raw[0] = data_raw[0].cuda()
        data_raw[0] = util.crop_mult(data_raw[0], mult=8)

        for nn in range(N):
            # embed()
            data = util.get_colorization_data(data_raw,
                                              opt,
                                              ab_thresh=0.,
                                              num_points=num_points[nn])

            model.set_input(data)
            model.test()
            visuals = model.get_current_visuals()

            psnrs[i, nn] = util.calculate_psnr_np(
                util.tensor2im(visuals['real']),
                util.tensor2im(visuals['fake_reg']))

        if i == opt.how_many - 1:
            break

        bar.update(i)
Exemple #6
0
    opt.display_port = 8098
    visualizer = Visualizer(opt)
    total_steps = 0

    if opt.stage == 'full' or opt.stage == 'instance':
        for epoch in trange(opt.epoch_count, opt.niter + opt.niter_decay, desc='epoch', dynamic_ncols=True):
            epoch_iter = 0

            for data_raw in tqdm(dataset_loader, desc='batch', dynamic_ncols=True, leave=False):
                total_steps += opt.batch_size
                epoch_iter += opt.batch_size

                data_raw['rgb_img'] = [data_raw['rgb_img']]
                data_raw['gray_img'] = [data_raw['gray_img']]

                input_data = util.get_colorization_data(data_raw['gray_img'], opt, p=1.0, ab_thresh=0)
                gt_data = util.get_colorization_data(data_raw['rgb_img'], opt, p=1.0, ab_thresh=10.0)
                if gt_data is None:
                    continue
                if(gt_data['B'].shape[0] < opt.batch_size):
                    continue
                input_data['B'] = gt_data['B']
                input_data['hint_B'] = gt_data['hint_B']
                input_data['mask_B'] = gt_data['mask_B']

                visualizer.reset()
                model.set_input(input_data)
                model.optimize_parameters()

                if total_steps % opt.display_freq == 0:
                    save_result = total_steps % opt.update_html_freq == 0
Exemple #7
0
            transforms.RandomChoice([
                transforms.RandomResizedCrop(opt.fineSize, interpolation=1),
                transforms.RandomResizedCrop(opt.fineSize, interpolation=2),
                transforms.RandomResizedCrop(opt.fineSize, interpolation=3)
            ]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]))
    lens = len(train_datasets)
    print('train datasets is [{}] '.format(lens))
    dataloader = torch.utils.data.DataLoader(train_datasets,
                                             batch_size=batch_size,
                                             shuffle=True)
    for epoch in range(epoches):
        for index, data in enumerate(dataloader):
            data = util.get_colorization_data(data, opt, p=opt.sample_p)
            if (data is None):
                continue
            input = torch.cat((data['A'], data['hint_B'], data['mask_B']),
                              dim=1)
            input = input.to(device)
            outputclass, outputreg = net(input)
            realclass = util.encode_ab_ind(data['B'][:, :, ::4, ::4],
                                           opt).to(device)
            lossreg = L1oss(outputreg, data['B'].to(device))
            print(outputclass.dtype, realclass.dtype)
            lossclass = CEloss(
                outputclass.type(torch.cuda.FloatTensor),
                realclass[:, 0, :, :].type(torch.cuda.LongTensor))

            if record:
Exemple #8
0
    # net = siggraph.SIGGRAPHGenerator(4,2)
    net.load_state_dict(torch.load(state_dict_name))
    net.eval()
    net.to(device)
    if len(gpu_ids)>1:
        net = nn.DataParallel(net,device_ids=[int(item) for item in gpu_ids.split(',')])  # list of int

    train_datasets = ImageFolder(os.path.join(root,phase),transform=transforms.Compose([
                                                   transforms.Resize((opt.loadSize, opt.loadSize)),
                                                   transforms.ToTensor()]))
    lens = len(train_datasets)
    print('test datasets is [{}] '.format(lens))
    dataloader = torch.utils.data.DataLoader(train_datasets,batch_size=batch_size,shuffle=False)
    for index , data in enumerate(dataloader):
        data[0] = util.crop_mult(data[0], mult=8)
        data = util.get_colorization_data(data,opt,ab_thresh=0,p=opt.sample_p)
        if(data is None):
            continue
        input = torch.cat((data['A'],data['hint_B'],data['mask_B']),dim=1)
        input = input.to(device)
        outputclass,outputreg = net(input)
        realclass = util.encode_ab_ind(data['B'][:,:,::4,::4],opt).to(device)
        lossreg = L1oss(outputreg,data['B'].to(device))
        lossclass = CEloss(outputclass.type(torch.cuda.FloatTensor),realclass[:,0,:,:].type(torch.cuda.LongTensor))
                        
        image_fake = util.lab2rgb(torch.cat([data['A'].type(torch.cuda.FloatTensor),outputreg.type(torch.cuda.FloatTensor)],dim=1),opt)
        print('  images [{}/{}] loss is [reg: {:.5}/[class: {:.5}]],  '.
                format((index+1)*data['A'].shape[0],lens,lossreg.item()*10,lossclass.item()))
        # torchvision.utils.save_image(image_fake,'hemin.png')
        if index %loss_fre == 0:
            writer.add_scalars('train/loss:',{'reg':lossreg.item()*10,
Exemple #9
0
def select_photos():
    to_visualize = ['gray', 'real', 'fake_reg', 'mask', 'hint']
    opt.load_model = True
    opt.num_threads = 1
    opt.batch_size = 1
    opt.display_id = -1
    opt.phase = 'val'

    opt.serial_batches = True
    opt.aspect_ratio = 1.

    model = create_model(opt)
    model.setup(opt)

    # tensor_image = tensor_image.cuda()
    sample_ps = 0.03125
    data_raw = [None] * 4

    # loader = transforms.Compose(transforms.ToTensor())
    # raw_image = Image.open(args.input)
    # # tensor_image = TF.to_tensor(raw_image)
    # tensor_image = loader(raw_image.float())
    # tensor_image = raw_image.unsqueeze(0)
    # data_raw[0] = tensor_image.cuda()

    # print(opt.input)
    tensor_image = Image.open(opt.input).convert('RGB')
    tensor_image = tensor_image.resize((opt.loadSize, opt.loadSize))
    tensor_image = ToTensor()(tensor_image).unsqueeze(0)
    data_raw[0] = tensor_image.cuda()
    # data_raw[0] = util.crop_mult(data_raw[0], mult=8)

    data = util.get_colorization_data(data_raw, opt, ab_thresh=0., p=sample_ps)

    model.set_input(data)

    # model.eval()
    model.test(True)

    # gets the visuals from the model
    global visuals
    visuals = util.get_subset_dict(model.get_current_visuals(), to_visualize)

    # output images
    raw_image = Image.fromarray(util.tensor2im(visuals['real']))
    image = raw_image.resize((512, 512), Image.ANTIALIAS)
    image = ImageTk.PhotoImage(image)
    label = tk.Label(text="Original Image", compound='top', image=image)
    label.photo = image  # assign to class variable to resolve problem with bug in `PhotoImage`
    label.grid(row=1, column=1)
    all_labels.append(label)

    raw_image = Image.fromarray(util.tensor2im(visuals['hint']))
    # lab_raw_image = util.rgb2lab(raw_image, opt)

    image = raw_image.resize((512, 512), Image.ANTIALIAS)
    image = ImageTk.PhotoImage(image)
    label = tk.Label(image=image, text="Hint Image", compound='top')
    label.bind("<Button-1>", lambda e: choose_colour())
    label.photo = image
    label.grid(row=1, column=3)
    all_labels.append(label)

    raw_image = Image.fromarray(util.tensor2im(visuals['fake_reg']))
    image = raw_image.resize((512, 512), Image.ANTIALIAS)
    image = ImageTk.PhotoImage(image)
    label = tk.Label(image=image, text="Colourised Image", compound='top')
    label.bind("<Button-1>", lambda e: choose_colour())
    label.photo = image
    label.grid(row=1, column=5)
    all_labels.append(label)

    original = Image.open(opt.input)
    original = original.resize((opt.loadSize, opt.loadSize))
    original = np.asarray(original)
    label = tk.Label(text="PSNR Numpy: " + str(
        util.calculate_psnr_np(original, util.tensor2im(visuals['fake_reg']))))
    label.grid(row=4, column=1)
    all_labels.append(label)

    label = tk.Label(
        text="MSE : " +
        str(util.calculate_mse(original, util.tensor2im(visuals['fake_reg']))))
    label.grid(row=4, column=3)
    all_labels.append(label)

    label = tk.Label(
        text="MSE : " +
        str(util.calculate_mae(original, util.tensor2im(visuals['fake_reg']))))
    label.grid(row=4, column=5)

    all_labels.append(label)
        self.model.forward()
        fake_reg = torch.cat((self.model.real_A, self.model.fake_B_reg), dim=1)
        return fake_reg


image_path = "./large.JPG"
image = Image.open(image_path)
image = transforms.Compose([
    transforms.Resize(512),
    transforms.ToTensor(),
])(image)
image = image.view(1, *image.shape)
image = util.crop_mult(image, mult=8, HWmax=[4032, 4032])
transforms.ToPILImage()(image[0]).show(command='fim')

data = util.get_colorization_data([image], opt, ab_thresh=0., p=0.125)
img = torch.cat((data["A"], data["B"]), dim=1)
hint = torch.cat((data["hint_B"], data["mask_B"]), dim=1)

# print(data["mask_B"], data["hint_B"])
# data["hint_B"] = torch.zeros_like(data["hint_B"])
# data["mask_B"] = torch.zeros_like(data["mask_B"])
# model = Colorization()
with torch.no_grad():
    model = Colorization()
    model.eval()
    for param in model.parameters():
        param.requires_grad = False
    model.model.set_requires_grad(model.model.netG)

# model(data)
Exemple #11
0
def open(fname):
    # copy to input
    input_img = 'example/input.jpg'
    if fname.startswith('http'):
        urlretrieve(fname, input_img)
    else:
        os.system(f'cp {fname} {input_img}')
    # find bbox
    img = cv2.imread(input_img)
    lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l_channel, a_channel, b_channel = cv2.split(lab_image)
    l_stack = np.stack([l_channel, l_channel, l_channel], axis=2)
    outputs = predictor(l_stack)
    save_path = 'example_bbox/input'
    pred_bbox = outputs["instances"].pred_boxes.to(
        torch.device('cpu')).tensor.numpy()
    pred_scores = outputs["instances"].scores.cpu().data.numpy()
    np.savez(save_path, bbox=pred_bbox, scores=pred_scores)
    # load data
    dataset = Fusion_Testing_Dataset(opt, -1)
    dataset_loader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batch_size)
    data_raw = next(iter(dataset_loader))
    # predict colors
    data_raw['full_img'][0] = data_raw['full_img'][0].cuda()
    if data_raw['empty_box'][0] == 0:
        data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda()
        box_info = data_raw['box_info'][0]
        box_info_2x = data_raw['box_info_2x'][0]
        box_info_4x = data_raw['box_info_4x'][0]
        box_info_8x = data_raw['box_info_8x'][0]
        cropped_data = util.get_colorization_data(data_raw['cropped_img'],
                                                  opt,
                                                  ab_thresh=0,
                                                  p=opt.sample_p)
        full_img_data = util.get_colorization_data(data_raw['full_img'],
                                                   opt,
                                                   ab_thresh=0,
                                                   p=opt.sample_p)
        model.set_input(cropped_data)
        model.set_fusion_input(
            full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
        model.forward()
    else:
        full_img_data = util.get_colorization_data(data_raw['full_img'],
                                                   opt,
                                                   ab_thresh=0,
                                                   p=opt.sample_p)
        model.set_forward_without_box(full_img_data)
    model.save_current_imgs(
        join(save_img_path, data_raw['file_id'][0] + '.png'))
    # combine grey & color
    img = cv2.imread('example/input.jpg')
    lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l_channel, _, _ = cv2.split(lab_image)

    img = cv2.imread('results/input.png')
    lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    _, a_pred, b_pred = cv2.split(lab_image)
    a_pred = cv2.resize(a_pred, (l_channel.shape[1], l_channel.shape[0]))
    b_pred = cv2.resize(b_pred, (l_channel.shape[1], l_channel.shape[0]))

    color_image = cv2.cvtColor(np.stack([l_channel, a_pred, b_pred], 2),
                               cv2.COLOR_LAB2BGR)
    return PIL.Image.fromarray(color_image[:, :, ::-1])  # BGR -> RGB