Esempio n. 1
0
def visualize_training(netG, val_loader,input_stack, target_img, target_texture,segment, vis, loss_graph, args):
    imgs = []
    for ii, data in enumerate(val_loader, 0):
        img, skg, seg, eroded_seg, txt = data  # LAB with negeative value
        if random.random() < 0.5:
            txt = img
        # this is in LAB value 0/100, -128/128 etc
        img = custom_transforms.normalize_lab(img)
        skg = custom_transforms.normalize_lab(skg)
        txt = custom_transforms.normalize_lab(txt)
        seg = custom_transforms.normalize_seg(seg)
        eroded_seg = custom_transforms.normalize_seg(eroded_seg)
        
        bs, w, h = seg.size()
        
        seg = seg.view(bs, 1, w, h)
        seg = torch.cat((seg, seg, seg), 1)
        
        eroded_seg = eroded_seg.view(bs, 1, w, h)
        eroded_seg = torch.cat((eroded_seg, eroded_seg, eroded_seg), 1)

        temp = torch.ones(seg.size()) * (1 - seg).float()
        temp[:, 1, :, :] = 0  # torch.ones(seg[:,1,:,:].size())*(1-seg[:,1,:,:]).float()
        temp[:, 2, :, :] = 0  # torch.ones(seg[:,2,:,:].size())*(1-seg[:,2,:,:]).float()

        txt = txt.float() * seg.float() + temp
      
        patchsize = args.local_texture_size
        batch_size = bs
              
        # seg=custom_transforms.normalize_lab(seg)
        # norm to 0-1 minus mean
        if not args.use_segmentation_patch:
            seg.fill_(1)
            #skg.fill_(0)
            eroded_seg.fill_(1)
        if args.input_texture_patch == 'original_image':
            inp, texture_loc = gen_input_rand(img, skg, eroded_seg[:, 0, :, :] * 100,
                                              args.patch_size_min, args.patch_size_max,
                                              args.num_input_texture_patch)
        elif args.input_texture_patch == 'dtd_texture':
            inp, texture_loc = gen_input_rand(txt, skg, eroded_seg[:, 0, :, :] * 100,
                                              args.patch_size_min, args.patch_size_max,
                                              args.num_input_texture_patch)

        img = img.cuda()
        skg = skg.cuda()
        seg = seg.cuda()
        eroded_seg = eroded_seg.cuda()
        txt = txt.cuda()
        inp = inp.cuda()

        inp.size()

        input_stack.resize_as_(inp.float()).copy_(inp)
        target_img.resize_as_(img.float()).copy_(img)
        segment.resize_as_(seg.float()).copy_(seg)
        target_texture.resize_as_(txt.float()).copy_(txt)

        inputv = Variable(input_stack)
        targetv = Variable(target_img)
        
        gtimgv = Variable(target_img)
        segv = Variable(segment)
        txtv = Variable(target_texture)

        outputG = netG(inputv)
        
        outputl, outputa, outputb = torch.chunk(outputG, 3, dim=1)
        #outputlll = (torch.cat((outputl, outputl, outputl), 1))
        gtl, gta, gtb = torch.chunk(gtimgv, 3, dim=1)
        txtl, txta, txtb = torch.chunk(txtv, 3, dim=1)
        
        gtab = torch.cat((gta, gtb), 1)
        txtab= torch.cat((txta, txtb), 1)
        
        if args.color_space == 'lab':
            outputlll = (torch.cat((outputl, outputl, outputl), 1))
            gtlll = (torch.cat((gtl, gtl, gtl), 1))
            txtlll = torch.cat((txtl, txtl, txtl), 1)
        elif args.color_space == 'rgb':
            outputlll = outputG  # (torch.cat((outputl,outputl,outputl),1))
            gtlll = gtimgv  # (torch.cat((targetl,targetl,targetl),1))
            txtlll = txtv
        if args.loss_texture == 'original_image':
            targetl = gtl
            targetab = gtab
            targetlll = gtlll
        else:
            targetl = txtl
            targetab = txtab
            targetlll = txtlll
       # import pdb; pdb.set_trace()

        texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg, seg, outputlll)
        gt_texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg, seg, targetlll)


    if args.color_space == 'lab':
        out_img = vis_image(custom_transforms.denormalize_lab(outputG.data.double().cpu()),
                            args.color_space)
        temp_labout = custom_transforms.denormalize_lab(texture_patch.data.double().cpu())
        temp_labout[:,1:3,:,:] = 0
        
        temp_labgt = custom_transforms.denormalize_lab(gt_texture_patch.data.double().cpu())
        temp_labgt[:,1:3,:,:] = 0
        temp_out =vis_image(temp_labout,args.color_space) #torch.cat((patches[0].data.double().cpu(),patches[0].data.double().cpu(),patches[0].data.double().cpu()),1)
        #temp_out = (temp_out + 1 )/2
                            
        temp_gt =vis_image(temp_labgt,
                            args.color_space) #torch.cat((patches[1].data.double().cpu(),patches[1].data.double().cpu(),patches[1].data.double().cpu()),1)

        if args.input_texture_patch == 'original_image':
            inp_img = vis_patch(custom_transforms.denormalize_lab(img.cpu()),
                                custom_transforms.denormalize_lab(skg.cpu()),
                                texture_loc,
                                args.color_space)
        elif args.input_texture_patch == 'dtd_texture':
            inp_img = vis_patch(custom_transforms.denormalize_lab(txt.cpu()),
                                custom_transforms.denormalize_lab(skg.cpu()),
                                texture_loc,
                                args.color_space)
        tar_img = vis_image(custom_transforms.denormalize_lab(img.cpu()),
                            args.color_space)
        skg_img = vis_image(custom_transforms.denormalize_lab(skg.cpu()),
                            args.color_space)
        txt_img = vis_image(custom_transforms.denormalize_lab(txt.cpu()),
                            args.color_space)
    elif args.color_space == 'rgb':

        out_img = vis_image(custom_transforms.denormalize_rgb(outputG.data.double().cpu()),
                            args.color_space)
        inp_img = vis_patch(custom_transforms.denormalize_rgb(img.cpu()),
                            custom_transforms.denormalize_rgb(skg.cpu()),
                            texture_loc,
                            args.color_space)
        tar_img = vis_image(custom_transforms.denormalize_rgb(img.cpu()),
                            args.color_space)

    out_final = [x*0 for x in txt_img] 
    gt_final = [x*0 for x in txt_img] 
    out_img = [x * 255 for x in out_img]  # (out_img*255)#.astype('uint8')
    skg_img = [x * 255 for x in skg_img]  # (out_img*255)#.astype('uint8')
    out_patch = [x * 255 for x in temp_out]
    gt_patch = [x * 255 for x in temp_gt]    # out_img=np.transpose(out_img,(2,0,1))
    for t_i in range(bs):
        #import pdb; pdb.set_trace()
        patchsize = int(args.local_texture_size)
        out_final[t_i][:,0:patchsize,0:patchsize] = out_patch[t_i][:,:,:]# .append(np.resize(out_patch[t_i], (3,w,h)))
        gt_final[t_i][:,0:patchsize,0:patchsize] =gt_patch[t_i][:,:,:]#gt_final.append(np.resize(gt_patch[t_i], (3,w,h)))
   
    
    # out_img=np.transpose(out_img,(2,0,1))

    txt_img = [x * 255 for x in txt_img]    
    inp_img = [x * 255 for x in inp_img]  # (inp_img*255)#.astype('uint8')
    # inp_img=np.transpose(inp_img,(2,0,1))

    tar_img = [x * 255 for x in tar_img]  # (tar_img*255)#.astype('uint8')
    # tar_img=np.transpose(tar_img,(2,0,1))
    #import pdb; pdb.set_trace()
    
    #segment_img = vis_image((eroded_seg.cpu()), args.color_space)
    #import pdb; pdb.set_trace()
    segment_img = [x * 255 for x in eroded_seg.cpu().numpy()]  # segment_img=(segment_img*255)#.astype('uint8')
    # segment_img=np.transpose(segment_img,(2,0,1))
    #import pdb; pdb.set_trace()
    for i_ in range(len(out_img)):
        #import pdb; pdb.set_trace()
        imgs.append(skg_img[i_])
        imgs.append(txt_img[i_])
        imgs.append(inp_img[i_])
        imgs.append(out_img[i_])
        imgs.append(segment_img[i_])
        imgs.append(tar_img[i_])
        imgs.append(out_final[i_])
        imgs.append(gt_final[i_])

    # for idx, img in enumerate(imgs):
    #     print(idx, type(img), img.shape)

    vis.images(imgs, win='output', opts=dict(title='Output images'))
    # vis.image(inp_img,win='input',opts=dict(title='input'))
    # vis.image(tar_img,win='target',opts=dict(title='target'))
    # vis.image(segment_img,win='segment',opts=dict(title='segment'))
    vis.line(np.array(loss_graph["gs"]), win='gs', opts=dict(title='G-Style Loss'))
    vis.line(np.array(loss_graph["g"]), win='g', opts=dict(title='G Total Loss'))
    vis.line(np.array(loss_graph["gd"]), win='gd', opts=dict(title='G-Discriminator Loss'))
    vis.line(np.array(loss_graph["gf"]), win='gf', opts=dict(title='G-Feature Loss'))
    vis.line(np.array(loss_graph["gpl"]), win='gpl', opts=dict(title='G-Pixel Loss-L'))
    vis.line(np.array(loss_graph["gpab"]), win='gpab', opts=dict(title='G-Pixel Loss-AB'))
    vis.line(np.array(loss_graph["d"]), win='d', opts=dict(title='D Loss'))
    if args.local_texture_size != -1:
        vis.line(np.array(loss_graph["dl"]), win='dl', opts=dict(title='D Local Loss'))
        vis.line(np.array(loss_graph["gdl"]), win='gdl', opts=dict(title='G D Local Loss'))
Esempio n. 2
0
img = custom_transforms.normalize_lab(img)
skg = custom_transforms.normalize_lab(skg)
txt = custom_transforms.normalize_lab(txt)
seg = custom_transforms.normalize_seg(seg)
eroded_seg = custom_transforms.normalize_seg(eroded_seg)

inp, texture_loc = get_input(data, 100, 100, 50, 1)

seg = seg != 0

model = netG

inpv = get_inputv(inp.cuda())
output = model(inpv.cuda())

out_img = vis_image(custom_transforms.denormalize_lab(output.data.double().cpu()),
                                    color_space)
inp_img = vis_patch(custom_transforms.denormalize_lab(txt.cpu()),
                            custom_transforms.denormalize_lab(skg.cpu()),
                            texture_loc,
                            color_space)
tar_img = vis_image(custom_transforms.denormalize_lab(img.cpu()),
                        color_space)

plt.figure()
plt.imshow(np.transpose(inp_img[0],(1, 2, 0)))
#plt.axis('off')
#plt.figure()
plt.figure()
plt.imshow(np.transpose(out_img[0],(1, 2, 0)))
plt.show()
Esempio n. 3
0
def get_prediction(image_bytes, txt_bytes, pos_x, pos_y):
    try:
        # code for TextureGAN


        transform = get_transforms(args)
        # val = make_dataset(args.data_path, 'val')
        # valDset = ImageFolder('val', args.data_path, transform)
        # val_display_size = 1
        # valLoader = DataLoader(dataset=valDset, batch_size=val_display_size,shuffle=True)
        # new load data
        # img_path = "./dataset/test/img.jpg"
        img_path = image_bytes
        # skg_path = "./dataset/test/skg.jpg"
        skg_path = image_bytes
        seg_path = "./dataset/test/blank.jpg"
        eroded_seg_path = seg_path
        # txt_path = "./dataset/test/txt.jpg"
        txt_path = txt_bytes

        img = pil_loader1(img_path)
        skg = pil_loader1(skg_path)
        seg = pil_loader(seg_path)
        txt = pil_loader1(txt_path)
        eroded_seg = pil_loader(eroded_seg_path)
        img, skg, seg, eroded_seg, txt = transform([img, skg, seg, eroded_seg, txt])
        img = img.unsqueeze(0)
        skg = skg.unsqueeze(0)
        txt = txt.unsqueeze(0)
        seg = seg.unsqueeze(0)
        eroded_seg = eroded_seg.unsqueeze(0)
        data = [img, skg, seg, eroded_seg, txt]

        # load model
        model_location = './pretrained_models/G_net_texturegan_18_300.pth'
        netG = texturegan.TextureGAN(5, 3, 32)
        load_network(netG, model_location)
        netG.eval()

        # get_ipython().run_line_magic('matplotlib', 'inline')
        warnings.filterwarnings('ignore')

        # start val
        color_space = 'lab'

        img, skg, seg, eroded_seg, txt = data
        # print(txt.size())
        img = custom_transforms.normalize_lab(img)
        skg = custom_transforms.normalize_lab(skg)
        txt = custom_transforms.normalize_lab(txt)
        seg = custom_transforms.normalize_seg(seg)
        eroded_seg = custom_transforms.normalize_seg(eroded_seg)

        inp, texture_loc = get_input(data, pos_x, pos_y, 50, 1)

        seg = seg != 0

        model = netG

        inpv = get_inputv(inp.cuda())
        output = model(inpv.cuda())

        out_img = vis_image(custom_transforms.denormalize_lab(output.data.double().cpu()),
                            color_space)
        inp_img = vis_patch(custom_transforms.denormalize_lab(txt.cpu()),
                            custom_transforms.denormalize_lab(skg.cpu()),
                            texture_loc,
                            color_space)
        tar_img = vis_image(custom_transforms.denormalize_lab(img.cpu()),
                            color_space)

        # plt.figure()
        # plt.imshow(np.transpose(inp_img[0], (1, 2, 0)))
        # plt.imsave("./out_img.jpg", np.transpose(out_img[0], (1, 2, 0)))
        # # plt.axis('off')
        # # plt.figure()
        # plt.figure()
        # plt.imshow(np.transpose(out_img[0], (1, 2, 0)))
        # plt.show()
        inp_img = np.transpose(inp_img[0], (1, 2, 0))
        out_img = np.transpose(out_img[0], (1, 2, 0))
        return out_img
    except Exception as e:
        print("error!!!!!!!!!!!!")
        return inp_img