コード例 #1
0
ファイル: cam_fomm.py プロジェクト: dbandrews/avatarify
def load_checkpoints(config_path, checkpoint_path, device="cuda"):

    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    generator = OcclusionAwareGenerator(
        **config["model_params"]["generator_params"],
        **config["model_params"]["common_params"],
    )
    generator.to(device)

    kp_detector = KPDetector(
        **config["model_params"]["kp_detector_params"],
        **config["model_params"]["common_params"],
    )
    kp_detector.to(device)

    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint["generator"])
    kp_detector.load_state_dict(checkpoint["kp_detector"])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
コード例 #2
0
def transfer(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'transfer')
    png_dir = os.path.join(log_dir, 'png')
    transfer_params = config['transfer_params']

    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=transfer_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='transfer'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x = {
                key: value if not hasattr(value, 'cuda') else value.cuda()
                for key, value in x.items()
            }
            driving_video = x['driving_video']
            source_image = x['source_video'][:, :, :1, :, :]
            out = transfer_one(generator, kp_detector, source_image,
                               driving_video, transfer_params)
            img_name = "-".join([x['driving_name'][0], x['source_name'][0]])

            # Store to .png for evaluation
            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(
                out_video_batch, [0, 2, 3, 4, 1])[0],
                                             axis=1)
            imageio.imsave(os.path.join(png_dir, img_name + '.png'),
                           (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(
                **config['visualizer_params']).visualize_transfer(
                    driving_video=driving_video,
                    source_image=source_image,
                    out=out)
            imageio.mimsave(
                os.path.join(log_dir, img_name + transfer_params['format']),
                image)
コード例 #3
0
def load_checkpoints(config_path, checkpoint_path, cpu=False):
    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config["model_params"]["generator_params"],
        **config["model_params"]["common_params"],
    )
    if cpu:
        generator.cpu()
    else:
        generator.cuda()

    kp_detector = KPDetector(
        **config["model_params"]["kp_detector_params"],
        **config["model_params"]["common_params"],
    )
    if cpu:
        kp_detector.cpu()
    else:
        kp_detector.cuda()

    checkpoint = torch.load(checkpoint_path, map_location="cpu" if cpu else None)
    generator.load_state_dict(checkpoint["generator"])
    kp_detector.load_state_dict(checkpoint["kp_detector"])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
コード例 #4
0
def load_checkpoints(config_path, checkpoint_path, device='cuda'):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config['model_params']['generator_params'],
        **config['model_params']['common_params'])
    generator.to(device)

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    kp_detector.to(device)

    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
コード例 #5
0
def load_checkpoints(config, checkpoint, blend_scale=0.125, first_order_motion_model=False, cpu=False):
    with open(config) as f:
        config = yaml.load(f)

    reconstruction_module = PartSwapGenerator(blend_scale=blend_scale,
                                              first_order_motion_model=first_order_motion_model,
                                              **config['model_params']['reconstruction_module_params'],
                                              **config['model_params']['common_params'])

    if not cpu:
        reconstruction_module.cuda()

    segmentation_module = SegmentationModule(**config['model_params']['segmentation_module_params'],
                                             **config['model_params']['common_params'])
    if not cpu:
        segmentation_module.cuda()

    if cpu:
        checkpoint = torch.load(checkpoint, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint)

    load_reconstruction_module(reconstruction_module, checkpoint)
    load_segmentation_module(segmentation_module, checkpoint)

    if not cpu:
        reconstruction_module = DataParallelWithCallback(reconstruction_module)
        segmentation_module = DataParallelWithCallback(segmentation_module)

    reconstruction_module.eval()
    segmentation_module.eval()

    return reconstruction_module, segmentation_module
コード例 #6
0
def load_checkpoints(config_path, checkpoint_path, cpu=False):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not cpu:
        kp_detector.cuda()
    
    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)
 
    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    
    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()
    
    return generator, kp_detector
コード例 #7
0
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            kp_source = kp_detector(x['video'][:, :, 0])
            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                kp_driving = kp_detector(driving)
                out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
                out['kp_source'] = kp_source
                out['kp_driving'] = kp_driving
                del out['sparse_deformed']
                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

                visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
                                                                                    driving=driving, out=out)
                visualizations.append(visualization)

                loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())

            predictions = np.concatenate(predictions, axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))

            image_name = x['name'][0] + config['reconstruction_params']['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

    print("Reconstruction loss: %s" % np.mean(loss_list))
コード例 #8
0
    def __init__(self,
                 hidden_dim,
                 lr,
                 hard_or_full_trip,
                 margin,
                 num_workers,
                 batch_size,
                 restore_iter,
                 total_iter,
                 save_name,
                 train_pid_num,
                 frame_num,
                 model_name,
                 train_source,
                 test_source,
                 img_size=64):

        self.save_name = save_name
        self.train_pid_num = train_pid_num
        self.train_source = train_source
        self.test_source = test_source

        self.hidden_dim = hidden_dim
        self.lr = lr
        self.hard_or_full_trip = hard_or_full_trip
        self.margin = margin
        self.frame_num = frame_num
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.model_name = model_name
        self.P, self.M = batch_size

        self.restore_iter = restore_iter
        self.total_iter = total_iter

        self.img_size = img_size

        self.encoder = SetNet(self.hidden_dim).float()
        self.encoder = DataParallelWithCallback(self.encoder)
        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
        self.triplet_loss = DataParallelWithCallback(self.triplet_loss)
        self.encoder.cuda()
        self.triplet_loss.cuda()

        self.optimizer = optim.Adam([
            {'params': self.encoder.parameters()},
        ], lr=self.lr)

        self.hard_loss_metric = []
        self.full_loss_metric = []
        self.full_loss_num = []
        self.dist_list = []
        self.mean_dist = 0.01

        self.sample_type = 'all'
コード例 #9
0
 def cuda(self):
     self.model.cuda()
     if self.get_param('network.sync_bn', False):
         self.model = DataParallelWithCallback(self.model,
                                               dim=self.batch_axis)
     else:
         self.model = nn.DataParallel(self.model, dim=self.batch_axis)
コード例 #10
0
    def testSyncBatchNormSyncEval(self):
        bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
コード例 #11
0
    def testSyncBatchNorm2DSyncTrain(self):
        bn = nn.BatchNorm2d(10)
        sync_bn = SynchronizedBatchNorm2d(10)
        sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

        bn.cuda()
        sync_bn.cuda()

        self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
コード例 #12
0
def load_checkpoints(config_path, checkpoint_path, cpu=False):
    with open(config_path) as f:
        config = yaml.load(f)

    generator = Generator(num_regions=config['model_params']['num_regions'],
                          num_channels=config['model_params']['num_channels'],
                          **config['model_params']['generator_params'])
    if not cpu:
        generator.cuda()

    region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'],
                                       num_channels=config['model_params']['num_channels'],
                                       estimate_affine=config['model_params']['estimate_affine'],
                                       **config['model_params']['region_predictor_params'])
    if not cpu:
        region_predictor.cuda()

    avd_network = AVDNetwork(num_regions=config['model_params']['num_regions'],
                             **config['model_params']['avd_network_params'])
    if not cpu:
        avd_network.cuda()

    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)

    generator.load_state_dict(checkpoint['generator'])
    region_predictor.load_state_dict(checkpoint['region_predictor'])
    if 'avd_network' in checkpoint:
        avd_network.load_state_dict(checkpoint['avd_network'])

    if not cpu:
        generator = DataParallelWithCallback(generator)
        region_predictor = DataParallelWithCallback(region_predictor)
        avd_network = DataParallelWithCallback(avd_network)

    generator.eval()
    region_predictor.eval()
    avd_network.eval()

    return generator, region_predictor, avd_network
コード例 #13
0
def load_checkpoints(kp_detector_path, generator_path, cpu=False):

    generator = torch.jit.load(generator_path)

    kp_detector = torch.jit.load(kp_detector_path)

    if not cpu:
        generator.cuda()

    if not cpu:
        kp_detector.cuda()

    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
コード例 #14
0
    def restore(self, checkpoint):
        self.epoch = 0

        self.generator = DCGenerator(**self.config['generator_params'])
        self.generator = DataParallelWithCallback(self.generator,
                                                  device_ids=self.device_ids)
        self.optimizer_generator = torch.optim.Adam(
            params=self.generator.parameters(),
            lr=self.config['lr_generator'],
            betas=(self.config['b1_generator'], self.config['b2_generator']),
            weight_decay=0,
            eps=1e-8)

        self.discriminator = DCDiscriminator(
            **self.config['discriminator_params'])
        self.discriminator = DataParallelWithCallback(
            self.discriminator, device_ids=self.device_ids)
        self.optimizer_discriminator = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(self.config['b1_discriminator'],
                   self.config['b2_discriminator']),
            weight_decay=0,
            eps=1e-8)

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data:
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs']
        self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1)
コード例 #15
0
    def __init__(self, opt):
        self.opt = opt
        self.pix2pix_model = Pix2PixModel(opt)
        if len(opt.gpu_ids) > 0:
            self.pix2pix_model = DataParallelWithCallback(
                self.pix2pix_model, device_ids=opt.gpu_ids)
            self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
        else:
            self.pix2pix_model_on_one_gpu = self.pix2pix_model

        self.generated = None
        if opt.isTrain:
            self.optimizer_G, self.optimizer_D = \
                self.pix2pix_model_on_one_gpu.create_optimizers(opt)
            self.old_lr = opt.lr
コード例 #16
0
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """
    if len(gpu_ids) > 0:
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = DataParallelWithCallback(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net
コード例 #17
0
def main():
    net = Baseline(num_classes=culane.num_classes,
                   deep_base=args['deep_base']).cuda().train()
    net = DataParallelWithCallback(net)

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['base_lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['base_lr']
    }],
                          momentum=args['momentum'])

    if len(args['checkpoint']) > 0:
        print('training resumes from \'%s\'' % args['checkpoint'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['checkpoint'] + '_checkpoint.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['checkpoint'] + '_checkpoint_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['base_lr']
        optimizer.param_groups[1]['lr'] = args['base_lr']

    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')

    train(net, optimizer)
コード例 #18
0
def debug_generator(generator, kp_to_skl_gt, loader, train_params, 
                     logger, device_ids, tgt_batch=None):
    log_params = train_params['log_params']
    genModel = ConditionalGenerator2D(generator, train_params)
    genModel = DataParallelWithCallback(genModel, device_ids=device_ids)

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                            lr=train_params['lr'],
                                            betas=train_params['betas'])
    scheduler_generator = MultiStepLR(optimizer_generator, 
                                       train_params['epoch_milestones'], 
                                       gamma=0.1, last_epoch=-1)
 
    k=0
    train_views = [0,1,3]
    eval_views = [2]
    if tgt_batch is not None:
        tgt_batch_samples = split_data(tgt_batch, 
                                       train_views=train_views, 
                                       eval_views=eval_views)
        with torch.no_grad():
            tgt_batch_samples['gt_skl'] = kp_to_skl_gt(tgt_batch_samples['kps'].to('cuda')).unsqueeze(1)
            tgt_batch_samples['gt_skl_eval'] = kp_to_skl_gt(tgt_batch_samples['kps_eval'].to('cuda')).unsqueeze(1)
        
    for epoch in range(train_params['num_epochs']):
        for i, batch  in enumerate(tqdm(loader)):
            batch_samples = split_data((img, annots, ref_img), 
                                         train_views=train_views, 
                                         eval_views=eval_views)
            #imgs = flatten_views(imgs)
            #ref_imgs = flatten_views(ref_imgs)
            #ref_imgs = torch.rand(*imgs.shape)
            with torch.no_grad():
                batch_samples['gt_skl'] = kp_to_skl_gt(batch_samples['kps'].to('cuda')).unsqueeze(1)
                #batch_samples['gt_skl_eval'] = kp_to_skl_gt(batch_samples['kps_eval'].to('cuda')).unsqueeze(1)
                #gt_skl = (kp_to_skl_gt(flatten_views(annots / (ref_img.shape[3] - 1)).to('cuda'))).unsqueeze(1)
            #gt_skl = torch.rand(imgs.shape[0], 1, *imgs.shape[2:])

            #generator_out = genModel(imgs, ref_imgs, gt_skl)
            generator_out = genModel(batch_samples['imgs'], batch_samples['ref_imgs'], batch_samples['gt_skl'])
            ##### Generator update
            #loss_generator = generator_out['loss']
            loss_generator = generator_out['perceptual_loss']
            loss_generator = [x.mean() for x in loss_generator]
            loss_gen = sum(loss_generator)
            loss_gen.backward(retain_graph=not train_params['detach_kp_discriminator'])
            optimizer_generator.step()
            optimizer_generator.zero_grad()

            ########### LOG
            logger.add_scalar("Generator Loss", 
                               loss_gen.item(), 
                               epoch * len(loader) + i + 1)
            if i in log_params['log_imgs']:
                if tgt_batch is not None:
                    with torch.no_grad():
                        genModel.eval()
                        generator_out_eval = genModel(tgt_batch_samples['imgs_eval'], 
                                                      tgt_batch_samples['ref_imgs_eval'],
                                                      tgt_batch_samples['gt_skl_eval'])
                        #generator_out_eval = genModel(batch_samples['imgs_eval'], 
                        #                              batch_samples['ref_imgs_eval'],
                        #                              batch_samples['gt_skl_eval'])
                        concat_img_eval = np.concatenate((tensor_to_image(tgt_batch_samples['imgs_eval'][k]), 
                                     tensor_to_image(tgt_batch_samples['gt_skl_eval'][k]), 
                                     tensor_to_image(tgt_batch_samples['ref_imgs_eval'][k]),
                                     tensor_to_image(generator_out_eval['reconstructred_image'][k])), axis=2)  # concat along width
                        logger.add_image('Sample_{%d}_EVAL' % i, concat_img_eval, epoch)
                        genModel.train()
                k += 1
                k = k % 4
                concat_img = np.concatenate((tensor_to_image(batch_samples['imgs'][k]), 
                             tensor_to_image(batch_samples['gt_skl'][k]), 
                             tensor_to_image(batch_samples['ref_imgs'][k]),
                             tensor_to_image(generator_out['reconstructred_image'][k])), axis=2)  # concat along width
                logger.add_image('Sample_{%d}' % i, concat_img, epoch)


        scheduler_generator.step()
コード例 #19
0
def debug_encoder(model_skeleton_to_keypoint,
                  model_keypoint_to_skeleton,
                  loader,
                  loader_tgt,
                  train_params,
                  checkpoint,
                  logger,
                  device_ids,
                  tgt_batch=None):
    log_params = train_params['log_params']
    optimizer_encoder = torch.optim.Adam(
        model_skeleton_to_keypoint.parameters(),
        lr=train_params['lr'],
        betas=train_params['betas'])
    resume_epoch = 0
    resume_iteration = 0
    if checkpoint is not None:
        print('Loading Checkpoint: %s' % checkpoint)
        resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint(
            checkpoint,
            skeleton_to_keypoints=model_skeleton_to_keypoint,
            optimizer_skeleton_to_keypoints=optimizer_encoder)
        logger.epoch = resume_epoch
        logger.iterations = resume_iteration

    scheduler_encoder = MultiStepLR(optimizer_encoder,
                                    train_params['epoch_milestones'],
                                    gamma=0.1,
                                    last_epoch=logger.epoch - 1)

    encoder = SkeletonToKeypoints(model_skeleton_to_keypoint,
                                  model_keypoint_to_skeleton)
    encoder = DataParallelWithCallback(encoder, device_ids=device_ids)

    with torch.no_grad():
        skeletons = model_keypoint_to_skeleton(
            tgt_batch['annots'].to('cuda')).unsqueeze(1).detach()
    tgt_batch = {
        'imgs': tgt_batch['imgs'],
        'annots': tgt_batch['annots'],
        'annots_unnormed': tgt_batch['annots'],
        'skeletons': skeletons
    }
    k = 0
    for epoch in range(logger.epoch, train_params['num_epochs']):
        for i, batch in enumerate(tqdm(loader)):
            annots = batch['annots']
            annots_gt = batch['annots_unnormed']

            with torch.no_grad():
                gt_skl = model_keypoint_to_skeleton(
                    annots.to('cuda')).unsqueeze(1).detach()
            encoder_out = encoder(gt_skl, annots)
            optimizer_encoder.zero_grad()
            loss = encoder_out['loss'].mean()
            loss.backward()

            optimizer_encoder.step()
            optimizer_encoder.zero_grad()
            ####### LOG VALIDATION
            if i % log_params['eval_frequency'] == 0:
                eval_loss = eval_model(encoder, next(iter(loader_tgt)),
                                       model_keypoint_to_skeleton)
                eval_sz = int(len(loader) / log_params['eval_frequency'])
                it_number = epoch * eval_sz + (logger.iterations /
                                               log_params['eval_frequency'])
                logger.add_scalar('Eval loss', eval_loss, it_number)

            ####### LOG
            logger.add_scalar('L2 loss', loss.item(), logger.iterations)
            if i in log_params['log_imgs']:
                with torch.no_grad():
                    encoder.eval()
                    target_out = encoder(tgt_batch['skeletons'],
                                         tgt_batch['annots'])
                    encoder.train()
                skl_out = target_out['reconstruction']
                kps_out = target_out['keypoints']

                concat_img = np.concatenate(
                    (draw_kp(tensor_to_image(tgt_batch['imgs'][k]),
                             tgt_batch['annots_unnormed'][k]),
                     tensor_to_image(tgt_batch['skeletons'][k]),
                     tensor_to_image(skl_out[k]),
                     draw_kp(tensor_to_image(tgt_batch['imgs'][k]),
                             kps_out[k],
                             color='red')),
                    axis=2)
                concat_img_train = np.concatenate(
                    (tensor_to_image(gt_skl[k]),
                     tensor_to_image(encoder_out['reconstruction'][k])),
                    axis=2)

                logger.add_image('Eval_{%d}' % i, concat_img, epoch)
                logger.add_image('Train_{%d}' % i, concat_img_train, epoch)
                k += 1
                k = k % len(log_params['log_imgs'])
            logger.step_it()

        scheduler_encoder.step()
        logger.step_epoch(
            models={
                'skeleton_to_keypoints': model_skeleton_to_keypoint,
                'optimizer_skeleton_to_keypoints': optimizer_encoder
            })
コード例 #20
0
ファイル: train.py プロジェクト: mangomadhava/donkey-net
def train(config,
          generator,
          discriminator,
          kp_detector,
          checkpoint,
          log_dir,
          dataset,
          device_ids,
          load_weights_only=False,
          use_both=False,
          update_kp=True):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr'],
                                           betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=train_params['lr'],
                                               betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(),
                                             lr=train_params['lr'],
                                             betas=(0.5, 0.999))

    if checkpoint is not None:

        if load_weights_only:
            Logger.load_cpk(checkpoint, generator, discriminator, kp_detector)
            start_epoch = 0
            it = 0
        else:
            saved_start_epoch, saved_it = Logger.load_cpk(
                checkpoint, generator, discriminator, kp_detector,
                optimizer_generator, optimizer_discriminator,
                optimizer_kp_detector)
            start_epoch = saved_start_epoch
            it = saved_it
    else:
        start_epoch = 0
        it = 0

    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator,
                                          train_params['epoch_milestones'],
                                          gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector,
                                        train_params['epoch_milestones'],
                                        gamma=0.1,
                                        last_epoch=start_epoch - 1)

    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=True)

    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)

    generator_full_par = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)
    discriminator_full_par = DataParallelWithCallback(discriminator_full,
                                                      device_ids=device_ids)

    if not os.path.isdir(log_dir + 'tb_log/'):
        os.mkdir(log_dir + 'tb_log/')

    writer = SummaryWriter(log_dir=log_dir + 'tb_log/')

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                **train_params['log_params']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):

            total_discriminator_loss = 0.
            total_generator_loss = 0.

            for i, x in enumerate(dataloader):

                # driving = driving_A, source = src_B --> driving_B'
                mn1_dict = {}
                mn1_dict['source'] = x['src_B']
                mn1_dict['video'] = x['driving_A']
                mn1_dict['gt_video'] = x['driving_B']
                '''This code is for the first model where we use ground truth key points for both'''
                if not use_both:
                    loss_values_B, loss_B, generated_B, gt_kp_joined_B = compute_loss(
                        generator_full_par, mn1_dict)
                else:
                    '''This code is for the second model where we give GT and approx kp respectively '''
                    out_B = compute_loss(generator_full_par,
                                         mn1_dict,
                                         use_both=True)
                    loss_values_B, loss_B, generated_B, approx_kp_joined_B, gt_kp_joined_B = out_B

                # driving = generated_B (driving_B'), source = src_A --> driving_A'
                mn2_dict = {}
                mn2_dict['source'] = x['src_A']
                mn2_dict['video'] = generated_B['video_prediction']
                mn2_dict['gt_video'] = x['driving_A']

                # First model - see above
                if not use_both:
                    loss_values_A, loss_A, generated_A, gt_kp_joined_A = compute_loss(
                        generator_full_par, mn2_dict)
                else:
                    # Second model - see above
                    out_A = compute_loss(generator_full_par,
                                         mn2_dict,
                                         use_both=True)
                    loss_values_A, loss_A, generated_A, approx_kp_joined_A, gt_kp_joined_A = out_A

                loss = loss_B + loss_A
                total_generator_loss += loss

                loss = loss_B
                total_generator_loss = loss

                writer.add_scalar('generator loss B', loss_B.item(), it)
                writer.add_scalar('generator loss A', loss_A.item(), it)
                writer.add_scalar('generator loss', loss.item(), it)

                loss.backward(
                    retain_graph=not train_params['detach_kp_discriminator'])
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_discriminator.zero_grad()

                if train_params['detach_kp_discriminator'] and update_kp:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                generator_loss_values = {}
                generator_loss_values['A'] = [
                    val.detach().cpu().numpy() for val in loss_values_A
                ]
                generator_loss_values['B'] = [
                    val.detach().cpu().numpy() for val in loss_values_B
                ]

                if not use_both:
                    loss_values_B = discriminator_full_par(
                        mn1_dict, gt_kp_joined_B, generated_B)
                    loss_values_A = discriminator_full_par(
                        mn2_dict, gt_kp_joined_A, generated_A)
                else:
                    loss_values_B = discriminator_full_par(
                        mn1_dict, approx_kp_joined_B, generated_B,
                        gt_kp_joined_B)
                    loss_values_A = discriminator_full_par(
                        mn2_dict, approx_kp_joined_A, generated_A,
                        gt_kp_joined_A)

                loss_values_B = [val.mean() for val in loss_values_B]
                loss_values_A = [val.mean() for val in loss_values_A]

                loss_B = sum(loss_values_B)
                loss_A = sum(loss_values_A)

                loss = loss_A + loss_B
                total_discriminator_loss += loss

                loss = loss_B
                total_discriminator_loss = loss

                writer.add_scalar('disc loss B', loss_B.item(), it)
                writer.add_scalar('disc loss A', loss_A.item(), it)
                writer.add_scalar('disc loss', loss.item(), it)

                loss.backward()
                optimizer_discriminator.step()
                optimizer_discriminator.zero_grad()
                if not train_params['detach_kp_discriminator'] and update_kp:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                discriminator_loss_values = {}
                discriminator_loss_values['A'] = [
                    val.detach().cpu().numpy() for val in loss_values_A
                ]
                discriminator_loss_values['B'] = [
                    val.detach().cpu().numpy() for val in loss_values_B
                ]

                values = {
                    'A':
                    generator_loss_values['A'] +
                    discriminator_loss_values['A'],
                    'B':
                    generator_loss_values['B'] + discriminator_loss_values['B']
                }

                logger.log_iter(
                    it,
                    names=generator_loss_names(train_params['loss_weights']) +
                    discriminator_loss_names(),
                    values=values['B'],
                    inp=mn1_dict,
                    out=generated_B,
                    name='src_B_driving_A')
                logger.log_iter(
                    it,
                    names=generator_loss_names(train_params['loss_weights']) +
                    discriminator_loss_names(),
                    values=values['A'],
                    inp=mn2_dict,
                    out=generated_A,
                    name='src_A_driving_B')
                it += 1

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()

            writer.add_scalar('generator loss / train',
                              total_generator_loss / (i + 1), epoch)
            writer.add_scalar('discriminator loss / train',
                              total_discriminator_loss / (i + 1), epoch)

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'discriminator': discriminator,
                    'kp_detector': kp_detector,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_discriminator': optimizer_discriminator,
                    'optimizer_kp_detector': optimizer_kp_detector
                })
コード例 #21
0
def train(config, generator, mask_generator, checkpoint, log_dir, dataset,
          device_ids):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr_generator'],
                                           betas=(0.5, 0.999))
    optimizer_mask_generator = torch.optim.Adam(
        mask_generator.parameters(),
        lr=train_params['lr_mask_generator'],
        betas=(0.5, 0.999))

    if checkpoint is not None:
        print('loading cpk')
        start_epoch = Logger.load_cpk(
            checkpoint, generator, mask_generator, optimizer_generator,
            None if train_params['lr_mask_generator'] == 0 else
            optimizer_mask_generator)
    else:
        start_epoch = 0

    print(start_epoch)
    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_mask_generator = MultiStepLR(
        optimizer_mask_generator,
        train_params['epoch_milestones'],
        gamma=0.1,
        last_epoch=-1 + start_epoch * (train_params['lr_mask_generator'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=6,
                            drop_last=True)

    generator_full = GeneratorFullModel(mask_generator, generator,
                                        train_params)

    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for index, x in enumerate(dataloader):
                predict_mask = epoch >= 1
                losses_generator, generated = generator_full(x, predict_mask)

                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_mask_generator.step()
                optimizer_mask_generator.zero_grad()

                losses = {
                    key: value.mean().detach().data.cpu().numpy()
                    for key, value in losses_generator.items()
                }
                logger.log_iter(losses=losses)

            scheduler_generator.step()
            scheduler_mask_generator.step()

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'mask_generator': mask_generator,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_mask_generator': optimizer_mask_generator
                },
                inp=x,
                out=generated,
                save_w=True)
print(opt)

torch.cuda.set_device(opt.device_id[0])

# ######################## Module #################################
print('Building model')
model = actionModel(opt.class_num,
                    batch_norm=True,
                    dropout=opt.dropout,
                    q=opt.q,
                    image_size=opt.img_size,
                    syn_bn=opt.syn_bn,
                    test_scheme=2)
print(model)
if opt.syn_bn:
    model = DataParallelWithCallback(model, device_ids=opt.device_id).cuda()
else:
    model = torch.nn.DataParallel(model, device_ids=opt.device_id).cuda()
print("Channels: " + str(model.module.channels))

# ########################Optimizer#########################
optimizer = torch.optim.SGD([{
    'params': model.module.RNN.parameters(),
    'lr': opt.LR[0]
}, {
    'params': model.module.ShortCut.parameters(),
    'lr': opt.LR[0]
}, {
    'params': model.module.classifier.parameters(),
    'lr': opt.LR[1]
}],
コード例 #23
0
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir,
          dataset, device_ids):
    train_params = config['train_params']

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                           lr=train_params['lr'],
                                           betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=train_params['lr'],
                                               betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(),
                                             lr=train_params['lr'],
                                             betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch, it = Logger.load_cpk(checkpoint, generator, discriminator,
                                          kp_detector, optimizer_generator,
                                          optimizer_discriminator,
                                          optimizer_kp_detector)
    else:
        start_epoch = 0
        it = 0

    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_generator,
                                          train_params['epoch_milestones'],
                                          gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_generator,
                                        train_params['epoch_milestones'],
                                        gamma=0.1,
                                        last_epoch=start_epoch - 1)

    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=4,
                            drop_last=True)

    generator_full = GeneratorFullModel(kp_detector, generator, discriminator,
                                        train_params)
    discriminator_full = DiscriminatorFullModel(kp_detector, generator,
                                                discriminator, train_params)

    generator_full_par = DataParallelWithCallback(generator_full,
                                                  device_ids=device_ids)
    discriminator_full_par = DataParallelWithCallback(discriminator_full,
                                                      device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                **train_params['log_params']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                out = generator_full_par(x)
                loss_values = out[:-2]
                generated = out[-2]
                kp_joined = out[-1]
                loss_values = [val.mean() for val in loss_values]
                loss = sum(loss_values)

                loss.backward(
                    retain_graph=not train_params['detach_kp_discriminator'])
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_discriminator.zero_grad()
                if train_params['detach_kp_discriminator']:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                generator_loss_values = [
                    val.detach().cpu().numpy() for val in loss_values
                ]

                loss_values = discriminator_full_par(x, kp_joined, generated)
                loss_values = [val.mean() for val in loss_values]
                loss = sum(loss_values)

                loss.backward()
                optimizer_discriminator.step()
                optimizer_discriminator.zero_grad()
                if not train_params['detach_kp_discriminator']:
                    optimizer_kp_detector.step()
                    optimizer_kp_detector.zero_grad()

                discriminator_loss_values = [
                    val.detach().cpu().numpy() for val in loss_values
                ]

                logger.log_iter(
                    it,
                    names=generator_loss_names(train_params['loss_weights']) +
                    discriminator_loss_names(),
                    values=generator_loss_values + discriminator_loss_values,
                    inp=x,
                    out=generated)
                it += 1

            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()

            logger.log_epoch(
                epoch, {
                    'generator': generator,
                    'discriminator': discriminator,
                    'kp_detector': kp_detector,
                    'optimizer_generator': optimizer_generator,
                    'optimizer_discriminator': optimizer_discriminator,
                    'optimizer_kp_detector': optimizer_kp_detector
                })
コード例 #24
0
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
    # Refer to *.yaml, "train_params" section.
    # This including epoch nums, etc ...
    train_params = config['train_params']

    # Define the optimizers for three sub-networks
    # Refer to Adam() document for details
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
    optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))

    if checkpoint is not None:
        # Load in pretrained-models if set so
        # Models passed in are empty-initialized, which will be loaded in the following function
        start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
                                      optimizer_generator, optimizer_discriminator,
                                      None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
    else:
        start_epoch = 0

    # TODO: not sure what's this, it seems to define schedulers contronlling training details
    scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
                                      last_epoch=start_epoch - 1)
    scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
                                          last_epoch=start_epoch - 1)
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
                                        last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))

    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        # Augment the dataset according to "num_reapeat"
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])
    # Load in data with form that network can determine
    # Refer to pytorch DataLoader for details
    # 这里dataloader是一个FramesDataset类,它是 Dataset 的一个子类,所以可以有如下操作
    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=2, drop_last=True)

    # Initialize two models for training
    # TODO: 阅读 generator 和 discrimator 的构造,key point detector 的部分应包含在 generator 当中
    generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
    # TODO: 阅读 discriminator,需注意的是上述 Generator 中也有 discriminator 存在,高清两者区别
    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)

    # Transfer model to gpu type
    if torch.cuda.is_available():
        generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
        discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)

    with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                # 此处为前向传播,第一个返回值为loss,第二个为生成器的输出图片
                losses_generator, generated = generator_full(x)

                # 此处计算的loss有很多种类,此处取了每一种的平均并求和
                loss_values = [val.mean() for val in losses_generator.values()]
                loss = sum(loss_values)

                # 此处分别使用不同部分的优化器进行 step 更新
                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                optimizer_kp_detector.step()
                optimizer_kp_detector.zero_grad()

                # 此处判断是否使用 GAN 的训练思想
                if train_params['loss_weights']['generator_gan'] != 0:
                    # 增加判别器的使用
                    optimizer_discriminator.zero_grad()
                    # 用判别器判定生成数据和源数据
                    losses_discriminator = discriminator_full(x, generated)
                    loss_values = [val.mean() for val in losses_discriminator.values()]
                    loss = sum(loss_values)

                    # 更新判别器
                    loss.backward()
                    optimizer_discriminator.step()
                    optimizer_discriminator.zero_grad()
                else:
                    losses_discriminator = {}

                # 注意此处的 update 是 python 中字典自带的更新方式
                losses_generator.update(losses_discriminator)
                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
                logger.log_iter(losses=losses)

            # 此处为一个 epoch 的工作完成
            # TODO: 这是之前不确定是什么的数据结构,推断是对训练的schedule器的更新
            scheduler_generator.step()
            scheduler_discriminator.step()
            scheduler_kp_detector.step()
            
            logger.log_epoch(epoch, {'generator': generator,
                                     'discriminator': discriminator,
                                     'kp_detector': kp_detector,
                                     'optimizer_generator': optimizer_generator,
                                     'optimizer_discriminator': optimizer_discriminator,
                                     'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
コード例 #25
0
def prediction(config, generator, kp_detector, checkpoint, log_dir):
    dataset = FramesDataset(is_train=True, transform=VideoToTensor(), **config['dataset_params'])
    log_dir = os.path.join(log_dir, 'prediction')
    png_dir = os.path.join(log_dir, 'png')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='prediction'.")
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    print("Extracting keypoints...")

    kp_detector.eval()
    generator.eval()

    keypoints_array = []

    prediction_params = config['prediction_params']

    for it, x in tqdm(enumerate(dataloader)):
        if prediction_params['train_size'] is not None:
            if it > prediction_params['train_size']:
                break
        with torch.no_grad():
            keypoints = []
            for i in range(x['video'].shape[2]):
                kp = kp_detector(x['video'][:, :, i:(i + 1)])
                kp = {k: v.data.cpu().numpy() for k, v in kp.items()}
                keypoints.append(kp)
            keypoints_array.append(keypoints)

    predictor = PredictionModule(num_kp=config['model_params']['common_params']['num_kp'],
                                 kp_variance=config['model_params']['common_params']['kp_variance'],
                                 **prediction_params['rnn_params']).cuda()

    num_epochs = prediction_params['num_epochs']
    lr = prediction_params['lr']
    bs = prediction_params['batch_size']
    num_frames = prediction_params['num_frames']
    init_frames = prediction_params['init_frames']

    optimizer = torch.optim.Adam(predictor.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)

    kp_dataset = KPDataset(keypoints_array, num_frames=num_frames)

    kp_dataloader = DataLoader(kp_dataset, batch_size=bs)

    print("Training prediction...")
    for _ in trange(num_epochs):
        loss_list = []
        for x in kp_dataloader:
            x = {k: v.cuda() for k, v in x.items()}
            gt = {k: v.clone() for k, v in x.items()}
            for k in x:
                x[k][:, init_frames:] = 0
            prediction = predictor(x)

            loss = sum([torch.abs(gt[k][:, init_frames:] - prediction[k][:, init_frames:]).mean() for k in x])

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_list.append(loss.detach().data.cpu().numpy())

        loss = np.mean(loss_list)
        scheduler.step(loss)

    dataset = FramesDataset(is_train=False, transform=VideoToTensor(), **config['dataset_params'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    print("Make predictions...")
    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x['video'] = x['video'][:, :, :num_frames]
            kp_init = kp_detector(x['video'])
            for k in kp_init:
                kp_init[k][:, init_frames:] = 0

            kp_source = kp_detector(x['video'][:, :, :1])

            kp_video = predictor(kp_init)
            for k in kp_video:
                kp_video[k][:, :init_frames] = kp_init[k][:, :init_frames]
            if 'var' in kp_video and prediction_params['predict_variance']:
                kp_video['var'] = kp_init['var'][:, (init_frames - 1):init_frames].repeat(1, kp_video['var'].shape[1],
                                                                                          1, 1, 1)
            out = generate(generator, appearance_image=x['video'][:, :, :1], kp_appearance=kp_source,
                           kp_video=kp_video)

            x['source'] = x['video'][:, :, :1]

            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(out_video_batch, [0, 2, 3, 4, 1])[0], axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(**config['visualizer_params']).visualize_reconstruction(x, out)
            image_name = x['name'][0] + prediction_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), image)

            del x, kp_video, kp_source, out
コード例 #26
0
ファイル: animate.py プロジェクト: abhinav-TB/Pic2Vedio
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'animation')
    png_dir = os.path.join(log_dir, 'png')
    animate_params = config['animate_params']

    dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            predictions = []
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            kp_source = kp_detector(source_frame)
            kp_driving_initial = kp_detector(driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                kp_driving = kp_detector(driving_frame)
                kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                       kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
                out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)

                out['kp_driving'] = kp_driving
                out['kp_source'] = kp_source
                out['kp_norm'] = kp_norm

                del out['sparse_deformed']

                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

                visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
                                                                                    driving=driving_frame, out=out)
                visualization = visualization
                visualizations.append(visualization)

            predictions = np.concatenate(predictions, axis=1)
            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))

            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
コード例 #27
0
ファイル: main.py プロジェクト: wuziniu/Fixed_Kernel_CNN
                               batch_size=max(param_grid['batch_size']))
 criterion = nn.MSELoss()
 best_valid_RMSE = np.full(1, np.inf)
 for grid in ParameterGrid(param_grid):
     print(f"===> Hyper-parameters = {grid}:")
     # random.seed(seed)
     # numpy.random.seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     importlib.reload(model)
     net = getattr(model,
                   model_name.upper())(input_channel=input_channel,
                                       input_size=input_size,
                                       output_size=output_size).to(device)
     if use_cuda:
         net = DataParallelWithCallback(net)
     print(f"===> Model:\n{list(net.modules())[0]}")
     print_param(net)
     if model_name == 'cnn0' or model_name == 'cnn1':
         optimizer = Adam(net.parameters(),
                          lr=grid['lr'],
                          l1=grid['l1'],
                          weight_decay=grid['l2'],
                          amsgrad=True)
     elif model_name == 'vgg19' or model_name == 'cnn2' or model_name == 'cnn3':
         optimizer = Adam([{
             'params':
             iter(param for name, param in net.named_parameters()
                  if 'channel_mask' in name),
             'l1':
             grid['l1_channel']
コード例 #28
0
def train_kpdetector(model_kp_detector,
                       loader,
                       loader_tgt,
                       train_params,
                       checkpoint,
                       logger, device_ids, tgt_batch=None, kp_map=None):
    log_params = train_params['log_params']
    optimizer_kp_detector = torch.optim.Adam(model_kp_detector.parameters(),
                                            lr=train_params['lr'],
                                            betas=train_params['betas'])
    resume_epoch = 0
    resume_iteration = 0
    if checkpoint is not None:
        print('Loading Checkpoint: %s' % checkpoint)
        # TODO: Implement Load/resumo kp_detector
        if train_params['test'] == False:
            resume_epoch, resume_iteration = logger.checkpoint.load_checkpoint(checkpoint,
                                                  model_kp_detector=model_kp_detector,
                                                  optimizer_kp_detector=optimizer_kp_detector)
        else:
            net_dict = model_kp_detector.state_dict()
            pretrained_dict = torch.load(checkpoint)
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict)}
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if pretrained_dict[k].shape == net_dict[k].shape}
            net_dict.update(pretrained_dict)
            model_kp_detector.load_state_dict(net_dict, strict=True)
            model_kp_detector.apply(convertLayer)
 
        logger.epoch = resume_epoch
        logger.iterations = resume_iteration
    scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, 
                                       train_params['epoch_milestones'], 
                                       gamma=0.1, last_epoch=logger.epoch-1)
    
    kp_detector = KPDetectorTrainer(model_kp_detector)
    kp_detector = DataParallelWithCallback(kp_detector, device_ids=device_ids)
    k = 0
    if train_params['test'] == True:
        results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset'])
        print(' MSE: ' + str(results['MSE']) + ' PCK: ' + str(results['PCK'])) 
        return

    heatmap_var = train_params['heatmap_var']

    for epoch in range(logger.epoch, train_params['num_epochs']):
        results = evaluate(model_kp_detector, loader_tgt, dset=train_params['dataset'])
        results_train = evaluate(model_kp_detector, loader, dset=train_params['dataset']) 
        print('Epoch ' + str(epoch)+ ' MSE: ' + str(results['MSE']))
        logger.add_scalar('MSE test', results['MSE'], epoch)
        logger.add_scalar('PCK test', results['PCK'], epoch)
        logger.add_scalar('MSE train', results_train['MSE'], epoch)
        logger.add_scalar('PCK train', results_train['PCK'], epoch)
 
        for i, batch  in enumerate(tqdm(loader)):
            images = batch['imgs']
            if (images != images).sum() > 0:
                print('Images has NaN')
                break
            annots = batch['annots'] 
            gt_heatmaps = kp2gaussian2(annots, (model_kp_detector.heatmap_res, 
                                                model_kp_detector.heatmap_res), heatmap_var).detach() 
            if (annots != annots).sum() > 0 or (annots.abs() == float("Inf")).sum() > 0:
                print('Annotation with NaN')
                break
            mask = None if 'kp_mask' not in batch.keys() else batch['kp_mask']
            ######## REMOVE
            #print(f"b_mask {mask}")
            #print(f"mask {mask.shape}")
            ##################
            #kp_detector_out = kp_detector(images, annots, mask)
            kp_detector_out = kp_detector(images, gt_heatmaps, mask)

            loss = kp_detector_out['l2_loss'].mean()
            loss.backward()

            optimizer_kp_detector.step()
            optimizer_kp_detector.zero_grad()
            ####### LOG VALIDATION
            if i % log_params['eval_frequency'] == 0:
                tgt_batch = next(iter(loader_tgt))
                eval_out = eval_model(kp_detector, tgt_batch, model_kp_detector.heatmap_res, heatmap_var)
                eval_sz = int(len(loader)/log_params['eval_frequency'])
                it_number = epoch * eval_sz  + (logger.iterations/log_params['eval_frequency'])
                logger.add_scalar('Eval loss', eval_out['l2_loss'].mean(), it_number)
                concat_img = np.concatenate((draw_kp(tensor_to_image(tgt_batch['imgs'][k]),unnorm_kp(tgt_batch['annots'][k])),
                                            draw_kp(tensor_to_image(tgt_batch['imgs'][k]), eval_out['keypoints'][k], color='red')), axis=2)

                heatmap_img_0 = tensor_to_image(kp_detector_out['heatmaps'][k, 0].unsqueeze(0), True)
                heatmap_img_1 = tensor_to_image(kp_detector_out['heatmaps'][k, 5].unsqueeze(0), True)
                src_heatmap_0 = tensor_to_image(gt_heatmaps[k, 0].unsqueeze(0), True)
                src_heatmap_1 = tensor_to_image(gt_heatmaps[k, 5].unsqueeze(0), True)
                heatmaps_img = np.concatenate((heatmap_img_0, heatmap_img_1), axis = 2)
                src_heatmaps = np.concatenate((src_heatmap_0, src_heatmap_1), axis = 2)
 
                logger.add_image('Eval_', concat_img, logger.iterations)
                logger.add_image('heatmaps', heatmaps_img, logger.iterations)
                logger.add_image('src heatmaps', src_heatmaps, logger.iterations)
 
            ####### LOG
            logger.add_scalar('L2 loss', 
                               loss.item(), 
                               logger.iterations)
            if i in log_params['log_imgs']:
                concat_img_train = np.concatenate((draw_kp(tensor_to_image(images[k]), unnorm_kp(annots[k])),
                                                  draw_kp(tensor_to_image(images[k]), kp_detector_out['keypoints'][k], color='red')), axis=2)
 
                logger.add_image('Train_{%d}' % i, concat_img_train, logger.iterations)
                k += 1
                k = k % len(log_params['log_imgs']) 
            logger.step_it()

        scheduler_kp_detector.step()
        logger.step_epoch(models = {'model_kp_detector':model_kp_detector,
                                    'optimizer_kp_detector':optimizer_kp_detector})
コード例 #29
0
def train(config, generator, region_predictor, bg_predictor, checkpoint,
          log_dir, dataset, device_ids):
    train_params = config['train_params']

    optimizer = torch.optim.Adam(list(generator.parameters()) +
                                 list(region_predictor.parameters()) +
                                 list(bg_predictor.parameters()),
                                 lr=train_params['lr'],
                                 betas=(0.5, 0.999))

    if checkpoint is not None:
        start_epoch = Logger.load_cpk(checkpoint, generator, region_predictor,
                                      bg_predictor, None, optimizer, None)
    else:
        start_epoch = 0

    scheduler = MultiStepLR(optimizer,
                            train_params['epoch_milestones'],
                            gamma=0.1,
                            last_epoch=start_epoch - 1)
    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        dataset = DatasetRepeater(dataset, train_params['num_repeats'])

    dataloader = DataLoader(dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=train_params['dataloader_workers'],
                            drop_last=True)

    model = ReconstructionModel(region_predictor, bg_predictor, generator,
                                train_params)

    if torch.cuda.is_available():
        if ('use_sync_bn' in train_params) and train_params['use_sync_bn']:
            model = DataParallelWithCallback(model, device_ids=device_ids)
        else:
            model = torch.nn.DataParallel(model, device_ids=device_ids)

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                checkpoint_freq=train_params['checkpoint_freq']) as logger:
        for epoch in trange(start_epoch, train_params['num_epochs']):
            for x in dataloader:
                losses, generated = model(x)
                loss_values = [val.mean() for val in losses.values()]
                loss = sum(loss_values)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                losses = {
                    key: value.mean().detach().data.cpu().numpy()
                    for key, value in losses.items()
                }
                logger.log_iter(losses=losses)

            scheduler.step()
            logger.log_epoch(epoch, {
                'generator': generator,
                'bg_predictor': bg_predictor,
                'region_predictor': region_predictor,
                'optimizer_reconstruction': optimizer
            },
                             inp=x,
                             out=generated)
コード例 #30
0
def reconstruction(config, generator, mask_generator, checkpoint, log_dir,
                   dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        epoch = Logger.load_cpk(checkpoint,
                                generator=generator,
                                mask_generator=mask_generator)
        print('checkpoint:' + str(epoch))
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='reconstruction'.")
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        mask_generator = DataParallelWithCallback(mask_generator)

    generator.eval()
    mask_generator.eval()

    recon_gen_dir = './log/recon_gen'
    os.makedirs(recon_gen_dir, exist_ok=False)

    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            mask_source = mask_generator(x['video'][:, :, 0])

            video_gen_dir = recon_gen_dir + '/' + x['name'][0]
            os.makedirs(video_gen_dir, exist_ok=False)

            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                mask_driving = mask_generator(driving)
                out = generator(source,
                                driving,
                                mask_source=mask_source,
                                mask_driving=mask_driving,
                                mask_driving2=None,
                                animate=False,
                                predict_mask=False)
                out['mask_source'] = mask_source
                out['mask_driving'] = mask_driving

                predictions.append(
                    np.transpose(
                        out['second_phase_prediction'].data.cpu().numpy(),
                        [0, 2, 3, 1])[0])

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(source=source,
                                                             driving=driving,
                                                             target=None,
                                                             out=out,
                                                             driving2=None)
                visualizations.append(visualization)

                loss_list.append(
                    torch.abs(out['second_phase_prediction'] -
                              driving).mean().cpu().numpy())

                frame_name = str(frame_idx).zfill(7) + '.png'
                second_phase_prediction = out[
                    'second_phase_prediction'].data.cpu().numpy()
                second_phase_prediction = np.transpose(second_phase_prediction,
                                                       [0, 2, 3, 1])
                second_phase_prediction = (255 *
                                           second_phase_prediction).astype(
                                               np.uint8)
                imageio.imsave(os.path.join(video_gen_dir, frame_name),
                               second_phase_prediction[0])

            predictions = np.concatenate(predictions, axis=1)

            image_name = x['name'][0] + config['reconstruction_params'][
                'format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

    print("Reconstruction loss: %s" % np.mean(loss_list))