def forward(self, images, filter=None, gt=None): pred_tgt = self.generator(images) kps = pred_tgt['value'] if filter is None else pred_tgt['value'][:, filter] generated_skeleton = self.kp_to_skl(kps).unsqueeze(1) generated_image = generated_skeleton masked_images = images #fake_predictions = self.discriminator(generated_skeleton) heatmaps = pred_tgt['heatmaps'] if filter is None else pred_tgt[ 'heatmaps'][:, filter] if self.do_inpaint: generated_image, masked_images = self.inpaint( images, generated_skeleton) elif self.do_recover: heatmaps_ = kp2gaussian2( gaussian2kp(heatmaps)['mean'], (122, 122), 0.5) if gt is not None: heatmaps_ = kp2gaussian2(gt, (122, 122), 0.5) recover_out = self.recover_transform(heatmaps_, images) generated_image = recover_out['reconstruction'] masked_images = recover_out['transformed_input'] hint = recover_out['hint'] #generated_skeleton = heatmaps fake_predictions = self.discriminator(heatmaps) gen_loss = [] for i, map_generated in enumerate(fake_predictions): gen_loss.append( generator_gan_loss(discriminator_maps_generated=map_generated, weight=self.train_params['loss_weights'] ['generator_gan']).mean()) return { "kps": kps, "heatmaps": heatmaps, "generator_loss": sum(gen_loss) / len(gen_loss), "inpaint_loss": 0 if self.do_inpaint is False else self.perceptual_loss( images, generated_image), "recover_loss": 0 if self.do_recover is False else recover_out['perceptual_loss'], "recover_image": None if self.do_recover is False else recover_out['reconstruction'], "image": generated_image, "masked_images": masked_images, "hint": None if self.do_recover is False else recover_out['hint'], }
def eval_model(model, tgt_batch, heatmap_res=122): model.eval() images = tgt_batch['imgs'] annots = tgt_batch['annots'] gt_heatmaps = kp2gaussian2(annots, (heatmap_res, heatmap_res), 0.5) mask = None if 'kp_mask' not in tgt_batch.keys() else tgt_batch['kp_mask'] out = None with torch.no_grad(): out = model(images, gt_heatmaps, mask) #out = model(images, annots, mask) return out
def train_generator_geo(model_generator, discriminator, conditional_generator, model_kp_to_skl, loader_src, loader_tgt, loader_test, train_params, checkpoint, logger, device_ids, kp_map=None): log_params = train_params['log_params'] optimizer_generator = torch.optim.Adam(model_generator.parameters(), lr=train_params['lr'], betas=train_params['betas']) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr'], betas=train_params['betas']) if conditional_generator is not None: optimizer_conditional_generator = torch.optim.Adam( conditional_generator.parameters(), lr=train_params['lr'], betas=train_params['betas']) resume_epoch = 0 resume_iteration = 0 if checkpoint is not None: print('Loading Checkpoint: %s' % checkpoint) # TODO: Implement Load/resumo kp_detector resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint( checkpoint, model_generator=model_generator, optimizer_generator=optimizer_generator, optimizer_discriminator=optimizer_discriminator, model_discriminator=discriminator) logger.epoch = resume_epoch logger.iterations = resume_iteration scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=logger.epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=logger.epoch - 1) if conditional_generator is not None: scheduler_conditional_generator = MultiStepLR( optimizer_conditional_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=logger.epoch - 1) img_generator = GeneratorTrainer( model_generator, discriminator, model_kp_to_skl, train_params, conditional_generator, do_inpaint=train_params['do_inpaint'], do_recover=train_params['do_recover']).cuda() discriminatorModel = DiscriminatorTrainer(discriminator, train_params).cuda() k = 0 iterator_source = iter(loader_src) source_model = copy.deepcopy(model_generator) for epoch in range(logger.epoch, train_params['num_epochs']): results = evaluate(model_generator, loader_test, train_params['dataset'], kp_map) print('Epoch ' + str(epoch) + ' MSE: ' + str(results)) logger.add_scalar('MSE test', results['MSE'], epoch) logger.add_scalar('PCK test', results['PCK'], epoch) if epoch >= 9: save_qualy(model_generator, source_model, loader_tgt, epoch, train_params['dataset'], kp_map) if epoch > 11: return for i, tgt_batch in enumerate(tqdm(loader_tgt)): try: src_batch = next(iterator_source) except: iterator_source = iter(loader_src) src_batch = next(iterator_source) angle = random.randint(1, 359) src_annots = src_batch['annots'].cuda() src_images = kp2gaussian2(src_annots, (122, 122), 0.5)[:, kp_map] geo_src_images = kp2gaussian2(batch_kp_rotation(src_annots, angle), (122, 122), 0.5)[:, kp_map] #with torch.no_grad(): # src_images = model_kp_to_skl(src_annots[:, kp_map]).to('cuda').unsqueeze(1) tgt_images = tgt_batch['imgs'].cuda() tgt_gt = tgt_batch['annots'].cuda() tgt_gt_rot = batch_kp_rotation(tgt_gt, angle) geo_tgt_imgs = geo_transform(tgt_images, angle) pred_tgt = img_generator(tgt_images, kp_map) pred_rot_tgt = img_generator(geo_tgt_imgs, kp_map) #print('pred_tgt: ', pred_tgt['heatmaps'][0]) #print('src_annots: ', src_images[0]) geo_loss = geo_consistency(pred_tgt['heatmaps'], pred_rot_tgt['heatmaps'], geo_transform, geo_transform_inverse, angle) geo_term = geo_loss['t'] + geo_loss['t_inv'] generator_term = pred_tgt['generator_loss'] + pred_rot_tgt[ 'generator_loss'] geo_weight = train_params['loss_weights']['geometric'] recover_loss = 0 recover_loss = train_params['loss_weights']['recover'] * ( pred_tgt['recover_loss'] + pred_rot_tgt['recover_loss']) inpaint_loss = 0 inpaint_loss = train_params['loss_weights']['inpaint'] * pred_tgt[ 'inpaint_loss'] loss = geo_weight * geo_term + ( generator_term) + inpaint_loss + recover_loss loss.backward() optimizer_generator.step() if conditional_generator is not None: #print('optimizing inpainting') optimizer_conditional_generator.step() optimizer_conditional_generator.zero_grad() optimizer_generator.zero_grad() optimizer_discriminator.zero_grad() discriminator_no_rot_out = discriminatorModel( gt_image=src_images, generated_image=pred_tgt['heatmaps'].detach()) discriminator_rot_out = discriminatorModel( gt_image=geo_src_images, generated_image=pred_rot_tgt['heatmaps'].detach()) loss_disc = discriminator_no_rot_out['loss'].mean( ) + discriminator_rot_out['loss'].mean() #if discriminator_no_rot_out['loss'].mean() < 1e-3 and i > 5: #print('source min: ', torch.min(src_images[0])) #print('source max: ', torch.max(src_images[0])) #print('pred min: ', torch.min(pred_tgt['heatmaps'][0])) #print('pred max: ', torch.max(pred_tgt['heatmaps'][0])) #if discriminator_rot_out['loss'].mean() < 1e-3 and i > 5: # print('rotation f****d!!') # print('no rotation f****d!!') # print('source min: ', torch.min(src_images[0])) # print('source max: ', torch.max(src_images[0])) # print('pred min: ', torch.min(pred_rot_tgt['heatmaps'][0])) # print('pred max: ', torch.max(pred_rot_tgt['heatmaps'][0])) # break loss_disc.backward() optimizer_discriminator.step() optimizer_discriminator.zero_grad() optimizer_generator.zero_grad() if conditional_generator is not None: optimizer_conditional_generator.zero_grad() logger.add_scalar("Losses", loss.item(), logger.iterations) logger.add_scalar("Disc Loss", loss_disc.item(), logger.iterations) logger.add_scalar("Gen Loss", pred_tgt['generator_loss'].item(), logger.iterations) logger.add_scalar("Geo Loss", (geo_weight * geo_term).item(), logger.iterations) logger.add_scalar("Inpaint Loss", inpaint_loss, logger.iterations) logger.add_scalar("Recover Loss", recover_loss, logger.iterations) scales = discriminator_no_rot_out['scales'] if len(scales) < 2: scales.append(torch.Tensor([0.])) scales.append(torch.Tensor([0.])) logger.add_scalar("disc_scales/s1", scales[0].item(), logger.iterations) logger.add_scalar("disc_scales/s2", scales[1].item(), logger.iterations) logger.add_scalar("disc_scales/s3", scales[2].item(), logger.iterations) ####### LOG VALIDATION if i % log_params['eval_frequency'] == 0 or i == 0: concat_img = np.concatenate( (draw_kp(tensor_to_image(tgt_batch['imgs'][k]), pred_tgt['kps'][k]), draw_kp(tensor_to_image(geo_tgt_imgs[k]), pred_rot_tgt['kps'][k])), axis=2) #skeletons_img = np.concatenate((draw_kp(tensor_to_image(pred_tgt['image'][k]), pred_tgt['kps'][k]), # draw_kp(tensor_to_image(pred_rot_tgt['image'][k]), pred_rot_tgt['kps'][k])), # axis=2) skeletons_img = np.concatenate( (tensor_to_image(pred_tgt['image'][k]), tensor_to_image(pred_rot_tgt['image'][k])), axis=2) masked_img = np.concatenate( (tensor_to_image(pred_tgt['masked_images'][k]), tensor_to_image(pred_rot_tgt['masked_images'][k])), axis=2) #heatmap_img_0 = np.concatenate((tensor_to_image(pred_tgt['heatmaps'][k, kp_map[0]].unsqueeze(0), True), # tensor_to_image(pred_rot_tgt['heatmaps'][k, kp_map[0]].unsqueeze(0), True)), axis=2) #heatmap_img_1 = np.concatenate((tensor_to_image(pred_tgt['heatmaps'][k, kp_map[5]].unsqueeze(0), True), # tensor_to_image(pred_rot_tgt['heatmaps'][k, kp_map[5]].unsqueeze(0), True)), axis=2) heatmap_img_0 = np.concatenate( (tensor_to_image(pred_tgt['heatmaps'][k, 0].unsqueeze(0), True), tensor_to_image( pred_rot_tgt['heatmaps'][k, 0].unsqueeze(0), True)), axis=2) heatmap_img_1 = np.concatenate( (tensor_to_image(pred_tgt['heatmaps'][k, 5].unsqueeze(0), True), tensor_to_image( pred_rot_tgt['heatmaps'][k, 5].unsqueeze(0), True)), axis=2) src_heatmap_0 = np.concatenate( (tensor_to_image(src_images[k, 0].unsqueeze(0), True), tensor_to_image(geo_src_images[k, 0].unsqueeze(0), True)), axis=2) src_heatmap_1 = np.concatenate( (tensor_to_image(src_images[k, 5].unsqueeze(0), True), tensor_to_image(geo_src_images[k, 5].unsqueeze(0), True)), axis=2) #inpainted_img = np.concatenate((tensor_to_image(pred_tgt['masked_input'][k]), # tensor_to_image(pred_tgt['inpainted_img']), axis=2) #inpainted_rot_img = np.concatenate((tensor_to_image(pred_rot_tgt['masked_input'][k]), # tensor_to_image(pred_rot_tgt['inpainted_img']), axis=2) image = np.concatenate((concat_img, skeletons_img, masked_img), axis=1) heatmaps_img = np.concatenate((heatmap_img_0, heatmap_img_1), axis=1) src_heatmaps = np.concatenate((src_heatmap_0, src_heatmap_1), axis=1) #inpaint_vis = np.concatenate((inpainted_img, inpainted_rot_img), axis=1) logger.add_image('Pose', image, epoch) logger.add_image('heatmaps', heatmaps_img, epoch) logger.add_image('src heatmaps', src_heatmaps, epoch) #logger.add_image('Reconstruction', inpaint_vis, epoch) k += 1 k = k % len(log_params['log_imgs']) ####### LOG logger.step_it() scheduler_generator.step() logger.step_epoch( models={ 'model_kp_detector': model_generator, 'optimizer_kp_detector': optimizer_generator })
def train_kpdetector(model_kp_detector, loader, loader_tgt, train_params, checkpoint, logger, device_ids, tgt_batch=None, kp_map=None): log_params = train_params['log_params'] optimizer_kp_detector = torch.optim.Adam(model_kp_detector.parameters(), lr=train_params['lr'], betas=train_params['betas']) resume_epoch = 0 resume_iteration = 0 if checkpoint is not None: print('Loading Checkpoint: %s' % checkpoint) # TODO: Implement Load/resumo kp_detector if train_params['test'] == False: resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint(checkpoint, model_kp_detector=model_kp_detector, optimizer_kp_detector=optimizer_kp_detector) else: net_dict = model_kp_detector.state_dict() pretrained_dict = torch.load(checkpoint) pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict)} pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict[k].shape == net_dict[k].shape} net_dict.update(pretrained_dict) model_kp_detector.load_state_dict(net_dict, strict=True) model_kp_detector.apply(convertLayer) logger.epoch = resume_epoch logger.iterations = resume_iteration scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=logger.epoch-1) kp_detector = KPDetectorTrainer(model_kp_detector) kp_detector = DataParallelWithCallback(kp_detector, device_ids=device_ids) k = 0 if train_params['test'] == True: results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset']) print(' MSE: ' + str(results['MSE']) + ' PCK: ' + str(results['PCK'])) return heatmap_var = train_params['heatmap_var'] for epoch in range(logger.epoch, train_params['num_epochs']): results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset']) results_train = evaluate(model_kp_detector, loader, dset=train_params['dataset']) print('Epoch ' + str(epoch)+ ' MSE: ' + str(results['MSE'])) logger.add_scalar('MSE test', results['MSE'], epoch) logger.add_scalar('PCK test', results['PCK'], epoch) logger.add_scalar('MSE train', results_train['MSE'], epoch) logger.add_scalar('PCK train', results_train['PCK'], epoch) for i, batch in enumerate(tqdm(loader)): images = batch['imgs'] if (images != images).sum() > 0: print('Images has NaN') break annots = batch['annots'] gt_heatmaps = kp2gaussian2(annots, (model_kp_detector.heatmap_res, model_kp_detector.heatmap_res), heatmap_var).detach() if (annots != annots).sum() > 0 or (annots.abs() == float("Inf")).sum() > 0: print('Annotation with NaN') break mask = None if 'kp_mask' not in batch.keys() else batch['kp_mask'] ######## REMOVE #print(f"b_mask {mask}") #print(f"mask {mask.shape}") ################## #kp_detector_out = kp_detector(images, annots, mask) kp_detector_out = kp_detector(images, gt_heatmaps, mask) loss = kp_detector_out['l2_loss'].mean() loss.backward() optimizer_kp_detector.step() optimizer_kp_detector.zero_grad() ####### LOG VALIDATION if i % log_params['eval_frequency'] == 0: tgt_batch = next(iter(loader_tgt)) eval_out = eval_model(kp_detector, tgt_batch, model_kp_detector.heatmap_res, heatmap_var) eval_sz = int(len(loader)/log_params['eval_frequency']) it_number = epoch * eval_sz + (logger.iterations/log_params['eval_frequency']) logger.add_scalar('Eval loss', eval_out['l2_loss'].mean(), it_number) concat_img = np.concatenate((draw_kp(tensor_to_image(tgt_batch['imgs'][k]),unnorm_kp(tgt_batch['annots'][k])), draw_kp(tensor_to_image(tgt_batch['imgs'][k]), eval_out['keypoints'][k], color='red')), axis=2) heatmap_img_0 = tensor_to_image(kp_detector_out['heatmaps'][k, 0].unsqueeze(0), True) heatmap_img_1 = tensor_to_image(kp_detector_out['heatmaps'][k, 5].unsqueeze(0), True) src_heatmap_0 = tensor_to_image(gt_heatmaps[k, 0].unsqueeze(0), True) src_heatmap_1 = tensor_to_image(gt_heatmaps[k, 5].unsqueeze(0), True) heatmaps_img = np.concatenate((heatmap_img_0, heatmap_img_1), axis = 2) src_heatmaps = np.concatenate((src_heatmap_0, src_heatmap_1), axis = 2) logger.add_image('Eval_', concat_img, logger.iterations) logger.add_image('heatmaps', heatmaps_img, logger.iterations) logger.add_image('src heatmaps', src_heatmaps, logger.iterations) ####### LOG logger.add_scalar('L2 loss', loss.item(), logger.iterations) if i in log_params['log_imgs']: concat_img_train = np.concatenate((draw_kp(tensor_to_image(images[k]), unnorm_kp(annots[k])), draw_kp(tensor_to_image(images[k]), kp_detector_out['keypoints'][k], color='red')), axis=2) logger.add_image('Train_{%d}' % i, concat_img_train, logger.iterations) k += 1 k = k % len(log_params['log_imgs']) logger.step_it() scheduler_kp_detector.step() logger.step_epoch(models = {'model_kp_detector':model_kp_detector, 'optimizer_kp_detector':optimizer_kp_detector})