示例#1
0
    def forward(self, images, filter=None, gt=None):
        pred_tgt = self.generator(images)
        kps = pred_tgt['value'] if filter is None else pred_tgt['value'][:,
                                                                         filter]
        generated_skeleton = self.kp_to_skl(kps).unsqueeze(1)
        generated_image = generated_skeleton
        masked_images = images
        #fake_predictions = self.discriminator(generated_skeleton)
        heatmaps = pred_tgt['heatmaps'] if filter is None else pred_tgt[
            'heatmaps'][:, filter]
        if self.do_inpaint:
            generated_image, masked_images = self.inpaint(
                images, generated_skeleton)
        elif self.do_recover:
            heatmaps_ = kp2gaussian2(
                gaussian2kp(heatmaps)['mean'], (122, 122), 0.5)
            if gt is not None:
                heatmaps_ = kp2gaussian2(gt, (122, 122), 0.5)
            recover_out = self.recover_transform(heatmaps_, images)
            generated_image = recover_out['reconstruction']
            masked_images = recover_out['transformed_input']
            hint = recover_out['hint']
        #generated_skeleton = heatmaps
        fake_predictions = self.discriminator(heatmaps)
        gen_loss = []
        for i, map_generated in enumerate(fake_predictions):
            gen_loss.append(
                generator_gan_loss(discriminator_maps_generated=map_generated,
                                   weight=self.train_params['loss_weights']
                                   ['generator_gan']).mean())

        return {
            "kps":
            kps,
            "heatmaps":
            heatmaps,
            "generator_loss":
            sum(gen_loss) / len(gen_loss),
            "inpaint_loss":
            0 if self.do_inpaint is False else self.perceptual_loss(
                images, generated_image),
            "recover_loss":
            0 if self.do_recover is False else recover_out['perceptual_loss'],
            "recover_image":
            None
            if self.do_recover is False else recover_out['reconstruction'],
            "image":
            generated_image,
            "masked_images":
            masked_images,
            "hint":
            None if self.do_recover is False else recover_out['hint'],
        }
示例#2
0
def eval_model(model, tgt_batch, heatmap_res=122):
    model.eval()
    images = tgt_batch['imgs']
    annots = tgt_batch['annots']
    gt_heatmaps = kp2gaussian2(annots, (heatmap_res, heatmap_res), 0.5)
    mask = None if 'kp_mask' not in tgt_batch.keys() else tgt_batch['kp_mask']
    out = None
    with torch.no_grad():
        out = model(images, gt_heatmaps, mask)
        #out = model(images, annots, mask)
    return out
示例#3
0
def train_generator_geo(model_generator,
                        discriminator,
                        conditional_generator,
                        model_kp_to_skl,
                        loader_src,
                        loader_tgt,
                        loader_test,
                        train_params,
                        checkpoint,
                        logger,
                        device_ids,
                        kp_map=None):
    log_params = train_params['log_params']
    optimizer_generator = torch.optim.Adam(model_generator.parameters(),
                                           lr=train_params['lr'],
                                           betas=train_params['betas'])
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=train_params['lr'],
                                               betas=train_params['betas'])
    if conditional_generator is not None:
        optimizer_conditional_generator = torch.optim.Adam(
            conditional_generator.parameters(),
            lr=train_params['lr'],
            betas=train_params['betas'])

    resume_epoch = 0
    resume_iteration = 0
    if checkpoint is not None:
        print('Loading Checkpoint: %s' % checkpoint)
        # TODO: Implement Load/resumo kp_detector
        resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint(
            checkpoint,
            model_generator=model_generator,
            optimizer_generator=optimizer_generator,
            optimizer_discriminator=optimizer_discriminator,
            model_discriminator=discriminator)
        logger.epoch = resume_epoch
        logger.iterations = resume_iteration
    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=logger.epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator,
                                          train_params['epoch_milestones'],
                                          gamma=0.1,
                                          last_epoch=logger.epoch - 1)
    if conditional_generator is not None:
        scheduler_conditional_generator = MultiStepLR(
            optimizer_conditional_generator,
            train_params['epoch_milestones'],
            gamma=0.1,
            last_epoch=logger.epoch - 1)
    img_generator = GeneratorTrainer(
        model_generator,
        discriminator,
        model_kp_to_skl,
        train_params,
        conditional_generator,
        do_inpaint=train_params['do_inpaint'],
        do_recover=train_params['do_recover']).cuda()

    discriminatorModel = DiscriminatorTrainer(discriminator,
                                              train_params).cuda()
    k = 0
    iterator_source = iter(loader_src)
    source_model = copy.deepcopy(model_generator)
    for epoch in range(logger.epoch, train_params['num_epochs']):
        results = evaluate(model_generator, loader_test,
                           train_params['dataset'], kp_map)
        print('Epoch ' + str(epoch) + ' MSE: ' + str(results))
        logger.add_scalar('MSE test', results['MSE'], epoch)
        logger.add_scalar('PCK test', results['PCK'], epoch)
        if epoch >= 9:
            save_qualy(model_generator, source_model, loader_tgt, epoch,
                       train_params['dataset'], kp_map)
        if epoch > 11:
            return

        for i, tgt_batch in enumerate(tqdm(loader_tgt)):
            try:
                src_batch = next(iterator_source)
            except:
                iterator_source = iter(loader_src)
                src_batch = next(iterator_source)

            angle = random.randint(1, 359)
            src_annots = src_batch['annots'].cuda()
            src_images = kp2gaussian2(src_annots, (122, 122), 0.5)[:, kp_map]
            geo_src_images = kp2gaussian2(batch_kp_rotation(src_annots, angle),
                                          (122, 122), 0.5)[:, kp_map]

            #with torch.no_grad():
            #    src_images = model_kp_to_skl(src_annots[:, kp_map]).to('cuda').unsqueeze(1)

            tgt_images = tgt_batch['imgs'].cuda()
            tgt_gt = tgt_batch['annots'].cuda()
            tgt_gt_rot = batch_kp_rotation(tgt_gt, angle)
            geo_tgt_imgs = geo_transform(tgt_images, angle)

            pred_tgt = img_generator(tgt_images, kp_map)
            pred_rot_tgt = img_generator(geo_tgt_imgs, kp_map)
            #print('pred_tgt: ', pred_tgt['heatmaps'][0])
            #print('src_annots: ', src_images[0])

            geo_loss = geo_consistency(pred_tgt['heatmaps'],
                                       pred_rot_tgt['heatmaps'], geo_transform,
                                       geo_transform_inverse, angle)

            geo_term = geo_loss['t'] + geo_loss['t_inv']
            generator_term = pred_tgt['generator_loss'] + pred_rot_tgt[
                'generator_loss']
            geo_weight = train_params['loss_weights']['geometric']
            recover_loss = 0
            recover_loss = train_params['loss_weights']['recover'] * (
                pred_tgt['recover_loss'] + pred_rot_tgt['recover_loss'])
            inpaint_loss = 0
            inpaint_loss = train_params['loss_weights']['inpaint'] * pred_tgt[
                'inpaint_loss']
            loss = geo_weight * geo_term + (
                generator_term) + inpaint_loss + recover_loss
            loss.backward()

            optimizer_generator.step()
            if conditional_generator is not None:
                #print('optimizing inpainting')
                optimizer_conditional_generator.step()
                optimizer_conditional_generator.zero_grad()
            optimizer_generator.zero_grad()
            optimizer_discriminator.zero_grad()

            discriminator_no_rot_out = discriminatorModel(
                gt_image=src_images,
                generated_image=pred_tgt['heatmaps'].detach())
            discriminator_rot_out = discriminatorModel(
                gt_image=geo_src_images,
                generated_image=pred_rot_tgt['heatmaps'].detach())

            loss_disc = discriminator_no_rot_out['loss'].mean(
            ) + discriminator_rot_out['loss'].mean()
            #if discriminator_no_rot_out['loss'].mean() < 1e-3 and i > 5:
            #print('source min: ', torch.min(src_images[0]))
            #print('source max: ', torch.max(src_images[0]))
            #print('pred min: ', torch.min(pred_tgt['heatmaps'][0]))
            #print('pred max: ', torch.max(pred_tgt['heatmaps'][0]))

            #if discriminator_rot_out['loss'].mean() < 1e-3 and i > 5:
            #    print('rotation f****d!!')
            #    print('no rotation f****d!!')
            #    print('source min: ', torch.min(src_images[0]))
            #    print('source max: ', torch.max(src_images[0]))
            #    print('pred min: ', torch.min(pred_rot_tgt['heatmaps'][0]))
            #    print('pred max: ', torch.max(pred_rot_tgt['heatmaps'][0]))
            #    break

            loss_disc.backward()
            optimizer_discriminator.step()
            optimizer_discriminator.zero_grad()
            optimizer_generator.zero_grad()
            if conditional_generator is not None:
                optimizer_conditional_generator.zero_grad()

            logger.add_scalar("Losses", loss.item(), logger.iterations)
            logger.add_scalar("Disc Loss", loss_disc.item(), logger.iterations)
            logger.add_scalar("Gen Loss", pred_tgt['generator_loss'].item(),
                              logger.iterations)
            logger.add_scalar("Geo Loss", (geo_weight * geo_term).item(),
                              logger.iterations)
            logger.add_scalar("Inpaint Loss", inpaint_loss, logger.iterations)
            logger.add_scalar("Recover Loss", recover_loss, logger.iterations)
            scales = discriminator_no_rot_out['scales']
            if len(scales) < 2:
                scales.append(torch.Tensor([0.]))
                scales.append(torch.Tensor([0.]))
            logger.add_scalar("disc_scales/s1", scales[0].item(),
                              logger.iterations)
            logger.add_scalar("disc_scales/s2", scales[1].item(),
                              logger.iterations)
            logger.add_scalar("disc_scales/s3", scales[2].item(),
                              logger.iterations)

            ####### LOG VALIDATION
            if i % log_params['eval_frequency'] == 0 or i == 0:
                concat_img = np.concatenate(
                    (draw_kp(tensor_to_image(tgt_batch['imgs'][k]),
                             pred_tgt['kps'][k]),
                     draw_kp(tensor_to_image(geo_tgt_imgs[k]),
                             pred_rot_tgt['kps'][k])),
                    axis=2)
                #skeletons_img = np.concatenate((draw_kp(tensor_to_image(pred_tgt['image'][k]), pred_tgt['kps'][k]),
                #                                draw_kp(tensor_to_image(pred_rot_tgt['image'][k]), pred_rot_tgt['kps'][k])),
                #                                axis=2)
                skeletons_img = np.concatenate(
                    (tensor_to_image(pred_tgt['image'][k]),
                     tensor_to_image(pred_rot_tgt['image'][k])),
                    axis=2)
                masked_img = np.concatenate(
                    (tensor_to_image(pred_tgt['masked_images'][k]),
                     tensor_to_image(pred_rot_tgt['masked_images'][k])),
                    axis=2)
                #heatmap_img_0 = np.concatenate((tensor_to_image(pred_tgt['heatmaps'][k, kp_map[0]].unsqueeze(0), True),
                #                              tensor_to_image(pred_rot_tgt['heatmaps'][k, kp_map[0]].unsqueeze(0), True)), axis=2)
                #heatmap_img_1 = np.concatenate((tensor_to_image(pred_tgt['heatmaps'][k, kp_map[5]].unsqueeze(0), True),
                #                              tensor_to_image(pred_rot_tgt['heatmaps'][k, kp_map[5]].unsqueeze(0), True)), axis=2)
                heatmap_img_0 = np.concatenate(
                    (tensor_to_image(pred_tgt['heatmaps'][k, 0].unsqueeze(0),
                                     True),
                     tensor_to_image(
                         pred_rot_tgt['heatmaps'][k, 0].unsqueeze(0), True)),
                    axis=2)
                heatmap_img_1 = np.concatenate(
                    (tensor_to_image(pred_tgt['heatmaps'][k, 5].unsqueeze(0),
                                     True),
                     tensor_to_image(
                         pred_rot_tgt['heatmaps'][k, 5].unsqueeze(0), True)),
                    axis=2)
                src_heatmap_0 = np.concatenate(
                    (tensor_to_image(src_images[k, 0].unsqueeze(0), True),
                     tensor_to_image(geo_src_images[k, 0].unsqueeze(0), True)),
                    axis=2)
                src_heatmap_1 = np.concatenate(
                    (tensor_to_image(src_images[k, 5].unsqueeze(0), True),
                     tensor_to_image(geo_src_images[k, 5].unsqueeze(0), True)),
                    axis=2)

                #inpainted_img = np.concatenate((tensor_to_image(pred_tgt['masked_input'][k]),
                #                              tensor_to_image(pred_tgt['inpainted_img']), axis=2)
                #inpainted_rot_img = np.concatenate((tensor_to_image(pred_rot_tgt['masked_input'][k]),
                #                              tensor_to_image(pred_rot_tgt['inpainted_img']), axis=2)

                image = np.concatenate((concat_img, skeletons_img, masked_img),
                                       axis=1)
                heatmaps_img = np.concatenate((heatmap_img_0, heatmap_img_1),
                                              axis=1)
                src_heatmaps = np.concatenate((src_heatmap_0, src_heatmap_1),
                                              axis=1)
                #inpaint_vis = np.concatenate((inpainted_img, inpainted_rot_img), axis=1)
                logger.add_image('Pose', image, epoch)
                logger.add_image('heatmaps', heatmaps_img, epoch)
                logger.add_image('src heatmaps', src_heatmaps, epoch)
                #logger.add_image('Reconstruction', inpaint_vis, epoch)
                k += 1
                k = k % len(log_params['log_imgs'])

            ####### LOG
            logger.step_it()

        scheduler_generator.step()
        logger.step_epoch(
            models={
                'model_kp_detector': model_generator,
                'optimizer_kp_detector': optimizer_generator
            })
def train_kpdetector(model_kp_detector,
                       loader,
                       loader_tgt,
                       train_params,
                       checkpoint,
                       logger, device_ids, tgt_batch=None, kp_map=None):
    log_params = train_params['log_params']
    optimizer_kp_detector = torch.optim.Adam(model_kp_detector.parameters(),
                                            lr=train_params['lr'],
                                            betas=train_params['betas'])
    resume_epoch = 0
    resume_iteration = 0
    if checkpoint is not None:
        print('Loading Checkpoint: %s' % checkpoint)
        # TODO: Implement Load/resumo kp_detector
        if train_params['test'] == False:
            resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint(checkpoint,
                                                  model_kp_detector=model_kp_detector,
                                                  optimizer_kp_detector=optimizer_kp_detector)
        else:
            net_dict = model_kp_detector.state_dict()
            pretrained_dict = torch.load(checkpoint)
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict)}
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict[k].shape == net_dict[k].shape}
            net_dict.update(pretrained_dict)
            model_kp_detector.load_state_dict(net_dict, strict=True)
            model_kp_detector.apply(convertLayer)
 
        logger.epoch = resume_epoch
        logger.iterations = resume_iteration
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, 
                                       train_params['epoch_milestones'], 
                                       gamma=0.1, last_epoch=logger.epoch-1)
    
    kp_detector = KPDetectorTrainer(model_kp_detector)
    kp_detector = DataParallelWithCallback(kp_detector, device_ids=device_ids)
    k = 0
    if train_params['test'] == True:
        results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset'])
        print(' MSE: ' + str(results['MSE']) + ' PCK: ' + str(results['PCK'])) 
        return

    heatmap_var = train_params['heatmap_var']

    for epoch in range(logger.epoch, train_params['num_epochs']):
        results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset'])
        results_train = evaluate(model_kp_detector, loader, dset=train_params['dataset']) 
        print('Epoch ' + str(epoch)+ ' MSE: ' + str(results['MSE']))
        logger.add_scalar('MSE test', results['MSE'], epoch)
        logger.add_scalar('PCK test', results['PCK'], epoch)
        logger.add_scalar('MSE train', results_train['MSE'], epoch)
        logger.add_scalar('PCK train', results_train['PCK'], epoch)
 
        for i, batch  in enumerate(tqdm(loader)):
            images = batch['imgs']
            if (images != images).sum() > 0:
                print('Images has NaN')
                break
            annots = batch['annots'] 
            gt_heatmaps = kp2gaussian2(annots, (model_kp_detector.heatmap_res, 
                                                model_kp_detector.heatmap_res), heatmap_var).detach() 
            if (annots != annots).sum() > 0 or (annots.abs() == float("Inf")).sum() > 0:
                print('Annotation with NaN')
                break
            mask = None if 'kp_mask' not in batch.keys() else batch['kp_mask']
            ######## REMOVE
            #print(f"b_mask {mask}")
            #print(f"mask {mask.shape}")
            ##################
            #kp_detector_out = kp_detector(images, annots, mask)
            kp_detector_out = kp_detector(images, gt_heatmaps, mask)

            loss = kp_detector_out['l2_loss'].mean()
            loss.backward()

            optimizer_kp_detector.step()
            optimizer_kp_detector.zero_grad()
            ####### LOG VALIDATION
            if i % log_params['eval_frequency'] == 0:
                tgt_batch = next(iter(loader_tgt))
                eval_out = eval_model(kp_detector, tgt_batch, model_kp_detector.heatmap_res, heatmap_var)
                eval_sz = int(len(loader)/log_params['eval_frequency'])
                it_number = epoch * eval_sz  + (logger.iterations/log_params['eval_frequency'])
                logger.add_scalar('Eval loss', eval_out['l2_loss'].mean(), it_number)
                concat_img = np.concatenate((draw_kp(tensor_to_image(tgt_batch['imgs'][k]),unnorm_kp(tgt_batch['annots'][k])),
                                            draw_kp(tensor_to_image(tgt_batch['imgs'][k]), eval_out['keypoints'][k], color='red')), axis=2)

                heatmap_img_0 = tensor_to_image(kp_detector_out['heatmaps'][k, 0].unsqueeze(0), True)
                heatmap_img_1 = tensor_to_image(kp_detector_out['heatmaps'][k, 5].unsqueeze(0), True)
                src_heatmap_0 = tensor_to_image(gt_heatmaps[k, 0].unsqueeze(0), True)
                src_heatmap_1 = tensor_to_image(gt_heatmaps[k, 5].unsqueeze(0), True)
                heatmaps_img = np.concatenate((heatmap_img_0, heatmap_img_1), axis = 2)
                src_heatmaps = np.concatenate((src_heatmap_0, src_heatmap_1), axis = 2)
 
                logger.add_image('Eval_', concat_img, logger.iterations)
                logger.add_image('heatmaps', heatmaps_img, logger.iterations)
                logger.add_image('src heatmaps', src_heatmaps, logger.iterations)
 
            ####### LOG
            logger.add_scalar('L2 loss', 
                               loss.item(), 
                               logger.iterations)
            if i in log_params['log_imgs']:
                concat_img_train = np.concatenate((draw_kp(tensor_to_image(images[k]), unnorm_kp(annots[k])),
                                                  draw_kp(tensor_to_image(images[k]), kp_detector_out['keypoints'][k], color='red')), axis=2)
 
                logger.add_image('Train_{%d}' % i, concat_img_train, logger.iterations)
                k += 1
                k = k % len(log_params['log_imgs']) 
            logger.step_it()

        scheduler_kp_detector.step()
        logger.step_epoch(models = {'model_kp_detector':model_kp_detector,
                                    'optimizer_kp_detector':optimizer_kp_detector})