def load_model(): comp = CompletionNet(norm=nn.BatchNorm2d, nf=64) comp = nn.DataParallel(comp).cuda() comp.load_state_dict( torch.load(os.path.join(assets_file_dir, "unfiller_rgb.pth"))) model = comp.module model.eval() return model
def __init__(self, port, imgs, depths, target, target_poses, scale_up, semantics=None, gui=True, use_filler=True, gpu_idx=0, window_width=256, window_height=None, env=None): self.env = env self.roll, self.pitch, self.yaw = 0, 0, 0 self.quat = [1, 0, 0, 0] self.x, self.y, self.z = 0, 0, 0 self.fps = 0 self.mousex, self.mousey = 0.5, 0.5 self.org_pitch, self.org_yaw, self.org_roll = 0, 0, 0 self.org_x, self.org_y, self.org_z = 0, 0, 0 self.clickstart = (0, 0) self.mousedown = False self.overlay = False self.show_depth = False self.port = port self._context_phys = zmq.Context() self._context_mist = zmq.Context() self._context_dept = zmq.Context() # Channel for smoothed depth self._context_norm = zmq.Context() # Channel for smoothed depth self._context_semt = zmq.Context() self.env = env # configs.View.SEMANTICS in configs.ViewComponent.getComponents() self._require_semantics = 'semantics' in self.env.config["output"] # configs.View.NORMAL in configs.ViewComponent.getComponents() self._require_normal = 'normal' in self.env.config["output"] self.socket_mist = self._context_mist.socket(zmq.REQ) self.socket_mist.connect("tcp://localhost:{}".format(self.port - 1)) # self.socket_dept = self._context_dept.socket(zmq.REQ) # self.socket_dept.connect("tcp://localhost:{}".format(5555 - 1)) if self._require_normal: self.socket_norm = self._context_norm.socket(zmq.REQ) self.socket_norm.connect("tcp://localhost:{}".format(self.port - 2)) if self._require_semantics: self.socket_semt = self._context_semt.socket(zmq.REQ) self.socket_semt.connect("tcp://localhost:{}".format(self.port - 3)) self.target_poses = target_poses self.pose_locations = np.array( [tp[:3, -1] for tp in self.target_poses]) self.relative_poses = [ np.dot(np.linalg.inv(tg), self.target_poses[0]) for tg in target_poses ] self.imgs = imgs self.depths = depths self.target = target self.semantics = semantics self.model = None self.old_topk = set([]) self.k = 5 self.use_filler = use_filler self.window_width = window_width if window_height is None: self.window_height = window_width else: self.window_height = window_height self.capture_count = 0 # print(self.showsz) # self.show = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8') # self.show_rgb = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8') self.show = np.zeros((self.window_width, self.window_height, 3), dtype='uint8') self.show_rgb = np.zeros((self.window_width, self.window_height, 3), dtype='uint8') self.show_semantics = np.zeros( (self.window_width, self.window_height, 3), dtype='uint8') self.show_prefilled = np.zeros( (self.window_width, self.window_height, 3), dtype='uint8') self.surface_normal = np.zeros( (self.window_width, self.window_height, 3), dtype='uint8') self.semtimg_count = 0 if "fast_lq_render" in self.env.config and self.env.config[ "fast_lq_render"] == True: comp = CompletionNet(norm=nn.BatchNorm2d, nf=24, skip_first_bn=True) else: comp = CompletionNet(norm=nn.BatchNorm2d, nf=64) comp = torch.nn.DataParallel(comp).cuda() # comp.load_state_dict(torch.load(os.path.join(assets_file_dir, "model_{}.pth".format(self.env.config["resolution"])))) if self.env.config["resolution"] <= 64: res = 64 elif self.env.config["resolution"] <= 128: res = 128 elif self.env.config["resolution"] <= 256: res = 256 else: res = 512 if "fast_lq_render" in self.env.config and self.env.config[ "fast_lq_render"]: comp.load_state_dict( torch.load( os.path.join(assets_file_dir, "model_small_{}.pth".format(res)))) else: comp.load_state_dict( torch.load( os.path.join(assets_file_dir, "model_{}.pth".format(res)))) # comp.load_state_dict(torch.load(os.path.join(file_dir, "models.pth"))) # comp.load_state_dict(torch.load(os.path.join(file_dir, "model_large.pth"))) self.model = comp.module self.model.eval() if not self.env.config["use_filler"]: self.model = None self.imgs_topk = None self.depths_topk = None self.relative_poses_topk = None self.old_topk = None self.imgv = Variable(torch.zeros(1, 3, self.window_height, self.window_width), volatile=True).cuda() self.maskv = Variable(torch.zeros(1, 2, self.window_height, self.window_width), volatile=True).cuda() self.mean = torch.from_numpy( np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32)) self.mean = self.mean.view(3, 1, 1).repeat(1, self.window_height, self.window_width) if gui and not self.env.config["display_ui"]: self.renderToScreenSetup()
for k, v in uuids: # print(k,v) data = d[v] source = data[0][0] target = data[1] target_depth = data[3] source_depth = data[2][0] pose = data[-1][0].numpy() targets.append(target) poses.append(pose) sources.append(target) source_depths.append(target_depth) model = None if opt.model != '': comp = CompletionNet() comp = torch.nn.DataParallel(comp).cuda() comp.load_state_dict(torch.load(opt.model)) model = comp.module model.eval() print(model) print('target', poses, poses[0]) # print('no.1 pose', poses, poses[1]) # print(source_depth) print(sources[0].shape, source_depths[0].shape) show_target(target) renderer = PCRenderer(5556, sources, source_depths, target, rts) # renderer.renderToScreen(sources, source_depths, poses, models, target, target_depth, rts) renderer.renderOffScreenSetup()
def __init__(self, port, imgs, depths, target, target_poses, scale_up, semantics=None, \ gui=True, use_filler=True, gpu_count=0, windowsz=256, env = None): self.env = env self.roll, self.pitch, self.yaw = 0, 0, 0 self.quat = [1, 0, 0, 0] self.x, self.y, self.z = 0, 0, 0 self.fps = 0 self.mousex, self.mousey = 0.5, 0.5 self.org_pitch, self.org_yaw, self.org_roll = 0, 0, 0 self.org_x, self.org_y, self.org_z = 0, 0, 0 self.clickstart = (0, 0) self.mousedown = False self.overlay = False self.show_depth = False self._context_phys = zmq.Context() self._context_mist = zmq.Context() self._context_dept = zmq.Context() ## Channel for smoothed depth self._context_norm = zmq.Context() ## Channel for smoothed depth self._context_semt = zmq.Context() self.env = env self._require_semantics = 'semantics' in self.env.config[ "output"] #configs.View.SEMANTICS in configs.ViewComponent.getComponents() self._require_normal = 'normal' in self.env.config[ "output"] #configs.View.NORMAL in configs.ViewComponent.getComponents() self.socket_mist = self._context_mist.socket(zmq.REQ) self.socket_mist.connect("tcp://localhost:{}".format(5555 + gpu_count)) self.socket_dept = self._context_dept.socket(zmq.REQ) self.socket_dept.connect("tcp://localhost:{}".format(5555 - 1)) if self._require_normal: self.socket_norm = self._context_norm.socket(zmq.REQ) self.socket_norm.connect("tcp://localhost:{}".format(5555 - 2)) if self._require_semantics: self.socket_semt = self._context_semt.socket(zmq.REQ) self.socket_semt.connect("tcp://localhost:{}".format(5555 - 3)) self.target_poses = target_poses self.imgs = imgs self.depths = depths self.target = target self.semantics = semantics self.model = None self.old_topk = set([]) self.k = 5 self.use_filler = use_filler self.showsz = windowsz self.capture_count = 0 #print(self.showsz) #self.show = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8') #self.show_rgb = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8') self.show = np.zeros((self.showsz, self.showsz, 3), dtype='uint8') self.show_rgb = np.zeros((self.showsz, self.showsz, 3), dtype='uint8') self.show_semantics = np.zeros((self.showsz, self.showsz, 3), dtype='uint8') #self.show_unfilled = None #if configs.MAKE_VIDEO or configs.HIST_MATCHING: self.show_unfilled = np.zeros((self.showsz, self.showsz, 3), dtype='uint8') self.surface_normal = np.zeros((self.showsz, self.showsz, 3), dtype='uint8') self.semtimg_count = 0 #if configs.USE_SMALL_FILLER: # comp = CompletionNet(norm = nn.BatchNorm2d, nf = 24) # comp = torch.nn.DataParallel(comp).cuda() # comp.load_state_dict(torch.load(os.path.join(assets_file_dir, "model.pth"))) #else: comp = CompletionNet(norm=nn.BatchNorm2d, nf=64) comp = torch.nn.DataParallel(comp).cuda() #comp.load_state_dict(torch.load(os.path.join(assets_file_dir, "model_{}.pth".format(self.env.config["resolution"])))) comp.load_state_dict( torch.load( os.path.join( assets_file_dir, "model_{}.pth".format(self.env.config["resolution"])))) #comp.load_state_dict(torch.load(os.path.join(file_dir, "model.pth"))) #comp.load_state_dict(torch.load(os.path.join(file_dir, "model_large.pth"))) self.model = comp.module self.model.eval() if not self.env.config["use_filler"]: self.model = None self.imgv = Variable(torch.zeros(1, 3, self.showsz, self.showsz), volatile=True).cuda() self.maskv = Variable(torch.zeros(1, 2, self.showsz, self.showsz), volatile=True).cuda() self.mean = torch.from_numpy( np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32)) if gui and not self.env.config["display_ui"]: #configs.DISPLAY_UI: self.renderToScreenSetup()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataroot', required=True, help='path to dataset') parser.add_argument('--debug', action='store_true', help='debug mode') parser.add_argument('--imgsize', type=int, default=256, help='image size') parser.add_argument('--batchsize', type=int, default=20, help='batchsize') parser.add_argument('--workers', type=int, default=9, help='number of workers') parser.add_argument('--nepoch', type=int, default=50, help='number of epochs') parser.add_argument('--lr', type=float, default=2e-5, help='learning rate, default=0.002') parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') parser.add_argument('--outf', type=str, default="filler_pano_pc_full", help='output folder') parser.add_argument('--model', type=str, default="", help='model path') parser.add_argument('--cepoch', type=int, default=0, help='current epoch') parser.add_argument('--loss', type=str, default="perceptual", help='l1 only') parser.add_argument('--init', type=str, default="iden", help='init method') parser.add_argument('--l1', type=float, default=0, help='add l1 loss') parser.add_argument('--color_coeff', type=float, default=0, help='add color match loss') parser.add_argument('--unfiller', action='store_true', help='debug mode') parser.add_argument('--joint', action='store_true', help='debug mode') parser.add_argument('--use_depth', action='store_true', default=False, help='debug mode') parser.add_argument('--zoom', type=int, default=1, help='debug mode') parser.add_argument('--patchsize', type=int, default=256, help='debug mode') mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32)).clone() opt = parser.parse_args() print(opt) writer = SummaryWriter(opt.outf + '/runs/' + datetime.now().strftime('%B%d %H:%M:%S')) try: os.makedirs(opt.outf) except OSError: pass zoom = opt.zoom patchsize = opt.patchsize tf = transforms.Compose([ transforms.ToTensor(), ]) mist_tf = transforms.Compose([ transforms.ToTensor(), ]) d = PairDataset(root=opt.dataroot, transform=tf, mist_transform=mist_tf) d_test = PairDataset(root=opt.dataroot, transform=tf, mist_transform=mist_tf, train=False) cudnn.benchmark = True dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last=True, pin_memory=False) dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last=True, pin_memory=False) img = Variable(torch.zeros(opt.batchsize, 3, 1024, 2048)).cuda() maskv = Variable(torch.zeros(opt.batchsize, 2, 1024, 2048)).cuda() img_original = Variable(torch.zeros(opt.batchsize, 3, 1024, 2048)).cuda() label = Variable(torch.LongTensor(opt.batchsize * 4)).cuda() comp = CompletionNet(norm=nn.BatchNorm2d, nf=64) current_epoch = opt.cepoch comp = torch.nn.DataParallel(comp).cuda() if opt.init == 'iden': comp.apply(identity_init) else: comp.apply(weights_init) if opt.model != '': comp.load_state_dict(torch.load(opt.model)) # dis.load_state_dict(torch.load(opt.model.replace("G", "D"))) current_epoch = opt.cepoch if opt.unfiller: comp2 = CompletionNet(norm=nn.BatchNorm2d, nf=64) comp2 = torch.nn.DataParallel(comp2).cuda() if opt.model != '': comp2.load_state_dict(torch.load(opt.model.replace('G', 'G2'))) optimizerG2 = torch.optim.Adam(comp2.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) l2 = nn.MSELoss() # if opt.loss == 'train_init': # params = list(comp.parameters()) # sel = np.random.choice(len(params), len(params)/2, replace=False) # params_sel = [params[i] for i in sel] # optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999)) # # else: optimizerG = torch.optim.Adam(comp.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) curriculum = (200000, 300000) # step to start D training and G training, slightly different from the paper alpha = 0.004 errG_data = 0 errD_data = 0 vgg16 = models.vgg16(pretrained=False) vgg16.load_state_dict(torch.load('vgg16-397923af.pth')) feat = vgg16.features p = torch.nn.DataParallel(Perceptual(feat)).cuda() for param in p.parameters(): param.requires_grad = False imgnet_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406]).astype(np.float32)).clone() imgnet_std = torch.from_numpy(np.array([0.229, 0.224, 0.225]).astype(np.float32)).clone() imgnet_mean_img = Variable(imgnet_mean.view(1, 3, 1, 1).repeat(opt.batchsize * 4, 1, patchsize, patchsize)).cuda() imgnet_std_img = Variable(imgnet_std.view(1, 3, 1, 1).repeat(opt.batchsize * 4, 1, patchsize, patchsize)).cuda() test_loader_enum = enumerate(dataloader_test) for epoch in range(current_epoch, opt.nepoch): for i, data in enumerate(dataloader, 0): optimizerG.zero_grad() source = data[0] source_depth = data[1] target = data[2] step = i + epoch * len(dataloader) mask = (torch.sum(source[:, :3, :, :], 1) > 0).float().unsqueeze(1) # img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1) source[:, :3, :, :] += (1 - mask.repeat(1, 3, 1, 1)) * mean.view(1, 3, 1, 1).repeat(opt.batchsize, 1, 1024, 2048) source_depth = source_depth[:, :, :, 0].unsqueeze(1) # print(source_depth.size(), mask.size()) source_depth = torch.cat([source_depth, mask], 1) img.data.copy_(source) maskv.data.copy_(source_depth) img_original.data.copy_(target) imgc, maskvc, img_originalc = crop(img, maskv, img_original, zoom, patchsize) # from IPython import embed; embed() recon = comp(imgc, maskvc) if opt.loss == "train_init": loss = l2(recon, imgc[:, :3, :, :]) elif opt.loss == 'l1': loss = l2(recon, img_originalc) elif opt.loss == 'perceptual': loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc) elif opt.loss == 'color_stable': loss = l2(p(recon.view(recon.size(0) * 3, 1, patchsize, patchsize).repeat(1, 3, 1, 1)), p(img_originalc.view(img_originalc.size(0) * 3, 1, patchsize, patchsize).repeat(1, 3, 1, 1)).detach()) elif opt.loss == 'color_correction': recon_percept = p((recon - imgnet_mean_img) / imgnet_std_img) org_percept = p((img_originalc - imgnet_mean_img) / (imgnet_std_img)).detach() loss = l2(recon_percept, org_percept) for scale in [32]: img_originalc_patch = img_originalc.view(opt.batchsize * 4, 3, patchsize // scale, scale, patchsize // scale, scale).transpose(4, 3).contiguous().view(opt.batchsize * 4, 3, patchsize // scale, patchsize // scale, -1) recon_patch = recon.view(opt.batchsize * 4, 3, patchsize // scale, scale, patchsize // scale, scale).transpose(4, 3).contiguous().view( opt.batchsize * 4, 3, patchsize // scale, patchsize // scale, -1) img_originalc_patch_mean = img_originalc_patch.mean(dim=-1) recon_patch_mean = recon_patch.mean(dim=-1) # recon_patch_cov = [] # img_originalc_patch_cov = [] # for j in range(3): # recon_patch_cov.append((recon_patch * recon_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1)) # img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1)) # recon_patch_cov_cat = torch.cat(recon_patch_cov,1) # img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1) color_loss = l2(recon_patch_mean, img_originalc_patch_mean) # + l2(recon_patch_cov_cat, img_originalc_patch_cov_cat.detach()) loss += opt.color_coeff * color_loss print("color loss %f" % color_loss.data[0]) loss.backward(retain_graph=True) if opt.unfiller: optimizerG2.zero_grad() recon2 = comp2(img_originalc, maskvc) if not opt.joint: recon2_percept = p((recon2 - imgnet_mean_img) / imgnet_std_img) recon_percept = p((recon - imgnet_mean_img) / imgnet_std_img) loss2 = l2(recon2_percept, recon_percept.detach()) else: recon_percept = p((recon - imgnet_mean_img) / imgnet_std_img) z = Variable(torch.zeros(recon_percept.size()).cuda()) recon2_percept = p((recon2 - imgnet_mean_img) / imgnet_std_img) loss2 = l2(recon2_percept - recon_percept, z) loss2 += 0.2 * l2(recon2_percept, org_percept) for scale in [32]: img_originalc_patch = recon.detach().view(opt.batchsize * 4, 3, patchsize / scale, scale, patchsize / scale, scale).transpose(4, 3).contiguous().view( opt.batchsize * 4, 3, patchsize / scale, patchsize / scale, -1) recon2_patch = recon2.view(opt.batchsize * 4, 3, patchsize / scale, scale, patchsize / scale, scale).transpose( 4, 3).contiguous().view(opt.batchsize * 4, 3, patchsize / scale, patchsize / scale, -1) img_originalc_patch_mean = img_originalc_patch.mean(dim=-1) recon2_patch_mean = recon2_patch.mean(dim=-1) recon2_patch_cov = [] img_originalc_patch_cov = [] for j in range(3): recon2_patch_cov.append( (recon2_patch * recon2_patch[:, j:j + 1].repeat(1, 3, 1, 1, 1)).mean(dim=-1)) img_originalc_patch_cov.append( (img_originalc_patch * img_originalc_patch[:, j:j + 1].repeat(1, 3, 1, 1, 1)).mean(dim=-1)) recon2_patch_cov_cat = torch.cat(recon2_patch_cov, 1) img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1) z = Variable(torch.zeros(img_originalc_patch_mean.size()).cuda()) if opt.joint: color_loss = l2(recon2_patch_mean - img_originalc_patch_mean, z) else: color_loss = l2(recon2_patch_mean, img_originalc_patch_mean) loss2 += opt.color_coeff * color_loss print("color loss %f" % color_loss.data[0]) loss2 = loss2 * 0.3 loss2.backward(retain_graph=True) print("loss2 %f" % loss2.data[0]) optimizerG2.step() if i % 10 == 0: writer.add_scalar('MSEloss2', loss2.data[0], step) # from IPython import embed; embed() if opt.loss == "train_init": for param in comp.parameters(): if len(param.size()) == 4: # print(param.size()) nk = param.size()[2] // 2 if nk > 5: param.grad[:nk, :, :, :] = 0 optimizerG.step() print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % ( epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data)) if i % 200 == 0: test_i, test_data = next(test_loader_enum) if test_i > len(dataloader_test) - 5: test_loader_enum = enumerate(dataloader_test) source = test_data[0] source_depth = test_data[1] target = test_data[2] mask = (torch.sum(source[:, :3, :, :], 1) > 0).float().unsqueeze(1) source[:, :3, :, :] += (1 - mask.repeat(1, 3, 1, 1)) * mean.view(1, 3, 1, 1).repeat(opt.batchsize, 1, 1024, 2048) source_depth = source_depth[:, :, :, 0].unsqueeze(1) source_depth = torch.cat([source_depth, mask], 1) img.data.copy_(source) maskv.data.copy_(source_depth) img_original.data.copy_(target) imgc, maskvc, img_originalc = crop(img, maskv, img_original, zoom, patchsize) comp.eval() recon = comp(imgc, maskvc) comp.train() if opt.unfiller: comp2.eval() # maskvc.data.fill_(0) recon2 = comp2(img_originalc, maskvc) comp2.train() visual = torch.cat([imgc.data[:, :3, :, :], recon.data, recon2.data, img_originalc.data], 3) else: visual = torch.cat([imgc.data[:, :3, :, :], recon.data, img_originalc.data], 3) visual = vutils.make_grid(visual, normalize=True) writer.add_image('image', visual, step) vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1) if i % 10 == 0: writer.add_scalar('MSEloss', loss.data[0], step) writer.add_scalar('G_loss', errG_data, step) writer.add_scalar('D_loss', errD_data, step) if i % 2000 == 0: torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i)) if opt.unfiller: torch.save(comp2.state_dict(), '%s/compG2_epoch%d_%d.pth' % (opt.outf, epoch, i))