def __init__(self, sigma=0.5, temperature=0.5, gradclip=1, npts=10, option='incremental', size=128, path_to_check='checkpoint_fansoft/fan_109.pth', args=None): self.npoints = npts self.gradclip = gradclip # - define FAN self.FAN = FAN(1, n_points=self.npoints) if not option == 'scratch': net_dict = self.FAN.state_dict() pretrained_dict = torch.load(path_to_check, map_location='cuda') pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict)} pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict[k].shape == net_dict[k].shape} net_dict.update(pretrained_dict) self.FAN.load_state_dict(net_dict, strict=True) if option == 'incremental': print('Option is incremental') self.FAN.apply(convertLayer) # - define Bottleneck self.BOT = GeoDistill(sigma=sigma, temperature=temperature, out_res=int(size/4)) # - define GEN self.GEN = Generator(conv_dim=32, c_dim=self.npoints) # # Load pretrained model if args.resume_folder: path_fan = '{}/model_{}.fan.pth'.format(args.resume_folder, args.resume_epoch) path_gen = '{}/model_{}.gen.pth'.format(args.resume_folder, args.resume_epoch) self._resume(path_fan, path_gen) # - multiple GPUs if torch.cuda.device_count() > 1: self.FAN = torch.nn.DataParallel(self.FAN) self.BOT = torch.nn.DataParallel(self.BOT) self.GEN = torch.nn.DataParallel(self.GEN) self.FAN.to('cuda').train() self.BOT.to('cuda').train() self.GEN.to('cuda').train() # - VGG for perceptual loss self.loss_network = LossNetwork(torch.nn.DataParallel(vgg16(pretrained=True)))\ if torch.cuda.device_count() > 1 else LossNetwork(vgg16(pretrained=True)) self.loss_network.eval() self.loss_network.to('cuda') self.loss = dict.fromkeys(['all_loss', 'rec', 'perceptual']) self.A = None # - define losses for reconstruction self.SelfLoss = torch.nn.MSELoss().to('cuda') self.PerceptualLoss = torch.nn.MSELoss().to('cuda') self.heatmap = HeatMap(32, 0.5).cuda()
def extractdata(folder, epoch, path_to_core, db, tight, npoints, data_path, sigma): # outpickle outname = 'data_{}.pkl'.format(folder) # create model path_to_fan = '{}/model_{}.fan.pth'.format(folder, epoch) path_to_gen = '{}/model_{}.gen.pth'.format(folder, epoch) #checkpoint = torch.load(path_to_model)['state_dict'] net_f, net_g = loadnet(npoints, path_to_fan, path_to_gen, path_to_core) BOT = GeoDistill(sigma=sigma, temperature=0.1).to('cuda') # create data database = SuperDB(path=data_path, size=128, flip=False, angle=0.0, tight=tight or 64, db=db, affine=True, npts=npoints) num_workers = 12 dbloader = DataLoader(database, batch_size=30, shuffle=False, num_workers=num_workers, pin_memory=False) # extract data print('Extracting data from {:s}, with {:d} imgs'.format( db, len(database))) Ptr, Gtr = getdata(dbloader, BOT, net_f, net_g) # dump data data = pickle.load(open(outname, 'rb')) if os.path.exists(outname) else {} if db not in data.keys(): data[db] = {} data[db][str(epoch)] = {'Ptr': Ptr, 'Gtr': Gtr} pickle.dump(data, open(outname, 'wb'))
def evalaffine(facenet, db, npts=10): errors = np.zeros((len(db), npts)) trainloader = DataLoader(db, batch_size=30, shuffle=False, num_workers=12, pin_memory=False) i = 0 BOT = GeoDistill(sigma=0.5, temperature=0.1).to('cuda') for j, sample in enumerate(trainloader): a, b, c = sample['Im'], sample['ImP'], sample['M'] _, preda = BOT(facenet(a.cuda())) _, predb = BOT(facenet(b.cuda())) pred_b = [] for m in range(preda.shape[0]): pred_b.append( torch.cat((4 * preda[m].cpu(), torch.ones(npts, 1)), dim=1) @ c[m].permute(1, 0)) errors[i, :] = np.sqrt( np.sum((pred_b[m].detach().numpy() - 4 * predb[m].detach().cpu().numpy())**2, axis=-1)) i = i + 1 return errors
class model(): def __init__(self, sigma=0.5, temperature=0.5, gradclip=1, npts=10, option='incremental', size=128, path_to_check='checkpoint_fansoft/fan_109.pth'): self.npoints = npts self.gradclip = gradclip # - define FAN self.FAN = FAN(1,n_points=self.npoints) if not option == 'scratch': net_dict = self.FAN.state_dict() pretrained_dict = torch.load(path_to_check, map_location='cuda') pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict)} pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict[k].shape == net_dict[k].shape} net_dict.update(pretrained_dict) self.FAN.load_state_dict(net_dict, strict=True) if option == 'incremental': print('Option is incremental') self.FAN.apply(convertLayer) # - define Bottleneck self.BOT = GeoDistill(sigma=sigma, temperature=temperature, out_res=int(size/4)) # - define GEN self.GEN = Generator(conv_dim=32, c_dim=self.npoints) # - multiple GPUs if torch.cuda.device_count() > 1: self.FAN = torch.nn.DataParallel(self.FAN) self.BOT = torch.nn.DataParallel(self.BOT) self.GEN = torch.nn.DataParallel(self.GEN) self.FAN.to('cuda').train() self.BOT.to('cuda').train() self.GEN.to('cuda').train() # - VGG for perceptual loss self.loss_network = LossNetwork(torch.nn.DataParallel(vgg16(pretrained=True))) if torch.cuda.device_count() > 1 else LossNetwork(vgg16(pretrained=True)) self.loss_network.eval() self.loss_network.to('cuda') self.loss = dict.fromkeys(['rec']) self.A = None # - define losses for reconstruction self.SelfLoss = torch.nn.MSELoss().to('cuda') self.PerceptualLoss = torch.nn.MSELoss().to('cuda') def _perceptual_loss(self,fake_im,real_im): vgg_fake = self.loss_network(fake_im) vgg_target = self.loss_network(real_im) perceptualLoss = 0 for vgg_idx in range(0,4): perceptualLoss += self.PerceptualLoss(vgg_fake[vgg_idx], vgg_target[vgg_idx].detach()) return perceptualLoss def _resume(self,path_fan, path_gen): self.FAN.load_state_dict(torch.load(path_fan)) self.GEN.load_state_dict(torch.load(path_gen)) def _save(self, path_to_models, epoch): torch.save(self.FAN.state_dict(), path_to_models + str(epoch) + '.fan.pth') torch.save(self.GEN.state_dict(), path_to_models + str(epoch) + '.gen.pth') def _set_batch(self,data): self.A = {k: Variable(data[k],requires_grad=True).to('cuda') for k in data.keys() if type(data[k]).__name__ == 'Tensor'} def forward(self): self.GEN.zero_grad() self.FAN.zero_grad() H = self.FAN(self.A['Im']) H, Pts = self.BOT(H) X = 0.5*(self.GEN(self.A['ImP'], H)+1) self.loss['rec'] = self._perceptual_loss(X, self.A['Im']) + self.SelfLoss(X , self.A['Im']) self.loss['rec'].backward() if self.gradclip: torch.nn.utils.clip_grad_norm_(self.FAN.parameters(), 1, norm_type=2) torch.nn.utils.clip_grad_norm_(self.GEN.parameters(), 1, norm_type=2) return {'Heatmap' : H, 'Reconstructed': X, 'Points' : Pts}
class model(): def __init__(self, sigma=0.5, temperature=0.5, gradclip=1, npts=10, option='incremental', size=128, path_to_check='checkpoint_fansoft/fan_109.pth', args=None): self.npoints = npts self.gradclip = gradclip # - define FAN self.FAN = FAN(1, n_points=self.npoints) if not option == 'scratch': net_dict = self.FAN.state_dict() pretrained_dict = torch.load(path_to_check, map_location='cuda') pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict)} pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict[k].shape == net_dict[k].shape} net_dict.update(pretrained_dict) self.FAN.load_state_dict(net_dict, strict=True) if option == 'incremental': print('Option is incremental') self.FAN.apply(convertLayer) # - define Bottleneck self.BOT = GeoDistill(sigma=sigma, temperature=temperature, out_res=int(size/4)) # - define GEN self.GEN = Generator(conv_dim=32, c_dim=self.npoints) # # Load pretrained model if args.resume_folder: path_fan = '{}/model_{}.fan.pth'.format(args.resume_folder, args.resume_epoch) path_gen = '{}/model_{}.gen.pth'.format(args.resume_folder, args.resume_epoch) self._resume(path_fan, path_gen) # - multiple GPUs if torch.cuda.device_count() > 1: self.FAN = torch.nn.DataParallel(self.FAN) self.BOT = torch.nn.DataParallel(self.BOT) self.GEN = torch.nn.DataParallel(self.GEN) self.FAN.to('cuda').train() self.BOT.to('cuda').train() self.GEN.to('cuda').train() # - VGG for perceptual loss self.loss_network = LossNetwork(torch.nn.DataParallel(vgg16(pretrained=True)))\ if torch.cuda.device_count() > 1 else LossNetwork(vgg16(pretrained=True)) self.loss_network.eval() self.loss_network.to('cuda') self.loss = dict.fromkeys(['all_loss', 'rec', 'perceptual']) self.A = None # - define losses for reconstruction self.SelfLoss = torch.nn.MSELoss().to('cuda') self.PerceptualLoss = torch.nn.MSELoss().to('cuda') self.heatmap = HeatMap(32, 0.5).cuda() def _perceptual_loss(self, fake_im, real_im): vgg_fake = self.loss_network(fake_im) vgg_target = self.loss_network(real_im) perceptualLoss = 0 for vgg_idx in range(0, 4): perceptualLoss += self.PerceptualLoss(vgg_fake[vgg_idx], vgg_target[vgg_idx].detach()) return perceptualLoss def _resume(self, path_fan, path_gen): self.FAN.load_state_dict(torch.load(path_fan)) self.GEN.load_state_dict(torch.load(path_gen)) def _save(self, path_to_models, epoch): torch.save(self.FAN.state_dict(), path_to_models + str(epoch) + '.fan.pth') torch.save(self.GEN.state_dict(), path_to_models + str(epoch) + '.gen.pth') def _set_batch(self, data): self.A = {k: Variable(data[k], requires_grad=False) .to('cuda') for k in data.keys() if type(data[k]).__name__ == 'Tensor'} def forward(self, myoptimizers, order_idx): self.GEN.zero_grad() self.FAN.zero_grad() Im_3 = self.A['Im'][order_idx] H_3 = self.FAN(Im_3) # [B, N, H, W] (N landmarks) H_3, Pts_3 = self.BOT(H_3) # Change it to landmark locations H = self.FAN(self.A['Im']) # [B, N, H, W] (N landmarks) H, Pts = self.BOT(H) # Change it to landmark locations H_P = self.FAN(self.A['ImP']) # [B, N, H, W] (N landmarks) H_P, Pts_P = self.BOT(H_P) # Change it to landmark locations X_3 = 0.5 * (self.GEN(self.A['Im'], H_3)+1) X_P_3 = 0.5 * (self.GEN(self.A['ImP'], H_3)+1) X_P = 0.5 * (self.GEN(X_3, H_P) + 1) # [B, 3, H_ori, W_ori] X = 0.5 * (self.GEN(X_P_3, H) + 1) # [B, 3, H_ori, W_ori] self.loss['perceptual'] = self._perceptual_loss(X, self.A['Im']) + self._perceptual_loss(X_P, self.A['ImP']) self.loss['rec'] = self.SelfLoss(X, self.A['Im']) + self.SelfLoss(X_P, self.A['ImP']) self.loss['all_loss'] = self.loss['perceptual'] + self.loss['rec'] self.loss['all_loss'].backward() if self.gradclip: torch.nn.utils.clip_grad_norm_(self.FAN.parameters(), 1, norm_type=2) torch.nn.utils.clip_grad_norm_(self.GEN.parameters(), 1, norm_type=2) return {'Heatmap': [H, H_P, H_3], 'Reconstructed': [X, X_P, X_3, X_P_3], 'Points': [Pts, Pts_P, Pts_3], 'myoptimizers': myoptimizers}