def load_checkpoints(config_path, checkpoint_path, device="cuda"): with open(config_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) generator = OcclusionAwareGenerator( **config["model_params"]["generator_params"], **config["model_params"]["common_params"], ) generator.to(device) kp_detector = KPDetector( **config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"], ) kp_detector.to(device) checkpoint = torch.load(checkpoint_path, map_location=device) generator.load_state_dict(checkpoint["generator"]) kp_detector.load_state_dict(checkpoint["kp_detector"]) generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() return generator, kp_detector
def transfer(config, generator, kp_detector, checkpoint, log_dir, dataset): log_dir = os.path.join(log_dir, 'transfer') png_dir = os.path.join(log_dir, 'png') transfer_params = config['transfer_params'] dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=transfer_params['num_pairs']) dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) if checkpoint is not None: Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) else: raise AttributeError( "Checkpoint should be specified for mode='transfer'.") if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(png_dir): os.makedirs(png_dir) generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() for it, x in tqdm(enumerate(dataloader)): with torch.no_grad(): x = { key: value if not hasattr(value, 'cuda') else value.cuda() for key, value in x.items() } driving_video = x['driving_video'] source_image = x['source_video'][:, :, :1, :, :] out = transfer_one(generator, kp_detector, source_image, driving_video, transfer_params) img_name = "-".join([x['driving_name'][0], x['source_name'][0]]) # Store to .png for evaluation out_video_batch = out['video_prediction'].data.cpu().numpy() out_video_batch = np.concatenate(np.transpose( out_video_batch, [0, 2, 3, 4, 1])[0], axis=1) imageio.imsave(os.path.join(png_dir, img_name + '.png'), (255 * out_video_batch).astype(np.uint8)) image = Visualizer( **config['visualizer_params']).visualize_transfer( driving_video=driving_video, source_image=source_image, out=out) imageio.mimsave( os.path.join(log_dir, img_name + transfer_params['format']), image)
def load_checkpoints(config_path, checkpoint_path, cpu=False): with open(config_path) as f: config = yaml.load(f) generator = OcclusionAwareGenerator( **config["model_params"]["generator_params"], **config["model_params"]["common_params"], ) if cpu: generator.cpu() else: generator.cuda() kp_detector = KPDetector( **config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"], ) if cpu: kp_detector.cpu() else: kp_detector.cuda() checkpoint = torch.load(checkpoint_path, map_location="cpu" if cpu else None) generator.load_state_dict(checkpoint["generator"]) kp_detector.load_state_dict(checkpoint["kp_detector"]) generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() return generator, kp_detector
def load_checkpoints(config_path, checkpoint_path, device='cuda'): with open(config_path) as f: config = yaml.load(f) generator = OcclusionAwareGenerator( **config['model_params']['generator_params'], **config['model_params']['common_params']) generator.to(device) kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) kp_detector.to(device) checkpoint = torch.load(checkpoint_path, map_location=device) generator.load_state_dict(checkpoint['generator']) kp_detector.load_state_dict(checkpoint['kp_detector']) generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() return generator, kp_detector
def load_checkpoints(config, checkpoint, blend_scale=0.125, first_order_motion_model=False, cpu=False): with open(config) as f: config = yaml.load(f) reconstruction_module = PartSwapGenerator(blend_scale=blend_scale, first_order_motion_model=first_order_motion_model, **config['model_params']['reconstruction_module_params'], **config['model_params']['common_params']) if not cpu: reconstruction_module.cuda() segmentation_module = SegmentationModule(**config['model_params']['segmentation_module_params'], **config['model_params']['common_params']) if not cpu: segmentation_module.cuda() if cpu: checkpoint = torch.load(checkpoint, map_location=torch.device('cpu')) else: checkpoint = torch.load(checkpoint) load_reconstruction_module(reconstruction_module, checkpoint) load_segmentation_module(segmentation_module, checkpoint) if not cpu: reconstruction_module = DataParallelWithCallback(reconstruction_module) segmentation_module = DataParallelWithCallback(segmentation_module) reconstruction_module.eval() segmentation_module.eval() return reconstruction_module, segmentation_module
def load_checkpoints(config_path, checkpoint_path, cpu=False): with open(config_path) as f: config = yaml.load(f) generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']) if not cpu: generator.cuda() kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) if not cpu: kp_detector.cuda() if cpu: checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) else: checkpoint = torch.load(checkpoint_path) generator.load_state_dict(checkpoint['generator']) kp_detector.load_state_dict(checkpoint['kp_detector']) if not cpu: generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() return generator, kp_detector
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset): png_dir = os.path.join(log_dir, 'reconstruction/png') log_dir = os.path.join(log_dir, 'reconstruction') if checkpoint is not None: Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) else: raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(png_dir): os.makedirs(png_dir) loss_list = [] if torch.cuda.is_available(): generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() for it, x in tqdm(enumerate(dataloader)): if config['reconstruction_params']['num_videos'] is not None: if it > config['reconstruction_params']['num_videos']: break with torch.no_grad(): predictions = [] visualizations = [] if torch.cuda.is_available(): x['video'] = x['video'].cuda() kp_source = kp_detector(x['video'][:, :, 0]) for frame_idx in range(x['video'].shape[2]): source = x['video'][:, :, 0] driving = x['video'][:, :, frame_idx] kp_driving = kp_detector(driving) out = generator(source, kp_source=kp_source, kp_driving=kp_driving) out['kp_source'] = kp_source out['kp_driving'] = kp_driving del out['sparse_deformed'] predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) visualization = Visualizer(**config['visualizer_params']).visualize(source=source, driving=driving, out=out) visualizations.append(visualization) loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy()) predictions = np.concatenate(predictions, axis=1) imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) image_name = x['name'][0] + config['reconstruction_params']['format'] imageio.mimsave(os.path.join(log_dir, image_name), visualizations) print("Reconstruction loss: %s" % np.mean(loss_list))
def __init__(self, hidden_dim, lr, hard_or_full_trip, margin, num_workers, batch_size, restore_iter, total_iter, save_name, train_pid_num, frame_num, model_name, train_source, test_source, img_size=64): self.save_name = save_name self.train_pid_num = train_pid_num self.train_source = train_source self.test_source = test_source self.hidden_dim = hidden_dim self.lr = lr self.hard_or_full_trip = hard_or_full_trip self.margin = margin self.frame_num = frame_num self.num_workers = num_workers self.batch_size = batch_size self.model_name = model_name self.P, self.M = batch_size self.restore_iter = restore_iter self.total_iter = total_iter self.img_size = img_size self.encoder = SetNet(self.hidden_dim).float() self.encoder = DataParallelWithCallback(self.encoder) self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float() self.triplet_loss = DataParallelWithCallback(self.triplet_loss) self.encoder.cuda() self.triplet_loss.cuda() self.optimizer = optim.Adam([ {'params': self.encoder.parameters()}, ], lr=self.lr) self.hard_loss_metric = [] self.full_loss_metric = [] self.full_loss_num = [] self.dist_list = [] self.mean_dist = 0.01 self.sample_type = 'all'
def cuda(self): self.model.cuda() if self.get_param('network.sync_bn', False): self.model = DataParallelWithCallback(self.model, dim=self.batch_axis) else: self.model = nn.DataParallel(self.model, dim=self.batch_axis)
def testSyncBatchNormSyncEval(self): bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) bn.cuda() sync_bn.cuda() self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
def testSyncBatchNorm2DSyncTrain(self): bn = nn.BatchNorm2d(10) sync_bn = SynchronizedBatchNorm2d(10) sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) bn.cuda() sync_bn.cuda() self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
def load_checkpoints(config_path, checkpoint_path, cpu=False): with open(config_path) as f: config = yaml.load(f) generator = Generator(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], **config['model_params']['generator_params']) if not cpu: generator.cuda() region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], estimate_affine=config['model_params']['estimate_affine'], **config['model_params']['region_predictor_params']) if not cpu: region_predictor.cuda() avd_network = AVDNetwork(num_regions=config['model_params']['num_regions'], **config['model_params']['avd_network_params']) if not cpu: avd_network.cuda() if cpu: checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) else: checkpoint = torch.load(checkpoint_path) generator.load_state_dict(checkpoint['generator']) region_predictor.load_state_dict(checkpoint['region_predictor']) if 'avd_network' in checkpoint: avd_network.load_state_dict(checkpoint['avd_network']) if not cpu: generator = DataParallelWithCallback(generator) region_predictor = DataParallelWithCallback(region_predictor) avd_network = DataParallelWithCallback(avd_network) generator.eval() region_predictor.eval() avd_network.eval() return generator, region_predictor, avd_network
def load_checkpoints(kp_detector_path, generator_path, cpu=False): generator = torch.jit.load(generator_path) kp_detector = torch.jit.load(kp_detector_path) if not cpu: generator.cuda() if not cpu: kp_detector.cuda() if not cpu: generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() return generator, kp_detector
def restore(self, checkpoint): self.epoch = 0 self.generator = DCGenerator(**self.config['generator_params']) self.generator = DataParallelWithCallback(self.generator, device_ids=self.device_ids) self.optimizer_generator = torch.optim.Adam( params=self.generator.parameters(), lr=self.config['lr_generator'], betas=(self.config['b1_generator'], self.config['b2_generator']), weight_decay=0, eps=1e-8) self.discriminator = DCDiscriminator( **self.config['discriminator_params']) self.discriminator = DataParallelWithCallback( self.discriminator, device_ids=self.device_ids) self.optimizer_discriminator = torch.optim.Adam( params=self.discriminator.parameters(), lr=self.config['lr_discriminator'], betas=(self.config['b1_discriminator'], self.config['b2_discriminator']), weight_decay=0, eps=1e-8) if checkpoint is not None: data = torch.load(checkpoint) for key, value in data: if key == 'epoch': self.epoch = value else: self.__dict__[key].load_state_dict(value) lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs'] self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR( self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1) self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR( self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1)
def __init__(self, opt): self.opt = opt self.pix2pix_model = Pix2PixModel(opt) if len(opt.gpu_ids) > 0: self.pix2pix_model = DataParallelWithCallback( self.pix2pix_model, device_ids=opt.gpu_ids) self.pix2pix_model_on_one_gpu = self.pix2pix_model.module else: self.pix2pix_model_on_one_gpu = self.pix2pix_model self.generated = None if opt.isTrain: self.optimizer_G, self.optimizer_D = \ self.pix2pix_model_on_one_gpu.create_optimizers(opt) self.old_lr = opt.lr
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = DataParallelWithCallback(net, gpu_ids) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net
def main(): net = Baseline(num_classes=culane.num_classes, deep_base=args['deep_base']).cuda().train() net = DataParallelWithCallback(net) optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['base_lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['base_lr'] }], momentum=args['momentum']) if len(args['checkpoint']) > 0: print('training resumes from \'%s\'' % args['checkpoint']) net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['checkpoint'] + '_checkpoint.pth'))) optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['checkpoint'] + '_checkpoint_optim.pth'))) optimizer.param_groups[0]['lr'] = 2 * args['base_lr'] optimizer.param_groups[1]['lr'] = args['base_lr'] check_mkdir(os.path.join(ckpt_path, exp_name)) open(log_path, 'w').write(str(args) + '\n\n') train(net, optimizer)
def debug_generator(generator, kp_to_skl_gt, loader, train_params, logger, device_ids, tgt_batch=None): log_params = train_params['log_params'] genModel = ConditionalGenerator2D(generator, train_params) genModel = DataParallelWithCallback(genModel, device_ids=device_ids) optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr'], betas=train_params['betas']) scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1) k=0 train_views = [0,1,3] eval_views = [2] if tgt_batch is not None: tgt_batch_samples = split_data(tgt_batch, train_views=train_views, eval_views=eval_views) with torch.no_grad(): tgt_batch_samples['gt_skl'] = kp_to_skl_gt(tgt_batch_samples['kps'].to('cuda')).unsqueeze(1) tgt_batch_samples['gt_skl_eval'] = kp_to_skl_gt(tgt_batch_samples['kps_eval'].to('cuda')).unsqueeze(1) for epoch in range(train_params['num_epochs']): for i, batch in enumerate(tqdm(loader)): batch_samples = split_data((img, annots, ref_img), train_views=train_views, eval_views=eval_views) #imgs = flatten_views(imgs) #ref_imgs = flatten_views(ref_imgs) #ref_imgs = torch.rand(*imgs.shape) with torch.no_grad(): batch_samples['gt_skl'] = kp_to_skl_gt(batch_samples['kps'].to('cuda')).unsqueeze(1) #batch_samples['gt_skl_eval'] = kp_to_skl_gt(batch_samples['kps_eval'].to('cuda')).unsqueeze(1) #gt_skl = (kp_to_skl_gt(flatten_views(annots / (ref_img.shape[3] - 1)).to('cuda'))).unsqueeze(1) #gt_skl = torch.rand(imgs.shape[0], 1, *imgs.shape[2:]) #generator_out = genModel(imgs, ref_imgs, gt_skl) generator_out = genModel(batch_samples['imgs'], batch_samples['ref_imgs'], batch_samples['gt_skl']) ##### Generator update #loss_generator = generator_out['loss'] loss_generator = generator_out['perceptual_loss'] loss_generator = [x.mean() for x in loss_generator] loss_gen = sum(loss_generator) loss_gen.backward(retain_graph=not train_params['detach_kp_discriminator']) optimizer_generator.step() optimizer_generator.zero_grad() ########### LOG logger.add_scalar("Generator Loss", loss_gen.item(), epoch * len(loader) + i + 1) if i in log_params['log_imgs']: if tgt_batch is not None: with torch.no_grad(): genModel.eval() generator_out_eval = genModel(tgt_batch_samples['imgs_eval'], tgt_batch_samples['ref_imgs_eval'], tgt_batch_samples['gt_skl_eval']) #generator_out_eval = genModel(batch_samples['imgs_eval'], # batch_samples['ref_imgs_eval'], # batch_samples['gt_skl_eval']) concat_img_eval = np.concatenate((tensor_to_image(tgt_batch_samples['imgs_eval'][k]), tensor_to_image(tgt_batch_samples['gt_skl_eval'][k]), tensor_to_image(tgt_batch_samples['ref_imgs_eval'][k]), tensor_to_image(generator_out_eval['reconstructred_image'][k])), axis=2) # concat along width logger.add_image('Sample_{%d}_EVAL' % i, concat_img_eval, epoch) genModel.train() k += 1 k = k % 4 concat_img = np.concatenate((tensor_to_image(batch_samples['imgs'][k]), tensor_to_image(batch_samples['gt_skl'][k]), tensor_to_image(batch_samples['ref_imgs'][k]), tensor_to_image(generator_out['reconstructred_image'][k])), axis=2) # concat along width logger.add_image('Sample_{%d}' % i, concat_img, epoch) scheduler_generator.step()
def debug_encoder(model_skeleton_to_keypoint, model_keypoint_to_skeleton, loader, loader_tgt, train_params, checkpoint, logger, device_ids, tgt_batch=None): log_params = train_params['log_params'] optimizer_encoder = torch.optim.Adam( model_skeleton_to_keypoint.parameters(), lr=train_params['lr'], betas=train_params['betas']) resume_epoch = 0 resume_iteration = 0 if checkpoint is not None: print('Loading Checkpoint: %s' % checkpoint) resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint( checkpoint, skeleton_to_keypoints=model_skeleton_to_keypoint, optimizer_skeleton_to_keypoints=optimizer_encoder) logger.epoch = resume_epoch logger.iterations = resume_iteration scheduler_encoder = MultiStepLR(optimizer_encoder, train_params['epoch_milestones'], gamma=0.1, last_epoch=logger.epoch - 1) encoder = SkeletonToKeypoints(model_skeleton_to_keypoint, model_keypoint_to_skeleton) encoder = DataParallelWithCallback(encoder, device_ids=device_ids) with torch.no_grad(): skeletons = model_keypoint_to_skeleton( tgt_batch['annots'].to('cuda')).unsqueeze(1).detach() tgt_batch = { 'imgs': tgt_batch['imgs'], 'annots': tgt_batch['annots'], 'annots_unnormed': tgt_batch['annots'], 'skeletons': skeletons } k = 0 for epoch in range(logger.epoch, train_params['num_epochs']): for i, batch in enumerate(tqdm(loader)): annots = batch['annots'] annots_gt = batch['annots_unnormed'] with torch.no_grad(): gt_skl = model_keypoint_to_skeleton( annots.to('cuda')).unsqueeze(1).detach() encoder_out = encoder(gt_skl, annots) optimizer_encoder.zero_grad() loss = encoder_out['loss'].mean() loss.backward() optimizer_encoder.step() optimizer_encoder.zero_grad() ####### LOG VALIDATION if i % log_params['eval_frequency'] == 0: eval_loss = eval_model(encoder, next(iter(loader_tgt)), model_keypoint_to_skeleton) eval_sz = int(len(loader) / log_params['eval_frequency']) it_number = epoch * eval_sz + (logger.iterations / log_params['eval_frequency']) logger.add_scalar('Eval loss', eval_loss, it_number) ####### LOG logger.add_scalar('L2 loss', loss.item(), logger.iterations) if i in log_params['log_imgs']: with torch.no_grad(): encoder.eval() target_out = encoder(tgt_batch['skeletons'], tgt_batch['annots']) encoder.train() skl_out = target_out['reconstruction'] kps_out = target_out['keypoints'] concat_img = np.concatenate( (draw_kp(tensor_to_image(tgt_batch['imgs'][k]), tgt_batch['annots_unnormed'][k]), tensor_to_image(tgt_batch['skeletons'][k]), tensor_to_image(skl_out[k]), draw_kp(tensor_to_image(tgt_batch['imgs'][k]), kps_out[k], color='red')), axis=2) concat_img_train = np.concatenate( (tensor_to_image(gt_skl[k]), tensor_to_image(encoder_out['reconstruction'][k])), axis=2) logger.add_image('Eval_{%d}' % i, concat_img, epoch) logger.add_image('Train_{%d}' % i, concat_img_train, epoch) k += 1 k = k % len(log_params['log_imgs']) logger.step_it() scheduler_encoder.step() logger.step_epoch( models={ 'skeleton_to_keypoints': model_skeleton_to_keypoint, 'optimizer_skeleton_to_keypoints': optimizer_encoder })
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids, load_weights_only=False, use_both=False, update_kp=True): train_params = config['train_params'] optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr'], betas=(0.5, 0.999)) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr'], betas=(0.5, 0.999)) optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr'], betas=(0.5, 0.999)) if checkpoint is not None: if load_weights_only: Logger.load_cpk(checkpoint, generator, discriminator, kp_detector) start_epoch = 0 it = 0 else: saved_start_epoch, saved_it = Logger.load_cpk( checkpoint, generator, discriminator, kp_detector, optimizer_generator, optimizer_discriminator, optimizer_kp_detector) start_epoch = saved_start_epoch it = saved_it else: start_epoch = 0 it = 0 scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=4, drop_last=True) generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) generator_full_par = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full_par = DataParallelWithCallback(discriminator_full, device_ids=device_ids) if not os.path.isdir(log_dir + 'tb_log/'): os.mkdir(log_dir + 'tb_log/') writer = SummaryWriter(log_dir=log_dir + 'tb_log/') with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], **train_params['log_params']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): total_discriminator_loss = 0. total_generator_loss = 0. for i, x in enumerate(dataloader): # driving = driving_A, source = src_B --> driving_B' mn1_dict = {} mn1_dict['source'] = x['src_B'] mn1_dict['video'] = x['driving_A'] mn1_dict['gt_video'] = x['driving_B'] '''This code is for the first model where we use ground truth key points for both''' if not use_both: loss_values_B, loss_B, generated_B, gt_kp_joined_B = compute_loss( generator_full_par, mn1_dict) else: '''This code is for the second model where we give GT and approx kp respectively ''' out_B = compute_loss(generator_full_par, mn1_dict, use_both=True) loss_values_B, loss_B, generated_B, approx_kp_joined_B, gt_kp_joined_B = out_B # driving = generated_B (driving_B'), source = src_A --> driving_A' mn2_dict = {} mn2_dict['source'] = x['src_A'] mn2_dict['video'] = generated_B['video_prediction'] mn2_dict['gt_video'] = x['driving_A'] # First model - see above if not use_both: loss_values_A, loss_A, generated_A, gt_kp_joined_A = compute_loss( generator_full_par, mn2_dict) else: # Second model - see above out_A = compute_loss(generator_full_par, mn2_dict, use_both=True) loss_values_A, loss_A, generated_A, approx_kp_joined_A, gt_kp_joined_A = out_A loss = loss_B + loss_A total_generator_loss += loss loss = loss_B total_generator_loss = loss writer.add_scalar('generator loss B', loss_B.item(), it) writer.add_scalar('generator loss A', loss_A.item(), it) writer.add_scalar('generator loss', loss.item(), it) loss.backward( retain_graph=not train_params['detach_kp_discriminator']) optimizer_generator.step() optimizer_generator.zero_grad() optimizer_discriminator.zero_grad() if train_params['detach_kp_discriminator'] and update_kp: optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() generator_loss_values = {} generator_loss_values['A'] = [ val.detach().cpu().numpy() for val in loss_values_A ] generator_loss_values['B'] = [ val.detach().cpu().numpy() for val in loss_values_B ] if not use_both: loss_values_B = discriminator_full_par( mn1_dict, gt_kp_joined_B, generated_B) loss_values_A = discriminator_full_par( mn2_dict, gt_kp_joined_A, generated_A) else: loss_values_B = discriminator_full_par( mn1_dict, approx_kp_joined_B, generated_B, gt_kp_joined_B) loss_values_A = discriminator_full_par( mn2_dict, approx_kp_joined_A, generated_A, gt_kp_joined_A) loss_values_B = [val.mean() for val in loss_values_B] loss_values_A = [val.mean() for val in loss_values_A] loss_B = sum(loss_values_B) loss_A = sum(loss_values_A) loss = loss_A + loss_B total_discriminator_loss += loss loss = loss_B total_discriminator_loss = loss writer.add_scalar('disc loss B', loss_B.item(), it) writer.add_scalar('disc loss A', loss_A.item(), it) writer.add_scalar('disc loss', loss.item(), it) loss.backward() optimizer_discriminator.step() optimizer_discriminator.zero_grad() if not train_params['detach_kp_discriminator'] and update_kp: optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() discriminator_loss_values = {} discriminator_loss_values['A'] = [ val.detach().cpu().numpy() for val in loss_values_A ] discriminator_loss_values['B'] = [ val.detach().cpu().numpy() for val in loss_values_B ] values = { 'A': generator_loss_values['A'] + discriminator_loss_values['A'], 'B': generator_loss_values['B'] + discriminator_loss_values['B'] } logger.log_iter( it, names=generator_loss_names(train_params['loss_weights']) + discriminator_loss_names(), values=values['B'], inp=mn1_dict, out=generated_B, name='src_B_driving_A') logger.log_iter( it, names=generator_loss_names(train_params['loss_weights']) + discriminator_loss_names(), values=values['A'], inp=mn2_dict, out=generated_A, name='src_A_driving_B') it += 1 scheduler_generator.step() scheduler_discriminator.step() scheduler_kp_detector.step() writer.add_scalar('generator loss / train', total_generator_loss / (i + 1), epoch) writer.add_scalar('discriminator loss / train', total_discriminator_loss / (i + 1), epoch) logger.log_epoch( epoch, { 'generator': generator, 'discriminator': discriminator, 'kp_detector': kp_detector, 'optimizer_generator': optimizer_generator, 'optimizer_discriminator': optimizer_discriminator, 'optimizer_kp_detector': optimizer_kp_detector })
def train(config, generator, mask_generator, checkpoint, log_dir, dataset, device_ids): train_params = config['train_params'] optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) optimizer_mask_generator = torch.optim.Adam( mask_generator.parameters(), lr=train_params['lr_mask_generator'], betas=(0.5, 0.999)) if checkpoint is not None: print('loading cpk') start_epoch = Logger.load_cpk( checkpoint, generator, mask_generator, optimizer_generator, None if train_params['lr_mask_generator'] == 0 else optimizer_mask_generator) else: start_epoch = 0 print(start_epoch) scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_mask_generator = MultiStepLR( optimizer_mask_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_mask_generator'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True) generator_full = GeneratorFullModel(mask_generator, generator, train_params) if torch.cuda.is_available(): generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for index, x in enumerate(dataloader): predict_mask = epoch >= 1 losses_generator, generated = generator_full(x, predict_mask) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) loss.backward() optimizer_generator.step() optimizer_generator.zero_grad() optimizer_mask_generator.step() optimizer_mask_generator.zero_grad() losses = { key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items() } logger.log_iter(losses=losses) scheduler_generator.step() scheduler_mask_generator.step() logger.log_epoch( epoch, { 'generator': generator, 'mask_generator': mask_generator, 'optimizer_generator': optimizer_generator, 'optimizer_mask_generator': optimizer_mask_generator }, inp=x, out=generated, save_w=True)
print(opt) torch.cuda.set_device(opt.device_id[0]) # ######################## Module ################################# print('Building model') model = actionModel(opt.class_num, batch_norm=True, dropout=opt.dropout, q=opt.q, image_size=opt.img_size, syn_bn=opt.syn_bn, test_scheme=2) print(model) if opt.syn_bn: model = DataParallelWithCallback(model, device_ids=opt.device_id).cuda() else: model = torch.nn.DataParallel(model, device_ids=opt.device_id).cuda() print("Channels: " + str(model.module.channels)) # ########################Optimizer######################### optimizer = torch.optim.SGD([{ 'params': model.module.RNN.parameters(), 'lr': opt.LR[0] }, { 'params': model.module.ShortCut.parameters(), 'lr': opt.LR[0] }, { 'params': model.module.classifier.parameters(), 'lr': opt.LR[1] }],
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids): train_params = config['train_params'] optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr'], betas=(0.5, 0.999)) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr'], betas=(0.5, 0.999)) optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr'], betas=(0.5, 0.999)) if checkpoint is not None: start_epoch, it = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, optimizer_generator, optimizer_discriminator, optimizer_kp_detector) else: start_epoch = 0 it = 0 scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_kp_detector = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=4, drop_last=True) generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) generator_full_par = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full_par = DataParallelWithCallback(discriminator_full, device_ids=device_ids) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], **train_params['log_params']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for x in dataloader: out = generator_full_par(x) loss_values = out[:-2] generated = out[-2] kp_joined = out[-1] loss_values = [val.mean() for val in loss_values] loss = sum(loss_values) loss.backward( retain_graph=not train_params['detach_kp_discriminator']) optimizer_generator.step() optimizer_generator.zero_grad() optimizer_discriminator.zero_grad() if train_params['detach_kp_discriminator']: optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() generator_loss_values = [ val.detach().cpu().numpy() for val in loss_values ] loss_values = discriminator_full_par(x, kp_joined, generated) loss_values = [val.mean() for val in loss_values] loss = sum(loss_values) loss.backward() optimizer_discriminator.step() optimizer_discriminator.zero_grad() if not train_params['detach_kp_discriminator']: optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() discriminator_loss_values = [ val.detach().cpu().numpy() for val in loss_values ] logger.log_iter( it, names=generator_loss_names(train_params['loss_weights']) + discriminator_loss_names(), values=generator_loss_values + discriminator_loss_values, inp=x, out=generated) it += 1 scheduler_generator.step() scheduler_discriminator.step() scheduler_kp_detector.step() logger.log_epoch( epoch, { 'generator': generator, 'discriminator': discriminator, 'kp_detector': kp_detector, 'optimizer_generator': optimizer_generator, 'optimizer_discriminator': optimizer_discriminator, 'optimizer_kp_detector': optimizer_kp_detector })
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids): # Refer to *.yaml, "train_params" section. # This including epoch nums, etc ... train_params = config['train_params'] # Define the optimizers for three sub-networks # Refer to Adam() document for details optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999)) if checkpoint is not None: # Load in pretrained-models if set so # Models passed in are empty-initialized, which will be loaded in the following function start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector) else: start_epoch = 0 # TODO: not sure what's this, it seems to define schedulers contronlling training details scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: # Augment the dataset according to "num_reapeat" dataset = DatasetRepeater(dataset, train_params['num_repeats']) # Load in data with form that network can determine # Refer to pytorch DataLoader for details # 这里dataloader是一个FramesDataset类,它是 Dataset 的一个子类,所以可以有如下操作 dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=2, drop_last=True) # Initialize two models for training # TODO: 阅读 generator 和 discrimator 的构造,key point detector 的部分应包含在 generator 当中 generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params) # TODO: 阅读 discriminator,需注意的是上述 Generator 中也有 discriminator 存在,高清两者区别 discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) # Transfer model to gpu type if torch.cuda.is_available(): generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for x in dataloader: # 此处为前向传播,第一个返回值为loss,第二个为生成器的输出图片 losses_generator, generated = generator_full(x) # 此处计算的loss有很多种类,此处取了每一种的平均并求和 loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) # 此处分别使用不同部分的优化器进行 step 更新 loss.backward() optimizer_generator.step() optimizer_generator.zero_grad() optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() # 此处判断是否使用 GAN 的训练思想 if train_params['loss_weights']['generator_gan'] != 0: # 增加判别器的使用 optimizer_discriminator.zero_grad() # 用判别器判定生成数据和源数据 losses_discriminator = discriminator_full(x, generated) loss_values = [val.mean() for val in losses_discriminator.values()] loss = sum(loss_values) # 更新判别器 loss.backward() optimizer_discriminator.step() optimizer_discriminator.zero_grad() else: losses_discriminator = {} # 注意此处的 update 是 python 中字典自带的更新方式 losses_generator.update(losses_discriminator) losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} logger.log_iter(losses=losses) # 此处为一个 epoch 的工作完成 # TODO: 这是之前不确定是什么的数据结构,推断是对训练的schedule器的更新 scheduler_generator.step() scheduler_discriminator.step() scheduler_kp_detector.step() logger.log_epoch(epoch, {'generator': generator, 'discriminator': discriminator, 'kp_detector': kp_detector, 'optimizer_generator': optimizer_generator, 'optimizer_discriminator': optimizer_discriminator, 'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
def prediction(config, generator, kp_detector, checkpoint, log_dir): dataset = FramesDataset(is_train=True, transform=VideoToTensor(), **config['dataset_params']) log_dir = os.path.join(log_dir, 'prediction') png_dir = os.path.join(log_dir, 'png') if checkpoint is not None: Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) else: raise AttributeError("Checkpoint should be specified for mode='prediction'.") dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(png_dir): os.makedirs(png_dir) print("Extracting keypoints...") kp_detector.eval() generator.eval() keypoints_array = [] prediction_params = config['prediction_params'] for it, x in tqdm(enumerate(dataloader)): if prediction_params['train_size'] is not None: if it > prediction_params['train_size']: break with torch.no_grad(): keypoints = [] for i in range(x['video'].shape[2]): kp = kp_detector(x['video'][:, :, i:(i + 1)]) kp = {k: v.data.cpu().numpy() for k, v in kp.items()} keypoints.append(kp) keypoints_array.append(keypoints) predictor = PredictionModule(num_kp=config['model_params']['common_params']['num_kp'], kp_variance=config['model_params']['common_params']['kp_variance'], **prediction_params['rnn_params']).cuda() num_epochs = prediction_params['num_epochs'] lr = prediction_params['lr'] bs = prediction_params['batch_size'] num_frames = prediction_params['num_frames'] init_frames = prediction_params['init_frames'] optimizer = torch.optim.Adam(predictor.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50) kp_dataset = KPDataset(keypoints_array, num_frames=num_frames) kp_dataloader = DataLoader(kp_dataset, batch_size=bs) print("Training prediction...") for _ in trange(num_epochs): loss_list = [] for x in kp_dataloader: x = {k: v.cuda() for k, v in x.items()} gt = {k: v.clone() for k, v in x.items()} for k in x: x[k][:, init_frames:] = 0 prediction = predictor(x) loss = sum([torch.abs(gt[k][:, init_frames:] - prediction[k][:, init_frames:]).mean() for k in x]) loss.backward() optimizer.step() optimizer.zero_grad() loss_list.append(loss.detach().data.cpu().numpy()) loss = np.mean(loss_list) scheduler.step(loss) dataset = FramesDataset(is_train=False, transform=VideoToTensor(), **config['dataset_params']) dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) print("Make predictions...") for it, x in tqdm(enumerate(dataloader)): with torch.no_grad(): x['video'] = x['video'][:, :, :num_frames] kp_init = kp_detector(x['video']) for k in kp_init: kp_init[k][:, init_frames:] = 0 kp_source = kp_detector(x['video'][:, :, :1]) kp_video = predictor(kp_init) for k in kp_video: kp_video[k][:, :init_frames] = kp_init[k][:, :init_frames] if 'var' in kp_video and prediction_params['predict_variance']: kp_video['var'] = kp_init['var'][:, (init_frames - 1):init_frames].repeat(1, kp_video['var'].shape[1], 1, 1, 1) out = generate(generator, appearance_image=x['video'][:, :, :1], kp_appearance=kp_source, kp_video=kp_video) x['source'] = x['video'][:, :, :1] out_video_batch = out['video_prediction'].data.cpu().numpy() out_video_batch = np.concatenate(np.transpose(out_video_batch, [0, 2, 3, 4, 1])[0], axis=1) imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * out_video_batch).astype(np.uint8)) image = Visualizer(**config['visualizer_params']).visualize_reconstruction(x, out) image_name = x['name'][0] + prediction_params['format'] imageio.mimsave(os.path.join(log_dir, image_name), image) del x, kp_video, kp_source, out
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset): log_dir = os.path.join(log_dir, 'animation') png_dir = os.path.join(log_dir, 'png') animate_params = config['animate_params'] dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs']) dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) if checkpoint is not None: Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) else: raise AttributeError("Checkpoint should be specified for mode='animate'.") if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(png_dir): os.makedirs(png_dir) if torch.cuda.is_available(): generator = DataParallelWithCallback(generator) kp_detector = DataParallelWithCallback(kp_detector) generator.eval() kp_detector.eval() for it, x in tqdm(enumerate(dataloader)): with torch.no_grad(): predictions = [] visualizations = [] driving_video = x['driving_video'] source_frame = x['source_video'][:, :, 0, :, :] kp_source = kp_detector(source_frame) kp_driving_initial = kp_detector(driving_video[:, :, 0]) for frame_idx in range(driving_video.shape[2]): driving_frame = driving_video[:, :, frame_idx] kp_driving = kp_detector(driving_frame) kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, kp_driving_initial=kp_driving_initial, **animate_params['normalization_params']) out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm) out['kp_driving'] = kp_driving out['kp_source'] = kp_source out['kp_norm'] = kp_norm del out['sparse_deformed'] predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame, driving=driving_frame, out=out) visualization = visualization visualizations.append(visualization) predictions = np.concatenate(predictions, axis=1) result_name = "-".join([x['driving_name'][0], x['source_name'][0]]) imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8)) image_name = result_name + animate_params['format'] imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
batch_size=max(param_grid['batch_size'])) criterion = nn.MSELoss() best_valid_RMSE = np.full(1, np.inf) for grid in ParameterGrid(param_grid): print(f"===> Hyper-parameters = {grid}:") # random.seed(seed) # numpy.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) importlib.reload(model) net = getattr(model, model_name.upper())(input_channel=input_channel, input_size=input_size, output_size=output_size).to(device) if use_cuda: net = DataParallelWithCallback(net) print(f"===> Model:\n{list(net.modules())[0]}") print_param(net) if model_name == 'cnn0' or model_name == 'cnn1': optimizer = Adam(net.parameters(), lr=grid['lr'], l1=grid['l1'], weight_decay=grid['l2'], amsgrad=True) elif model_name == 'vgg19' or model_name == 'cnn2' or model_name == 'cnn3': optimizer = Adam([{ 'params': iter(param for name, param in net.named_parameters() if 'channel_mask' in name), 'l1': grid['l1_channel']
def train_kpdetector(model_kp_detector, loader, loader_tgt, train_params, checkpoint, logger, device_ids, tgt_batch=None, kp_map=None): log_params = train_params['log_params'] optimizer_kp_detector = torch.optim.Adam(model_kp_detector.parameters(), lr=train_params['lr'], betas=train_params['betas']) resume_epoch = 0 resume_iteration = 0 if checkpoint is not None: print('Loading Checkpoint: %s' % checkpoint) # TODO: Implement Load/resumo kp_detector if train_params['test'] == False: resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint(checkpoint, model_kp_detector=model_kp_detector, optimizer_kp_detector=optimizer_kp_detector) else: net_dict = model_kp_detector.state_dict() pretrained_dict = torch.load(checkpoint) pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict)} pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict[k].shape == net_dict[k].shape} net_dict.update(pretrained_dict) model_kp_detector.load_state_dict(net_dict, strict=True) model_kp_detector.apply(convertLayer) logger.epoch = resume_epoch logger.iterations = resume_iteration scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=logger.epoch-1) kp_detector = KPDetectorTrainer(model_kp_detector) kp_detector = DataParallelWithCallback(kp_detector, device_ids=device_ids) k = 0 if train_params['test'] == True: results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset']) print(' MSE: ' + str(results['MSE']) + ' PCK: ' + str(results['PCK'])) return heatmap_var = train_params['heatmap_var'] for epoch in range(logger.epoch, train_params['num_epochs']): results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset']) results_train = evaluate(model_kp_detector, loader, dset=train_params['dataset']) print('Epoch ' + str(epoch)+ ' MSE: ' + str(results['MSE'])) logger.add_scalar('MSE test', results['MSE'], epoch) logger.add_scalar('PCK test', results['PCK'], epoch) logger.add_scalar('MSE train', results_train['MSE'], epoch) logger.add_scalar('PCK train', results_train['PCK'], epoch) for i, batch in enumerate(tqdm(loader)): images = batch['imgs'] if (images != images).sum() > 0: print('Images has NaN') break annots = batch['annots'] gt_heatmaps = kp2gaussian2(annots, (model_kp_detector.heatmap_res, model_kp_detector.heatmap_res), heatmap_var).detach() if (annots != annots).sum() > 0 or (annots.abs() == float("Inf")).sum() > 0: print('Annotation with NaN') break mask = None if 'kp_mask' not in batch.keys() else batch['kp_mask'] ######## REMOVE #print(f"b_mask {mask}") #print(f"mask {mask.shape}") ################## #kp_detector_out = kp_detector(images, annots, mask) kp_detector_out = kp_detector(images, gt_heatmaps, mask) loss = kp_detector_out['l2_loss'].mean() loss.backward() optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() ####### LOG VALIDATION if i % log_params['eval_frequency'] == 0: tgt_batch = next(iter(loader_tgt)) eval_out = eval_model(kp_detector, tgt_batch, model_kp_detector.heatmap_res, heatmap_var) eval_sz = int(len(loader)/log_params['eval_frequency']) it_number = epoch * eval_sz + (logger.iterations/log_params['eval_frequency']) logger.add_scalar('Eval loss', eval_out['l2_loss'].mean(), it_number) concat_img = np.concatenate((draw_kp(tensor_to_image(tgt_batch['imgs'][k]),unnorm_kp(tgt_batch['annots'][k])), draw_kp(tensor_to_image(tgt_batch['imgs'][k]), eval_out['keypoints'][k], color='red')), axis=2) heatmap_img_0 = tensor_to_image(kp_detector_out['heatmaps'][k, 0].unsqueeze(0), True) heatmap_img_1 = tensor_to_image(kp_detector_out['heatmaps'][k, 5].unsqueeze(0), True) src_heatmap_0 = tensor_to_image(gt_heatmaps[k, 0].unsqueeze(0), True) src_heatmap_1 = tensor_to_image(gt_heatmaps[k, 5].unsqueeze(0), True) heatmaps_img = np.concatenate((heatmap_img_0, heatmap_img_1), axis = 2) src_heatmaps = np.concatenate((src_heatmap_0, src_heatmap_1), axis = 2) logger.add_image('Eval_', concat_img, logger.iterations) logger.add_image('heatmaps', heatmaps_img, logger.iterations) logger.add_image('src heatmaps', src_heatmaps, logger.iterations) ####### LOG logger.add_scalar('L2 loss', loss.item(), logger.iterations) if i in log_params['log_imgs']: concat_img_train = np.concatenate((draw_kp(tensor_to_image(images[k]), unnorm_kp(annots[k])), draw_kp(tensor_to_image(images[k]), kp_detector_out['keypoints'][k], color='red')), axis=2) logger.add_image('Train_{%d}' % i, concat_img_train, logger.iterations) k += 1 k = k % len(log_params['log_imgs']) logger.step_it() scheduler_kp_detector.step() logger.step_epoch(models = {'model_kp_detector':model_kp_detector, 'optimizer_kp_detector':optimizer_kp_detector})
def train(config, generator, region_predictor, bg_predictor, checkpoint, log_dir, dataset, device_ids): train_params = config['train_params'] optimizer = torch.optim.Adam(list(generator.parameters()) + list(region_predictor.parameters()) + list(bg_predictor.parameters()), lr=train_params['lr'], betas=(0.5, 0.999)) if checkpoint is not None: start_epoch = Logger.load_cpk(checkpoint, generator, region_predictor, bg_predictor, None, optimizer, None) else: start_epoch = 0 scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=train_params['dataloader_workers'], drop_last=True) model = ReconstructionModel(region_predictor, bg_predictor, generator, train_params) if torch.cuda.is_available(): if ('use_sync_bn' in train_params) and train_params['use_sync_bn']: model = DataParallelWithCallback(model, device_ids=device_ids) else: model = torch.nn.DataParallel(model, device_ids=device_ids) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for x in dataloader: losses, generated = model(x) loss_values = [val.mean() for val in losses.values()] loss = sum(loss_values) loss.backward() optimizer.step() optimizer.zero_grad() losses = { key: value.mean().detach().data.cpu().numpy() for key, value in losses.items() } logger.log_iter(losses=losses) scheduler.step() logger.log_epoch(epoch, { 'generator': generator, 'bg_predictor': bg_predictor, 'region_predictor': region_predictor, 'optimizer_reconstruction': optimizer }, inp=x, out=generated)
def reconstruction(config, generator, mask_generator, checkpoint, log_dir, dataset): png_dir = os.path.join(log_dir, 'reconstruction/png') log_dir = os.path.join(log_dir, 'reconstruction') if checkpoint is not None: epoch = Logger.load_cpk(checkpoint, generator=generator, mask_generator=mask_generator) print('checkpoint:' + str(epoch)) else: raise AttributeError( "Checkpoint should be specified for mode='reconstruction'.") dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(png_dir): os.makedirs(png_dir) loss_list = [] if torch.cuda.is_available(): generator = DataParallelWithCallback(generator) mask_generator = DataParallelWithCallback(mask_generator) generator.eval() mask_generator.eval() recon_gen_dir = './log/recon_gen' os.makedirs(recon_gen_dir, exist_ok=False) for it, x in tqdm(enumerate(dataloader)): if config['reconstruction_params']['num_videos'] is not None: if it > config['reconstruction_params']['num_videos']: break with torch.no_grad(): predictions = [] visualizations = [] if torch.cuda.is_available(): x['video'] = x['video'].cuda() mask_source = mask_generator(x['video'][:, :, 0]) video_gen_dir = recon_gen_dir + '/' + x['name'][0] os.makedirs(video_gen_dir, exist_ok=False) for frame_idx in range(x['video'].shape[2]): source = x['video'][:, :, 0] driving = x['video'][:, :, frame_idx] mask_driving = mask_generator(driving) out = generator(source, driving, mask_source=mask_source, mask_driving=mask_driving, mask_driving2=None, animate=False, predict_mask=False) out['mask_source'] = mask_source out['mask_driving'] = mask_driving predictions.append( np.transpose( out['second_phase_prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) visualization = Visualizer( **config['visualizer_params']).visualize(source=source, driving=driving, target=None, out=out, driving2=None) visualizations.append(visualization) loss_list.append( torch.abs(out['second_phase_prediction'] - driving).mean().cpu().numpy()) frame_name = str(frame_idx).zfill(7) + '.png' second_phase_prediction = out[ 'second_phase_prediction'].data.cpu().numpy() second_phase_prediction = np.transpose(second_phase_prediction, [0, 2, 3, 1]) second_phase_prediction = (255 * second_phase_prediction).astype( np.uint8) imageio.imsave(os.path.join(video_gen_dir, frame_name), second_phase_prediction[0]) predictions = np.concatenate(predictions, axis=1) image_name = x['name'][0] + config['reconstruction_params'][ 'format'] imageio.mimsave(os.path.join(log_dir, image_name), visualizations) print("Reconstruction loss: %s" % np.mean(loss_list))