def get_input(val_loader, xcenter, ycenter, patch_size, num_patch): img, skg, seg, eroded_seg, txt = val_loader img = custom_transforms.normalize_lab(img) skg = custom_transforms.normalize_lab(skg) txt = custom_transforms.normalize_lab(txt) seg = custom_transforms.normalize_seg(seg) eroded_seg = custom_transforms.normalize_seg(eroded_seg) bs, w, h = seg.size() seg = seg.view(bs, 1, w, h) seg = torch.cat((seg, seg, seg), 1) eroded_seg = eroded_seg.view(bs, 1, w, h) eroded_seg = torch.cat((eroded_seg, eroded_seg, eroded_seg), 1) temp = torch.ones(seg.size()) * (1 - seg).float() temp[:, 1, :, :] = 0 # torch.ones(seg[:,1,:,:].size())*(1-seg[:,1,:,:]).float() temp[:, 2, :, :] = 0 # torch.ones(seg[:,2,:,:].size())*(1-seg[:,2,:,:]).float() txt = txt.float() * seg.float() + temp patchsize = args.local_texture_size batch_size = bs if xcenter < 0 or ycenter < 0: inp, texture_loc = gen_input_rand(txt, skg, eroded_seg[:, 0, :, :] * 100, patch_size, patch_size, num_patch) else: inp, texture_loc = gen_input_exact(txt, skg, eroded_seg[:, 0, :, :] * 100, xcenter, ycenter, patch_size, 1) return inp, texture_loc
netG = texturegan.TextureGAN(5, 3, 32) load_network(netG, model_location) netG.eval() # get_ipython().run_line_magic('matplotlib', 'inline') warnings.filterwarnings('ignore') # start val color_space = 'lab' img, skg, seg, eroded_seg, txt = data print(txt.size()) img = custom_transforms.normalize_lab(img) skg = custom_transforms.normalize_lab(skg) txt = custom_transforms.normalize_lab(txt) seg = custom_transforms.normalize_seg(seg) eroded_seg = custom_transforms.normalize_seg(eroded_seg) inp, texture_loc = get_input(data, 100, 100, 50, 1) seg = seg != 0 model = netG inpv = get_inputv(inp.cuda()) output = model(inpv.cuda()) out_img = vis_image(custom_transforms.denormalize_lab(output.data.double().cpu()), color_space) inp_img = vis_patch(custom_transforms.denormalize_lab(txt.cpu()), custom_transforms.denormalize_lab(skg.cpu()),
def visualize_training(netG, val_loader,input_stack, target_img, target_texture,segment, vis, loss_graph, args): imgs = [] for ii, data in enumerate(val_loader, 0): img, skg, seg, eroded_seg, txt = data # LAB with negeative value if random.random() < 0.5: txt = img # this is in LAB value 0/100, -128/128 etc img = custom_transforms.normalize_lab(img) skg = custom_transforms.normalize_lab(skg) txt = custom_transforms.normalize_lab(txt) seg = custom_transforms.normalize_seg(seg) eroded_seg = custom_transforms.normalize_seg(eroded_seg) bs, w, h = seg.size() seg = seg.view(bs, 1, w, h) seg = torch.cat((seg, seg, seg), 1) eroded_seg = eroded_seg.view(bs, 1, w, h) eroded_seg = torch.cat((eroded_seg, eroded_seg, eroded_seg), 1) temp = torch.ones(seg.size()) * (1 - seg).float() temp[:, 1, :, :] = 0 # torch.ones(seg[:,1,:,:].size())*(1-seg[:,1,:,:]).float() temp[:, 2, :, :] = 0 # torch.ones(seg[:,2,:,:].size())*(1-seg[:,2,:,:]).float() txt = txt.float() * seg.float() + temp patchsize = args.local_texture_size batch_size = bs # seg=custom_transforms.normalize_lab(seg) # norm to 0-1 minus mean if not args.use_segmentation_patch: seg.fill_(1) #skg.fill_(0) eroded_seg.fill_(1) if args.input_texture_patch == 'original_image': inp, texture_loc = gen_input_rand(img, skg, eroded_seg[:, 0, :, :] * 100, args.patch_size_min, args.patch_size_max, args.num_input_texture_patch) elif args.input_texture_patch == 'dtd_texture': inp, texture_loc = gen_input_rand(txt, skg, eroded_seg[:, 0, :, :] * 100, args.patch_size_min, args.patch_size_max, args.num_input_texture_patch) img = img.cuda() skg = skg.cuda() seg = seg.cuda() eroded_seg = eroded_seg.cuda() txt = txt.cuda() inp = inp.cuda() inp.size() input_stack.resize_as_(inp.float()).copy_(inp) target_img.resize_as_(img.float()).copy_(img) segment.resize_as_(seg.float()).copy_(seg) target_texture.resize_as_(txt.float()).copy_(txt) inputv = Variable(input_stack) targetv = Variable(target_img) gtimgv = Variable(target_img) segv = Variable(segment) txtv = Variable(target_texture) outputG = netG(inputv) outputl, outputa, outputb = torch.chunk(outputG, 3, dim=1) #outputlll = (torch.cat((outputl, outputl, outputl), 1)) gtl, gta, gtb = torch.chunk(gtimgv, 3, dim=1) txtl, txta, txtb = torch.chunk(txtv, 3, dim=1) gtab = torch.cat((gta, gtb), 1) txtab= torch.cat((txta, txtb), 1) if args.color_space == 'lab': outputlll = (torch.cat((outputl, outputl, outputl), 1)) gtlll = (torch.cat((gtl, gtl, gtl), 1)) txtlll = torch.cat((txtl, txtl, txtl), 1) elif args.color_space == 'rgb': outputlll = outputG # (torch.cat((outputl,outputl,outputl),1)) gtlll = gtimgv # (torch.cat((targetl,targetl,targetl),1)) txtlll = txtv if args.loss_texture == 'original_image': targetl = gtl targetab = gtab targetlll = gtlll else: targetl = txtl targetab = txtab targetlll = txtlll # import pdb; pdb.set_trace() texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg, seg, outputlll) gt_texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg, seg, targetlll) if args.color_space == 'lab': out_img = vis_image(custom_transforms.denormalize_lab(outputG.data.double().cpu()), args.color_space) temp_labout = custom_transforms.denormalize_lab(texture_patch.data.double().cpu()) temp_labout[:,1:3,:,:] = 0 temp_labgt = custom_transforms.denormalize_lab(gt_texture_patch.data.double().cpu()) temp_labgt[:,1:3,:,:] = 0 temp_out =vis_image(temp_labout,args.color_space) #torch.cat((patches[0].data.double().cpu(),patches[0].data.double().cpu(),patches[0].data.double().cpu()),1) #temp_out = (temp_out + 1 )/2 temp_gt =vis_image(temp_labgt, args.color_space) #torch.cat((patches[1].data.double().cpu(),patches[1].data.double().cpu(),patches[1].data.double().cpu()),1) if args.input_texture_patch == 'original_image': inp_img = vis_patch(custom_transforms.denormalize_lab(img.cpu()), custom_transforms.denormalize_lab(skg.cpu()), texture_loc, args.color_space) elif args.input_texture_patch == 'dtd_texture': inp_img = vis_patch(custom_transforms.denormalize_lab(txt.cpu()), custom_transforms.denormalize_lab(skg.cpu()), texture_loc, args.color_space) tar_img = vis_image(custom_transforms.denormalize_lab(img.cpu()), args.color_space) skg_img = vis_image(custom_transforms.denormalize_lab(skg.cpu()), args.color_space) txt_img = vis_image(custom_transforms.denormalize_lab(txt.cpu()), args.color_space) elif args.color_space == 'rgb': out_img = vis_image(custom_transforms.denormalize_rgb(outputG.data.double().cpu()), args.color_space) inp_img = vis_patch(custom_transforms.denormalize_rgb(img.cpu()), custom_transforms.denormalize_rgb(skg.cpu()), texture_loc, args.color_space) tar_img = vis_image(custom_transforms.denormalize_rgb(img.cpu()), args.color_space) out_final = [x*0 for x in txt_img] gt_final = [x*0 for x in txt_img] out_img = [x * 255 for x in out_img] # (out_img*255)#.astype('uint8') skg_img = [x * 255 for x in skg_img] # (out_img*255)#.astype('uint8') out_patch = [x * 255 for x in temp_out] gt_patch = [x * 255 for x in temp_gt] # out_img=np.transpose(out_img,(2,0,1)) for t_i in range(bs): #import pdb; pdb.set_trace() patchsize = int(args.local_texture_size) out_final[t_i][:,0:patchsize,0:patchsize] = out_patch[t_i][:,:,:]# .append(np.resize(out_patch[t_i], (3,w,h))) gt_final[t_i][:,0:patchsize,0:patchsize] =gt_patch[t_i][:,:,:]#gt_final.append(np.resize(gt_patch[t_i], (3,w,h))) # out_img=np.transpose(out_img,(2,0,1)) txt_img = [x * 255 for x in txt_img] inp_img = [x * 255 for x in inp_img] # (inp_img*255)#.astype('uint8') # inp_img=np.transpose(inp_img,(2,0,1)) tar_img = [x * 255 for x in tar_img] # (tar_img*255)#.astype('uint8') # tar_img=np.transpose(tar_img,(2,0,1)) #import pdb; pdb.set_trace() #segment_img = vis_image((eroded_seg.cpu()), args.color_space) #import pdb; pdb.set_trace() segment_img = [x * 255 for x in eroded_seg.cpu().numpy()] # segment_img=(segment_img*255)#.astype('uint8') # segment_img=np.transpose(segment_img,(2,0,1)) #import pdb; pdb.set_trace() for i_ in range(len(out_img)): #import pdb; pdb.set_trace() imgs.append(skg_img[i_]) imgs.append(txt_img[i_]) imgs.append(inp_img[i_]) imgs.append(out_img[i_]) imgs.append(segment_img[i_]) imgs.append(tar_img[i_]) imgs.append(out_final[i_]) imgs.append(gt_final[i_]) # for idx, img in enumerate(imgs): # print(idx, type(img), img.shape) vis.images(imgs, win='output', opts=dict(title='Output images')) # vis.image(inp_img,win='input',opts=dict(title='input')) # vis.image(tar_img,win='target',opts=dict(title='target')) # vis.image(segment_img,win='segment',opts=dict(title='segment')) vis.line(np.array(loss_graph["gs"]), win='gs', opts=dict(title='G-Style Loss')) vis.line(np.array(loss_graph["g"]), win='g', opts=dict(title='G Total Loss')) vis.line(np.array(loss_graph["gd"]), win='gd', opts=dict(title='G-Discriminator Loss')) vis.line(np.array(loss_graph["gf"]), win='gf', opts=dict(title='G-Feature Loss')) vis.line(np.array(loss_graph["gpl"]), win='gpl', opts=dict(title='G-Pixel Loss-L')) vis.line(np.array(loss_graph["gpab"]), win='gpab', opts=dict(title='G-Pixel Loss-AB')) vis.line(np.array(loss_graph["d"]), win='d', opts=dict(title='D Loss')) if args.local_texture_size != -1: vis.line(np.array(loss_graph["dl"]), win='dl', opts=dict(title='D Local Loss')) vis.line(np.array(loss_graph["gdl"]), win='gdl', opts=dict(title='G D Local Loss'))
def train(model, train_loader, val_loader, input_stack, target_img, target_texture, segment, label,label_local, extract_content, extract_style, loss_graph, vis, epoch, args): netG = model["netG"] netD = model["netD"] netD_local = model["netD_local"] criterion_gan = model["criterion_gan"] criterion_pixel_l = model["criterion_pixel_l"] criterion_pixel_ab = model["criterion_pixel_ab"] criterion_feat = model["criterion_feat"] criterion_style = model["criterion_style"] criterion_texturegan = model["criterion_texturegan"] real_label = model["real_label"] fake_label = model["fake_label"] optimizerD = model["optimizerD"] optimizerD_local = model["optimizerD_local"] optimizerG = model["optimizerG"] for i, data in enumerate(train_loader): print("Epoch: {0} Iteration: {1}".format(epoch, i)) # Detach is apparently just creating new Variable with cut off reference to previous node, so shouldn't effect the original # But just in case, let's do G first so that detaching G during D update don't do anything weird ############################ # (1) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() img, skg, seg, eroded_seg, txt = data # LAB with negeative value if random.random() < 0.5: txt = img # output img/skg/seg rgb between 0-1 # output img/skg/seg lab between 0-100, -128-128 if args.color_space == 'lab': img = custom_transforms.normalize_lab(img) skg = custom_transforms.normalize_lab(skg) txt = custom_transforms.normalize_lab(txt) seg = custom_transforms.normalize_seg(seg) eroded_seg = custom_transforms.normalize_seg(eroded_seg) # seg = custom_transforms.normalize_lab(seg) elif args.color_space == 'rgb': img = custom_transforms.normalize_rgb(img) skg = custom_transforms.normalize_rgb(skg) txt = custom_transforms.normalize_rgb(txt) # seg=custom_transforms.normalize_rgb(seg) # print seg if not args.use_segmentation_patch: seg.fill_(1) bs, w, h = seg.size() seg = seg.view(bs, 1, w, h) seg = torch.cat((seg, seg, seg), 1) eroded_seg = eroded_seg.view(bs, 1, w, h) # import pdb; pdb.set_trace() temp = torch.ones(seg.size()) * (1 - seg).float() temp[:, 1, :, :] = 0 # torch.ones(seg[:,1,:,:].size())*(1-seg[:,1,:,:]).float() temp[:, 2, :, :] = 0 # torch.ones(seg[:,2,:,:].size())*(1-seg[:,2,:,:]).float() txt = txt.float() * seg.float() + temp #tic = time.time() if args.input_texture_patch == 'original_image': inp, _ = gen_input_rand(img, skg, eroded_seg[:, 0, :, :], args.patch_size_min, args.patch_size_max, args.num_input_texture_patch) elif args.input_texture_patch == 'dtd_texture': inp, _ = gen_input_rand(txt, skg, eroded_seg[:, 0, :, :], args.patch_size_min, args.patch_size_max, args.num_input_texture_patch) #print(time.time()-tic) batch_size, _, _, _ = img.size() img = img.cuda() skg = skg.cuda() seg = seg.cuda() eroded_seg = eroded_seg.cuda() txt = txt.cuda() inp = inp.cuda() input_stack.resize_as_(inp.float()).copy_(inp) target_img.resize_as_(img.float()).copy_(img) segment.resize_as_(seg.float()).copy_(seg) target_texture.resize_as_(txt.float()).copy_(txt) inv_idx = torch.arange(target_texture.size(0)-1, -1, -1).long().cuda() target_texture_inv = target_texture.index_select(0, inv_idx) assert torch.max(seg) <= 1 assert torch.max(eroded_seg) <= 1 inputv = Variable(input_stack) gtimgv = Variable(target_img) segv = Variable(segment) txtv = Variable(target_texture) txtv_inv = Variable(target_texture_inv) outputG = netG(inputv) outputl, outputa, outputb = torch.chunk(outputG, 3, dim=1) gtl, gta, gtb = torch.chunk(gtimgv, 3, dim=1) txtl, txta, txtb = torch.chunk(txtv, 3, dim=1) txtl_inv,txta_inv,txtb_inv = torch.chunk(txtv_inv,3,dim=1) outputab = torch.cat((outputa, outputb), 1) gtab = torch.cat((gta, gtb), 1) txtab = torch.cat((txta, txtb), 1) if args.color_space == 'lab': outputlll = (torch.cat((outputl, outputl, outputl), 1)) gtlll = (torch.cat((gtl, gtl, gtl), 1)) txtlll = torch.cat((txtl, txtl, txtl), 1) elif args.color_space == 'rgb': outputlll = outputG # (torch.cat((outputl,outputl,outputl),1)) gtlll = gtimgv # (torch.cat((targetl,targetl,targetl),1)) txtlll = txtv if args.loss_texture == 'original_image': targetl = gtl targetab = gtab targetlll = gtlll else: # if args.loss_texture == 'texture_mask': # remove baskground dtd # txtl = segv[:,0:1,:,:]*txtl # txtab=segv[:,1:3,:,:]*txtab # txtlll=segv*txtlll # elif args.loss_texture == 'texture_patch': targetl = txtl targetab = txtab targetlll = txtlll ################## Global Pixel ab Loss ############################ err_pixel_ab = args.pixel_weight_ab * criterion_pixel_ab(outputab, targetab) ################## Global Feature Loss############################ out_feat = extract_content(renormalize(outputlll))[0] gt_feat = extract_content(renormalize(gtlll))[0] err_feat = args.feature_weight * criterion_feat(out_feat, gt_feat.detach()) ################## Global D Adversarial Loss ############################ netD.zero_grad() label_ = Variable(label) #return outputl, txtl if args.color_space == 'lab': outputD = netD(outputl) elif args.color_space == 'rgb': outputD = netD(outputG) # D_G_z2 = outputD.data.mean() label.resize_(outputD.data.size()) labelv = Variable(label.fill_(real_label)) err_gan = args.discriminator_weight * criterion_gan(outputD, labelv) err_pixel_l = 0 ################## Global Pixel L Loss ############################ err_pixel_l = args.global_pixel_weight_l * criterion_pixel_l(outputl, targetl) if args.local_texture_size == -1: # global, no loss patch ################## Global Style Loss ############################ output_style_feat = extract_style(outputlll) target_style_feat = extract_style(targetlll) gram = GramMatrix() err_style = 0 for m in range(len(output_style_feat)): gram_y = gram(output_style_feat[m]) gram_s = gram(target_style_feat[m]) err_style += args.style_weight * criterion_style(gram_y, gram_s.detach()) err_texturegan = 0 else: # local loss patch err_style = 0 patchsize = args.local_texture_size netD_local.zero_grad() for p in range(args.num_local_texture_patch): texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg,seg, outputlll) gt_texture_patch = gen_local_patch(patchsize, batch_size, eroded_seg,seg, targetlll) texture_patchl = gen_local_patch(patchsize, batch_size, eroded_seg, seg,outputl) gt_texture_patchl = gen_local_patch(patchsize, batch_size, eroded_seg,seg, targetl) ################## Local Style Loss ############################ output_style_feat = extract_style(texture_patch) target_style_feat = extract_style(gt_texture_patch) gram = GramMatrix() for m in range(len(output_style_feat)): gram_y = gram(output_style_feat[m]) gram_s = gram(target_style_feat[m]) err_style += args.style_weight * criterion_style(gram_y, gram_s.detach()) ################## Local Pixel L Loss ############################ err_pixel_l += args.local_pixel_weight_l * criterion_pixel_l(texture_patchl, gt_texture_patchl) ################## Local D Loss ############################ label_ = Variable(label) err_texturegan = 0 outputD_local = netD_local(torch.cat((texture_patchl, gt_texture_patchl),1)) label_local.resize_(outputD_local.data.size()) labelv_local = Variable(label_local.fill_(real_label)) err_texturegan += args.discriminator_local_weight * criterion_texturegan(outputD_local, labelv_local) loss_graph["gdl"].append(err_texturegan.data[0]) #################################### err_G = err_pixel_l + err_pixel_ab + err_gan + err_feat + err_style + err_texturegan err_G.backward(retain_variables=True) optimizerG.step() loss_graph["g"].append(err_G.data[0]) loss_graph["gpl"].append(err_pixel_l.data[0]) loss_graph["gpab"].append(err_pixel_ab.data[0]) loss_graph["gd"].append(err_gan.data[0]) loss_graph["gf"].append(err_feat.data[0]) loss_graph["gs"].append(err_style.data[0]) print('G:', err_G.data[0]) ############################ # (2) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real netD.zero_grad() labelv = Variable(label) if args.color_space == 'lab': outputD = netD(gtl) elif args.color_space == 'rgb': outputD = netD(gtimgv) label.resize_(outputD.data.size()) labelv = Variable(label.fill_(real_label)) errD_real = criterion_gan(outputD, labelv) errD_real.backward() score = Variable(torch.ones(batch_size)) _, cd, wd, hd = outputD.size() D_output_size = cd * wd * hd clamped_output_D = outputD.clamp(0, 1) clamped_output_D = torch.round(clamped_output_D) for acc_i in range(batch_size): score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size real_acc = torch.mean(score) if args.color_space == 'lab': outputD = netD(outputl.detach()) elif args.color_space == 'rgb': outputD = netD(outputG.detach()) label.resize_(outputD.data.size()) labelv = Variable(label.fill_(fake_label)) errD_fake = criterion_gan(outputD, labelv) errD_fake.backward() score = Variable(torch.ones(batch_size)) _, cd, wd, hd = outputD.size() D_output_size = cd * wd * hd clamped_output_D = outputD.clamp(0, 1) clamped_output_D = torch.round(clamped_output_D) for acc_i in range(batch_size): score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size fake_acc = torch.mean(1 - score) D_acc = (real_acc + fake_acc) / 2 if D_acc.data[0] < args.threshold_D_max: # D_G_z1 = output.data.mean() errD = errD_real + errD_fake loss_graph["d"].append(errD.data[0]) optimizerD.step() else: loss_graph["d"].append(0) print('D:', 'real_acc', "%.2f" % real_acc.data[0], 'fake_acc', "%.2f" % fake_acc.data[0], 'D_acc', D_acc.data[0]) ############################ # (2) Update D local network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real if args.local_texture_size != -1: patchsize = args.local_texture_size x1 = int(rand_between(patchsize, args.image_size - patchsize)) y1 = int(rand_between(patchsize, args.image_size - patchsize)) x2 = int(rand_between(patchsize, args.image_size - patchsize)) y2 = int(rand_between(patchsize, args.image_size - patchsize)) netD_local.zero_grad() labelv = Variable(label) if args.color_space == 'lab': outputD_local = netD_local(torch.cat((targetl[:, :, x1:(x1 + patchsize), y1:(y1 + patchsize)],targetl[:, :, x2:(x2 + patchsize), y2:(y2 + patchsize)]),1))#netD_local(targetl) elif args.color_space == 'rgb': outputD = netD(gtimgv) label.resize_(outputD_local.data.size()) labelv = Variable(label.fill_(real_label)) errD_real_local = criterion_texturegan(outputD_local, labelv) errD_real_local.backward(retain_variables=True) score = Variable(torch.ones(batch_size)) _, cd, wd, hd = outputD_local.size() D_output_size = cd * wd * hd clamped_output_D = outputD_local.clamp(0, 1) clamped_output_D = torch.round(clamped_output_D) for acc_i in range(batch_size): score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size realreal_acc = torch.mean(score) x1 = int(rand_between(patchsize, args.image_size - patchsize)) y1 = int(rand_between(patchsize, args.image_size - patchsize)) x2 = int(rand_between(patchsize, args.image_size - patchsize)) y2 = int(rand_between(patchsize, args.image_size - patchsize)) if args.color_space == 'lab': #outputD_local = netD_local(torch.cat((txtl[:, :, x1:(x1 + patchsize), y1:(y1 + patchsize)],outputl[:, :, x2:(x2 + patchsize), y2:(y2 + patchsize)]),1))#outputD = netD(outputl.detach()) outputD_local = netD_local(torch.cat((texture_patchl, gt_texture_patchl),1)) elif args.color_space == 'rgb': outputD = netD(outputG.detach()) label.resize_(outputD_local.data.size()) labelv = Variable(label.fill_(fake_label)) errD_fake_local = criterion_gan(outputD_local, labelv) errD_fake_local.backward() score = Variable(torch.ones(batch_size)) _, cd, wd, hd = outputD_local.size() D_output_size = cd * wd * hd clamped_output_D = outputD_local.clamp(0, 1) clamped_output_D = torch.round(clamped_output_D) for acc_i in range(batch_size): score[acc_i] = torch.sum(clamped_output_D[acc_i]) / D_output_size fakefake_acc = torch.mean(1 - score) D_acc = (realreal_acc +fakefake_acc) / 2 if D_acc.data[0] < args.threshold_D_max: # D_G_z1 = output.data.mean() errD_local = errD_real_local + errD_fake_local loss_graph["dl"].append(errD_local.data[0]) optimizerD_local.step() else: loss_graph["dl"].append(0) print('D local:', 'real real_acc', "%.2f" % realreal_acc.data[0], 'fake fake_acc', "%.2f" % fakefake_acc.data[0], 'D_acc', D_acc.data[0]) #if i % args.save_every == 0: # save_network(netD_local, 'D_local', epoch, i, args) if i % args.save_every == 0: save_network(netG, 'G', epoch, i, args) save_network(netD, 'D', epoch, i, args) save_network(netD_local, 'D_local', epoch, i, args) if i % args.visualize_every == 0: visualize_training(netG, val_loader, input_stack, target_img,target_texture, segment, vis, loss_graph, args)
def get_prediction(image_bytes, txt_bytes, pos_x, pos_y): try: # code for TextureGAN transform = get_transforms(args) # val = make_dataset(args.data_path, 'val') # valDset = ImageFolder('val', args.data_path, transform) # val_display_size = 1 # valLoader = DataLoader(dataset=valDset, batch_size=val_display_size,shuffle=True) # new load data # img_path = "./dataset/test/img.jpg" img_path = image_bytes # skg_path = "./dataset/test/skg.jpg" skg_path = image_bytes seg_path = "./dataset/test/blank.jpg" eroded_seg_path = seg_path # txt_path = "./dataset/test/txt.jpg" txt_path = txt_bytes img = pil_loader1(img_path) skg = pil_loader1(skg_path) seg = pil_loader(seg_path) txt = pil_loader1(txt_path) eroded_seg = pil_loader(eroded_seg_path) img, skg, seg, eroded_seg, txt = transform([img, skg, seg, eroded_seg, txt]) img = img.unsqueeze(0) skg = skg.unsqueeze(0) txt = txt.unsqueeze(0) seg = seg.unsqueeze(0) eroded_seg = eroded_seg.unsqueeze(0) data = [img, skg, seg, eroded_seg, txt] # load model model_location = './pretrained_models/G_net_texturegan_18_300.pth' netG = texturegan.TextureGAN(5, 3, 32) load_network(netG, model_location) netG.eval() # get_ipython().run_line_magic('matplotlib', 'inline') warnings.filterwarnings('ignore') # start val color_space = 'lab' img, skg, seg, eroded_seg, txt = data # print(txt.size()) img = custom_transforms.normalize_lab(img) skg = custom_transforms.normalize_lab(skg) txt = custom_transforms.normalize_lab(txt) seg = custom_transforms.normalize_seg(seg) eroded_seg = custom_transforms.normalize_seg(eroded_seg) inp, texture_loc = get_input(data, pos_x, pos_y, 50, 1) seg = seg != 0 model = netG inpv = get_inputv(inp.cuda()) output = model(inpv.cuda()) out_img = vis_image(custom_transforms.denormalize_lab(output.data.double().cpu()), color_space) inp_img = vis_patch(custom_transforms.denormalize_lab(txt.cpu()), custom_transforms.denormalize_lab(skg.cpu()), texture_loc, color_space) tar_img = vis_image(custom_transforms.denormalize_lab(img.cpu()), color_space) # plt.figure() # plt.imshow(np.transpose(inp_img[0], (1, 2, 0))) # plt.imsave("./out_img.jpg", np.transpose(out_img[0], (1, 2, 0))) # # plt.axis('off') # # plt.figure() # plt.figure() # plt.imshow(np.transpose(out_img[0], (1, 2, 0))) # plt.show() inp_img = np.transpose(inp_img[0], (1, 2, 0)) out_img = np.transpose(out_img[0], (1, 2, 0)) return out_img except Exception as e: print("error!!!!!!!!!!!!") return inp_img