Example #1
0
def transfer(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'transfer')
    png_dir = os.path.join(log_dir, 'png')
    transfer_params = config['transfer_params']

    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=transfer_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='transfer'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x = {
                key: value if not hasattr(value, 'cuda') else value.cuda()
                for key, value in x.items()
            }
            driving_video = x['driving_video']
            source_image = x['source_video'][:, :, :1, :, :]
            out = transfer_one(generator, kp_detector, source_image,
                               driving_video, transfer_params)
            img_name = "-".join([x['driving_name'][0], x['source_name'][0]])

            # Store to .png for evaluation
            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(
                out_video_batch, [0, 2, 3, 4, 1])[0],
                                             axis=1)
            imageio.imsave(os.path.join(png_dir, img_name + '.png'),
                           (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(
                **config['visualizer_params']).visualize_transfer(
                    driving_video=driving_video,
                    source_image=source_image,
                    out=out)
            imageio.mimsave(
                os.path.join(log_dir, img_name + transfer_params['format']),
                image)
Example #2
0
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            kp_source = kp_detector(x['video'][:, :, 0])
            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                kp_driving = kp_detector(driving)
                out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
                out['kp_source'] = kp_source
                out['kp_driving'] = kp_driving
                del out['sparse_deformed']
                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

                visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
                                                                                    driving=driving, out=out)
                visualizations.append(visualization)

                loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())

            predictions = np.concatenate(predictions, axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))

            image_name = x['name'][0] + config['reconstruction_params']['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

    print("Reconstruction loss: %s" % np.mean(loss_list))
Example #3
0
def load_checkpoints(kp_detector_path, generator_path, cpu=False):

    generator = torch.jit.load(generator_path)

    kp_detector = torch.jit.load(kp_detector_path)

    if not cpu:
        generator.cuda()

    if not cpu:
        kp_detector.cuda()

    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector
Example #4
0
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'animation')
    png_dir = os.path.join(log_dir, 'png')
    animate_params = config['animate_params']

    dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            predictions = []
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            kp_source = kp_detector(source_frame)
            kp_driving_initial = kp_detector(driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                kp_driving = kp_detector(driving_frame)
                kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                       kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
                out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)

                out['kp_driving'] = kp_driving
                out['kp_source'] = kp_source
                out['kp_norm'] = kp_norm

                del out['sparse_deformed']

                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

                visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
                                                                                    driving=driving_frame, out=out)
                visualization = visualization
                visualizations.append(visualization)

            predictions = np.concatenate(predictions, axis=1)
            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))

            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
def debug_generator(generator, kp_to_skl_gt, loader, train_params, 
                     logger, device_ids, tgt_batch=None):
    log_params = train_params['log_params']
    genModel = ConditionalGenerator2D(generator, train_params)
    genModel = DataParallelWithCallback(genModel, device_ids=device_ids)

    optimizer_generator = torch.optim.Adam(generator.parameters(),
                                            lr=train_params['lr'],
                                            betas=train_params['betas'])
    scheduler_generator = MultiStepLR(optimizer_generator, 
                                       train_params['epoch_milestones'], 
                                       gamma=0.1, last_epoch=-1)
 
    k=0
    train_views = [0,1,3]
    eval_views = [2]
    if tgt_batch is not None:
        tgt_batch_samples = split_data(tgt_batch, 
                                       train_views=train_views, 
                                       eval_views=eval_views)
        with torch.no_grad():
            tgt_batch_samples['gt_skl'] = kp_to_skl_gt(tgt_batch_samples['kps'].to('cuda')).unsqueeze(1)
            tgt_batch_samples['gt_skl_eval'] = kp_to_skl_gt(tgt_batch_samples['kps_eval'].to('cuda')).unsqueeze(1)
        
    for epoch in range(train_params['num_epochs']):
        for i, batch  in enumerate(tqdm(loader)):
            batch_samples = split_data((img, annots, ref_img), 
                                         train_views=train_views, 
                                         eval_views=eval_views)
            #imgs = flatten_views(imgs)
            #ref_imgs = flatten_views(ref_imgs)
            #ref_imgs = torch.rand(*imgs.shape)
            with torch.no_grad():
                batch_samples['gt_skl'] = kp_to_skl_gt(batch_samples['kps'].to('cuda')).unsqueeze(1)
                #batch_samples['gt_skl_eval'] = kp_to_skl_gt(batch_samples['kps_eval'].to('cuda')).unsqueeze(1)
                #gt_skl = (kp_to_skl_gt(flatten_views(annots / (ref_img.shape[3] - 1)).to('cuda'))).unsqueeze(1)
            #gt_skl = torch.rand(imgs.shape[0], 1, *imgs.shape[2:])

            #generator_out = genModel(imgs, ref_imgs, gt_skl)
            generator_out = genModel(batch_samples['imgs'], batch_samples['ref_imgs'], batch_samples['gt_skl'])
            ##### Generator update
            #loss_generator = generator_out['loss']
            loss_generator = generator_out['perceptual_loss']
            loss_generator = [x.mean() for x in loss_generator]
            loss_gen = sum(loss_generator)
            loss_gen.backward(retain_graph=not train_params['detach_kp_discriminator'])
            optimizer_generator.step()
            optimizer_generator.zero_grad()

            ########### LOG
            logger.add_scalar("Generator Loss", 
                               loss_gen.item(), 
                               epoch * len(loader) + i + 1)
            if i in log_params['log_imgs']:
                if tgt_batch is not None:
                    with torch.no_grad():
                        genModel.eval()
                        generator_out_eval = genModel(tgt_batch_samples['imgs_eval'], 
                                                      tgt_batch_samples['ref_imgs_eval'],
                                                      tgt_batch_samples['gt_skl_eval'])
                        #generator_out_eval = genModel(batch_samples['imgs_eval'], 
                        #                              batch_samples['ref_imgs_eval'],
                        #                              batch_samples['gt_skl_eval'])
                        concat_img_eval = np.concatenate((tensor_to_image(tgt_batch_samples['imgs_eval'][k]), 
                                     tensor_to_image(tgt_batch_samples['gt_skl_eval'][k]), 
                                     tensor_to_image(tgt_batch_samples['ref_imgs_eval'][k]),
                                     tensor_to_image(generator_out_eval['reconstructred_image'][k])), axis=2)  # concat along width
                        logger.add_image('Sample_{%d}_EVAL' % i, concat_img_eval, epoch)
                        genModel.train()
                k += 1
                k = k % 4
                concat_img = np.concatenate((tensor_to_image(batch_samples['imgs'][k]), 
                             tensor_to_image(batch_samples['gt_skl'][k]), 
                             tensor_to_image(batch_samples['ref_imgs'][k]),
                             tensor_to_image(generator_out['reconstructred_image'][k])), axis=2)  # concat along width
                logger.add_image('Sample_{%d}' % i, concat_img, epoch)


        scheduler_generator.step()
            if step % 2 == 1:
                optimizer.step()
                optimizer.zero_grad()

            predict_for_mAP = []
            label_for_mAP = []

        # ################################### test ###############################
        if (step + resume_step) % 700 == 699:
            test = True

        if test:
            print('Start Test')
            TEST_LOSS = AverageMeter()
            with torch.no_grad():
                model.eval()
                predict_for_mAP = []
                label_for_mAP = []
                print("TESTING")

                for step_test, (x, _, _, action) in tqdm(
                        enumerate(Loader_test)):  # gives batch data
                    b_x = Variable(x).cuda()
                    b_action = Variable(action).cuda()

                    c = [
                        Variable(
                            torch.from_numpy(
                                np.zeros((len(b_x),
                                          model.module.channels[layer + 1],
                                          model.module.input_size[layer],
Example #7
0
class Model:
    def __init__(self,
                 hidden_dim,
                 lr,
                 hard_or_full_trip,
                 margin,
                 num_workers,
                 batch_size,
                 restore_iter,
                 total_iter,
                 save_name,
                 train_pid_num,
                 frame_num,
                 model_name,
                 train_source,
                 test_source,
                 img_size=64):

        self.save_name = save_name
        self.train_pid_num = train_pid_num
        self.train_source = train_source
        self.test_source = test_source

        self.hidden_dim = hidden_dim
        self.lr = lr
        self.hard_or_full_trip = hard_or_full_trip
        self.margin = margin
        self.frame_num = frame_num
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.model_name = model_name
        self.P, self.M = batch_size

        self.restore_iter = restore_iter
        self.total_iter = total_iter

        self.img_size = img_size

        self.encoder = SetNet(self.hidden_dim).float()
        self.encoder = DataParallelWithCallback(self.encoder)
        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
        self.triplet_loss = DataParallelWithCallback(self.triplet_loss)
        self.encoder.cuda()
        self.triplet_loss.cuda()

        self.optimizer = optim.Adam([
            {'params': self.encoder.parameters()},
        ], lr=self.lr)

        self.hard_loss_metric = []
        self.full_loss_metric = []
        self.full_loss_num = []
        self.dist_list = []
        self.mean_dist = 0.01

        self.sample_type = 'all'

    def collate_fn(self, batch):
        batch_size = len(batch)
        feature_num = len(batch[0][0])
        seqs = [batch[i][0] for i in range(batch_size)]
        frame_sets = [batch[i][1] for i in range(batch_size)]
        view = [batch[i][2] for i in range(batch_size)]
        seq_type = [batch[i][3] for i in range(batch_size)]
        label = [batch[i][4] for i in range(batch_size)]
        batch = [seqs, view, seq_type, label, None]

        def select_frame(index):
            sample = seqs[index]
            frame_set = frame_sets[index]
            if self.sample_type == 'random':
                frame_id_list = random.choices(frame_set, k=self.frame_num)
                _ = [feature.loc[frame_id_list].values for feature in sample]
            else:
                _ = [feature.values for feature in sample]
            return _

        seqs = list(map(select_frame, range(len(seqs))))

        if self.sample_type == 'random':
            seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
        else:
            gpu_num = min(torch.cuda.device_count(), batch_size)
            batch_per_gpu = math.ceil(batch_size / gpu_num)
            batch_frames = [[
                                len(frame_sets[i])
                                for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                if i < batch_size
                                ] for _ in range(gpu_num)]
            if len(batch_frames[-1]) != batch_per_gpu:
                for _ in range(batch_per_gpu - len(batch_frames[-1])):
                    batch_frames[-1].append(0)
            max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
            seqs = [[
                        np.concatenate([
                                           seqs[i][j]
                                           for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                           if i < batch_size
                                           ], 0) for _ in range(gpu_num)]
                    for j in range(feature_num)]
            seqs = [np.asarray([
                                   np.pad(seqs[j][_],
                                          ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
                                          'constant',
                                          constant_values=0)
                                   for _ in range(gpu_num)])
                    for j in range(feature_num)]
            batch[4] = np.asarray(batch_frames)

        batch[0] = seqs
        return batch

    def fit(self):
        if self.restore_iter != 0:
            self.load(self.restore_iter)

        self.encoder.train()
        self.sample_type = 'random'
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        triplet_sampler = TripletSampler(self.train_source, self.batch_size)
        train_loader = tordata.DataLoader(
            dataset=self.train_source,
            batch_sampler=triplet_sampler,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        train_label_set = list(self.train_source.label_set)
        train_label_set.sort()

        _time1 = datetime.now()
        for seq, view, seq_type, label, batch_frame in train_loader:
            self.restore_iter += 1
            self.optimizer.zero_grad()

            for i in range(len(seq)):
                seq[i] = self.np2var(seq[i]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()

            feature, label_prob = self.encoder(*seq, batch_frame)

            target_label = [train_label_set.index(l) for l in label]
            target_label = self.np2var(np.array(target_label)).long()

            triplet_feature = feature.permute(1, 0, 2).contiguous()
            triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
            (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
             ) = self.triplet_loss(triplet_feature, triplet_label)
            if self.hard_or_full_trip == 'hard':
                loss = hard_loss_metric.mean()
            elif self.hard_or_full_trip == 'full':
                loss = full_loss_metric.mean()

            self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
            self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
            self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
            self.dist_list.append(mean_dist.mean().data.cpu().numpy())

            if loss > 1e-9:
                loss.backward()
                self.optimizer.step()

            if self.restore_iter % 1000 == 0:
                print(datetime.now() - _time1)
                _time1 = datetime.now()

            if self.restore_iter % 100 == 0:
                self.save()
                print('iter {}:'.format(self.restore_iter), end='')
                print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
                print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
                print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
                self.mean_dist = np.mean(self.dist_list)
                print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
                print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
                print(', hard or full=%r' % self.hard_or_full_trip)
                sys.stdout.flush()
                self.hard_loss_metric = []
                self.full_loss_metric = []
                self.full_loss_num = []
                self.dist_list = []

            # Visualization using t-SNE
            # if self.restore_iter % 500 == 0:
            #     pca = TSNE(2)
            #     pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
            #     for i in range(self.P):
            #         plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
            #                     pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
            #
            #     plt.show()

            if self.restore_iter == self.total_iter:
                break

    def ts2var(self, x):
        return autograd.Variable(x).cuda()

    def np2var(self, x):
        return self.ts2var(torch.from_numpy(x))

    def transform(self, flag, batch_size=1):
        self.encoder.eval()
        source = self.test_source if flag == 'test' else self.train_source
        self.sample_type = 'all'
        data_loader = tordata.DataLoader(
            dataset=source,
            batch_size=batch_size,
            sampler=tordata.sampler.SequentialSampler(source),
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        feature_list = list()
        view_list = list()
        seq_type_list = list()
        label_list = list()

        for i, x in enumerate(data_loader):
            seq, view, seq_type, label, batch_frame = x
            for j in range(len(seq)):
                seq[j] = self.np2var(seq[j]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()
            # print(batch_frame, np.sum(batch_frame))

            feature, _ = self.encoder(*seq, batch_frame)
            n, num_bin, _ = feature.size()
            feature_list.append(feature.view(n, -1).data.cpu().numpy())
            view_list += view
            seq_type_list += seq_type
            label_list += label

        return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list

    def save(self):
        os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
        torch.save(self.encoder.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-encoder.ptm'.format(
                                self.save_name, self.restore_iter)))
        torch.save(self.optimizer.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-optimizer.ptm'.format(
                                self.save_name, self.restore_iter)))

    # restore_iter: iteration index of the checkpoint to load
    def load(self, restore_iter):
        self.encoder.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
        self.optimizer.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
Example #8
0
class Framework:
    batch_axis = 1

    def __init__(self, config):
        self.config = config
        self.build_dataset()
        self.build_network()
        self.build_optimizer()

    def get_param(self, keys, default=None):
        node = self.config
        for key in keys.split('.'):
            if key in node:
                node = node[key]
            else:
                return default
        return node

    def cuda(self):
        self.model.cuda()
        if self.get_param('network.sync_bn', False):
            self.model = DataParallelWithCallback(self.model,
                                                  dim=self.batch_axis)
        else:
            self.model = nn.DataParallel(self.model, dim=self.batch_axis)

    def build_optimizer(self):
        args = copy(self.config['optimizer'])
        if args['type'] == 'SGD':
            optim_class = torch.optim.SGD
        elif args['type'] == 'Adam':
            optim_class = torch.optim.Adam
        args.pop('type')
        if isinstance(args['lr'], list):
            args['lr'] = args['lr'][0][0]
        self.optimizer = optim_class(self.model.parameters(), **args)

    def set_learning_rate(self, epoch):
        lrs = self.config['optimizer']['lr']
        if isinstance(lrs, list):
            c = 0
            for lr, num_epochs in lrs:
                c += num_epochs
                if epoch <= c:
                    break
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
            print('setting learning rate {}'.format(lr))

    def build_dataset(self):
        length = self.config['data']['length']
        stride = 5
        self.train_data = dataset.build_ucf101_dataset(
            'traindev1',
            transforms=transforms.Compose([
                transforms.RandomResizedCrop((224, 224), (0.5, 1.0),
                                             ratio=(3 / 4, 4 / 3)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ]),
            length=length,
            stride=stride,
            config=self.config)
        self.train_loader = torch.utils.data.DataLoader(
            self.train_data,
            sampler=sampler.RandomSampler(self.train_data,
                                          loop=self.config['data']['loop']),
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_worker'])
        self.val_data, self.val_loader = self.build_test_dataset('val1')
        self.classes = self.train_data.classes

    def build_test_dataset(self, split):
        length = self.config['data']['length']
        stride = 5
        data = dataset.build_ucf101_dataset(split,
                                            transforms=transforms.Compose([
                                                transforms.CenterCrop(
                                                    (224, 224)),
                                                transforms.ToTensor()
                                            ]),
                                            length=length,
                                            stride=stride,
                                            config=self.config)
        loader = torch.utils.data.DataLoader(
            data,
            sampler=sampler.ValSampler(data, stride=length),
            batch_size=self.config['batch_size'],
            num_workers=self.config['num_worker'])
        return data, loader

    # Why Non_Blocking

    def prepare_data(self, data):
        frames = torch.stack(data[0], dim=0).cuda(non_blocking=True)
        labels = data[1].cuda(non_blocking=True)
        vids = data[2].numpy()
        return (frames, labels), vids

    def train_epoch(self, epoch):
        self.model.train()
        self.set_learning_rate(epoch)
        end = time.time()
        metrics = defaultdict(AverageMeter)
        for i, data in enumerate(self.train_loader):
            # measure data loading time
            metrics['data_time'].update(time.time() - end)

            args, _ = self.prepare_data(data)
            batch_size = args[0].size(1)
            result = self.train_batch(*args)
            loss = result['loss']
            for k, v in result.items():
                metrics[k].update(v.item(), batch_size)

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            metrics['batch_time'].update(time.time() - end)
            end = time.time()

            if i % self.config['print_freq'] == 0:
                print('Epoch: [{0}][{1}/{2}]\t'.format(epoch, i,
                                                       len(self.train_loader)),
                      end='')

                for k, v in metrics.items():
                    print('{key} {val.avg:.3f}'.format(key=k, val=v), end='\t')
                print()
            # if i > 200:
            #     break

        metrics.pop('batch_time')
        metrics.pop('data_time')
        return {k: v.avg for k, v in metrics.items()}

    def predict(self, dataloader):
        self.model.eval()
        metrics = defaultdict(list)
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                args, indices = self.prepare_data(data)
                result = self.predict_batch(*args)
                for k, v in result.items():
                    metrics[k].append(v.cpu().numpy())
                metrics['indices'].append(indices)
                if i % self.config['print_freq'] == 0:
                    print('Valid {}/{}'.format(i, len(dataloader)))
        for k in metrics:
            metrics[k] = np.concatenate(metrics[k], axis=0)
        return metrics

    def evaluate(self, dataloader):
        self.model.eval()
        metrics = defaultdict(AverageMeter)
        with torch.no_grad():
            for i, data in enumerate(dataloader):
                args, indices = self.prepare_data(data)
                batch_size = args[0][0].size(0)
                result = self.eval_batch(*args)
                for k, v in result.items():
                    metrics[k].update(v.item(), batch_size)
                if i % self.config['print_freq'] == 0:
                    print('Valid {}/{}'.format(i, len(dataloader)))
        return {k: v.avg for k, v in metrics.items()}
Example #9
0
def prediction(config, generator, kp_detector, checkpoint, log_dir):
    dataset = FramesDataset(is_train=True, transform=VideoToTensor(), **config['dataset_params'])
    log_dir = os.path.join(log_dir, 'prediction')
    png_dir = os.path.join(log_dir, 'png')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='prediction'.")
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    print("Extracting keypoints...")

    kp_detector.eval()
    generator.eval()

    keypoints_array = []

    prediction_params = config['prediction_params']

    for it, x in tqdm(enumerate(dataloader)):
        if prediction_params['train_size'] is not None:
            if it > prediction_params['train_size']:
                break
        with torch.no_grad():
            keypoints = []
            for i in range(x['video'].shape[2]):
                kp = kp_detector(x['video'][:, :, i:(i + 1)])
                kp = {k: v.data.cpu().numpy() for k, v in kp.items()}
                keypoints.append(kp)
            keypoints_array.append(keypoints)

    predictor = PredictionModule(num_kp=config['model_params']['common_params']['num_kp'],
                                 kp_variance=config['model_params']['common_params']['kp_variance'],
                                 **prediction_params['rnn_params']).cuda()

    num_epochs = prediction_params['num_epochs']
    lr = prediction_params['lr']
    bs = prediction_params['batch_size']
    num_frames = prediction_params['num_frames']
    init_frames = prediction_params['init_frames']

    optimizer = torch.optim.Adam(predictor.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)

    kp_dataset = KPDataset(keypoints_array, num_frames=num_frames)

    kp_dataloader = DataLoader(kp_dataset, batch_size=bs)

    print("Training prediction...")
    for _ in trange(num_epochs):
        loss_list = []
        for x in kp_dataloader:
            x = {k: v.cuda() for k, v in x.items()}
            gt = {k: v.clone() for k, v in x.items()}
            for k in x:
                x[k][:, init_frames:] = 0
            prediction = predictor(x)

            loss = sum([torch.abs(gt[k][:, init_frames:] - prediction[k][:, init_frames:]).mean() for k in x])

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_list.append(loss.detach().data.cpu().numpy())

        loss = np.mean(loss_list)
        scheduler.step(loss)

    dataset = FramesDataset(is_train=False, transform=VideoToTensor(), **config['dataset_params'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    print("Make predictions...")
    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x['video'] = x['video'][:, :, :num_frames]
            kp_init = kp_detector(x['video'])
            for k in kp_init:
                kp_init[k][:, init_frames:] = 0

            kp_source = kp_detector(x['video'][:, :, :1])

            kp_video = predictor(kp_init)
            for k in kp_video:
                kp_video[k][:, :init_frames] = kp_init[k][:, :init_frames]
            if 'var' in kp_video and prediction_params['predict_variance']:
                kp_video['var'] = kp_init['var'][:, (init_frames - 1):init_frames].repeat(1, kp_video['var'].shape[1],
                                                                                          1, 1, 1)
            out = generate(generator, appearance_image=x['video'][:, :, :1], kp_appearance=kp_source,
                           kp_video=kp_video)

            x['source'] = x['video'][:, :, :1]

            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(out_video_batch, [0, 2, 3, 4, 1])[0], axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(**config['visualizer_params']).visualize_reconstruction(x, out)
            image_name = x['name'][0] + prediction_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), image)

            del x, kp_video, kp_source, out
def reconstruction(config, generator, mask_generator, checkpoint, log_dir,
                   dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        epoch = Logger.load_cpk(checkpoint,
                                generator=generator,
                                mask_generator=mask_generator)
        print('checkpoint:' + str(epoch))
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='reconstruction'.")
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        mask_generator = DataParallelWithCallback(mask_generator)

    generator.eval()
    mask_generator.eval()

    recon_gen_dir = './log/recon_gen'
    os.makedirs(recon_gen_dir, exist_ok=False)

    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            mask_source = mask_generator(x['video'][:, :, 0])

            video_gen_dir = recon_gen_dir + '/' + x['name'][0]
            os.makedirs(video_gen_dir, exist_ok=False)

            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                mask_driving = mask_generator(driving)
                out = generator(source,
                                driving,
                                mask_source=mask_source,
                                mask_driving=mask_driving,
                                mask_driving2=None,
                                animate=False,
                                predict_mask=False)
                out['mask_source'] = mask_source
                out['mask_driving'] = mask_driving

                predictions.append(
                    np.transpose(
                        out['second_phase_prediction'].data.cpu().numpy(),
                        [0, 2, 3, 1])[0])

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(source=source,
                                                             driving=driving,
                                                             target=None,
                                                             out=out,
                                                             driving2=None)
                visualizations.append(visualization)

                loss_list.append(
                    torch.abs(out['second_phase_prediction'] -
                              driving).mean().cpu().numpy())

                frame_name = str(frame_idx).zfill(7) + '.png'
                second_phase_prediction = out[
                    'second_phase_prediction'].data.cpu().numpy()
                second_phase_prediction = np.transpose(second_phase_prediction,
                                                       [0, 2, 3, 1])
                second_phase_prediction = (255 *
                                           second_phase_prediction).astype(
                                               np.uint8)
                imageio.imsave(os.path.join(video_gen_dir, frame_name),
                               second_phase_prediction[0])

            predictions = np.concatenate(predictions, axis=1)

            image_name = x['name'][0] + config['reconstruction_params'][
                'format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

    print("Reconstruction loss: %s" % np.mean(loss_list))
Example #11
0
class Trainer:
    def __init__(self, logger, checkpoint, device_ids, config):
        self.BtoA = config['cycle_loss_weight'] != 0
        self.config = config
        self.logger = logger
        self.device_ids = device_ids

        self.restore(checkpoint)

        print("Generator...")
        print(self.generatorB)

        print("Discriminator...")
        print(self.discriminatorB)

        transform = list()
        transform.append(T.Resize(config['load_size']))
        transform.append(T.RandomCrop(config['crop_size']))
        transform.append(T.ToTensor())
        transform.append(T.Normalize(mean=(0.5, 0.5, 0.5),
                                     std=(0.5, 0.5, 0.5)))
        transform = T.Compose(transform)

        self.dataset = ABDataset(config['root_dir'],
                                 partition='train',
                                 transform=transform)

    def restore(self, checkpoint):
        self.epoch = 0

        self.generatorB = Generator(**self.config['generator_params']).cuda()
        self.generatorB = DataParallelWithCallback(self.generatorB,
                                                   device_ids=self.device_ids)
        self.optimizer_generatorB = torch.optim.Adam(
            self.generatorB.parameters(),
            lr=self.config['lr_generator'],
            betas=(0.5, 0.999))

        self.discriminatorB = Discriminator(
            **self.config['discriminator_params']).cuda()
        self.discriminatorB = DataParallelWithCallback(
            self.discriminatorB, device_ids=self.device_ids)
        self.optimizer_discriminatorB = torch.optim.Adam(
            self.discriminatorB.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(0.5, 0.999))

        if self.BtoA:
            self.generatorA = Generator(
                **self.config['generator_params']).cuda()
            self.generatorA = DataParallelWithCallback(
                self.generatorA, device_ids=self.device_ids)
            self.optimizer_generatorA = torch.optim.Adam(
                self.generatorA.parameters(),
                lr=self.config['lr_generator'],
                betas=(0.5, 0.999))

            self.discriminatorA = Discriminator(
                **self.config['discriminator_params']).cuda()
            self.discriminatorA = DataParallelWithCallback(
                self.discriminatorA, device_ids=self.device_ids)
            self.optimizer_discriminatorA = torch.optim.Adam(
                self.discriminatorA.parameters(),
                lr=self.config['lr_discriminator'],
                betas=(0.5, 0.999))

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data.items():
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: min(
            1, 2 - 2 * epoch / self.config['num_epochs'])
        self.scheduler_generatorB = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generatorB, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminatorB = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminatorB,
            lr_lambda,
            last_epoch=self.epoch - 1)

        if self.BtoA:
            self.scheduler_generatorA = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_generatorA,
                lr_lambda,
                last_epoch=self.epoch - 1)
            self.scheduler_discriminatorA = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_discriminatorA,
                lr_lambda,
                last_epoch=self.epoch - 1)

    def save(self):
        state_dict = {
            'epoch': self.epoch,
            'generatorB': self.generatorB.state_dict(),
            'optimizer_generatorB': self.optimizer_generatorB.state_dict(),
            'discriminatorB': self.discriminatorB.state_dict(),
            'optimizer_discriminatorB':
            self.optimizer_discriminatorB.state_dict()
        }

        if self.BtoA:
            state_dict.update({
                'generatorA':
                self.generatorA.state_dict(),
                'optimizer_generatorA':
                self.optimizer_generatorA.state_dict(),
                'discriminatorA':
                self.discriminatorA.state_dict(),
                'optimizer_discriminatorA':
                self.optimizer_discriminatorA.state_dict()
            })

        torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth'))

    def train(self):
        np.random.seed(0)
        loader = DataLoader(self.dataset,
                            batch_size=self.config['bs'],
                            shuffle=False,
                            drop_last=True,
                            num_workers=4)
        images_fixed = None

        for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])):
            loss_dict = defaultdict(lambda: 0.0)
            iteration_count = 1
            for inp in tqdm(loader):
                images_A = inp['A'].cuda()
                images_B = inp['B'].cuda()

                if images_fixed is None:
                    images_fixed = {'A': images_A, 'B': images_B}
                    transform_fixed = Transform(
                        images_A.shape[0], **self.config['transform_params'])

                if self.config['identity_loss_weight'] != 0:
                    images_trg = self.generatorB(images_B, source=False)
                    identity_loss = l1(images_trg, images_B)
                    identity_loss = self.config[
                        'identity_loss_weight'] * identity_loss
                    identity_loss.backward()

                    loss_dict['identity_loss_B'] += identity_loss.detach().cpu(
                    ).numpy()

                if self.config['identity_loss_weight'] != 0 and self.BtoA:
                    images_trg = self.generatorA(images_A, source=False)
                    identity_loss = l1(images_trg, images_A)
                    identity_loss = self.config[
                        'identity_loss_weight'] * identity_loss
                    identity_loss.backward()

                    loss_dict['identity_loss_A'] += identity_loss.detach().cpu(
                    ).numpy()

                generator_loss = 0
                images_generatedB = self.generatorB(images_A, source=True)
                logits = self.discriminatorB(images_generatedB)
                adversarial_loss = gan_loss_generator(
                    logits, self.config['gan_loss_type'])
                adversarial_loss = self.config[
                    'adversarial_loss_weight'] * adversarial_loss
                generator_loss += adversarial_loss
                loss_dict['adversarial_loss_B'] += adversarial_loss.detach(
                ).cpu().numpy()

                if self.BtoA:
                    images_generatedA = self.generatorA(images_B, source=True)
                    logits = self.discriminatorA(images_generatedA)
                    adversarial_loss = gan_loss_generator(
                        logits, self.config['gan_loss_type'])
                    adversarial_loss = self.config[
                        'adversarial_loss_weight'] * adversarial_loss
                    generator_loss += adversarial_loss
                    loss_dict['adversarial_loss_A'] += adversarial_loss.detach(
                    ).cpu().numpy()

                if self.config['equivariance_loss_weight_generator'] != 0:
                    transform = Transform(images_generatedB.shape[0],
                                          **self.config['transform_params'])
                    images_A_transformed = transform.transform_frame(images_A)
                    loss = corr(
                        self.generatorB(images_A_transformed, source=True),
                        transform.transform_frame(images_generatedB))
                    loss = self.config[
                        'equivariance_loss_weight_generator'] * loss
                    generator_loss += loss
                    loss_dict['equivariance_generator_B'] += loss.detach().cpu(
                    ).numpy()

                if self.config[
                        'equivariance_loss_weight_generator'] != 0 and self.BtoA:
                    transform = Transform(images_generatedA.shape[0],
                                          **self.config['transform_params'])
                    images_B_transformed = transform.transform_frame(images_B)
                    loss = corr(
                        self.generatorB(images_B_transformed, source=True),
                        transform.transform_frame(images_generatedA))
                    loss = self.config[
                        'equivariance_loss_weight_generator'] * loss
                    generator_loss += loss
                    loss_dict['equivariance_generator_A'] += loss.detach().cpu(
                    ).numpy()

                if self.BtoA and self.config[
                        'cycle_loss_weight'] != 0 and self.BtoA:
                    images_cycled = self.generatorA(images_generatedB,
                                                    source=True)
                    cycle_loss = torch.abs(images_cycled - images_A).mean()
                    cycle_loss = self.config['cycle_loss_weight'] * cycle_loss
                    generator_loss += cycle_loss
                    loss_dict['cycle_loss_B'] += cycle_loss.detach().cpu(
                    ).numpy()

                    images_cycled = self.generatorB(images_generatedA,
                                                    source=True)
                    cycle_loss = torch.abs(images_cycled - images_B).mean()
                    cycle_loss = self.config['cycle_loss_weight'] * cycle_loss
                    generator_loss += cycle_loss
                    loss_dict['cycle_loss_A'] += cycle_loss.detach().cpu(
                    ).numpy()

                generator_loss.backward()

                self.optimizer_generatorB.step()
                self.optimizer_generatorB.zero_grad()
                self.optimizer_discriminatorB.zero_grad()

                if self.BtoA:
                    self.optimizer_generatorA.step()
                    self.optimizer_generatorA.zero_grad()
                    self.optimizer_discriminatorA.zero_grad()

                logits_fake = self.discriminatorB(images_generatedB.detach())
                logits_real = self.discriminatorB(images_B)
                discriminator_loss = gan_loss_discriminator(
                    logits_real, logits_fake, self.config['gan_loss_type'])
                loss_dict['discriminator_loss_B'] += discriminator_loss.detach(
                ).cpu().numpy()

                if self.config['equivariance_loss_weight_discriminator'] != 0:
                    images_join = torch.cat(
                        [images_generatedB.detach(), images_B])
                    logits_join = torch.cat([logits_fake, logits_real])

                    transform = Transform(images_join.shape[0],
                                          **self.config['transform_params'])
                    images_transformed = transform.transform_frame(images_join)
                    loss = corr(self.discriminatorB(images_transformed),
                                transform.transform_frame(logits_join))

                    loss = self.config[
                        'equivariance_loss_weight_discriminator'] * loss
                    discriminator_loss += loss
                    loss_dict['equivariance_discriminator_B'] += loss.detach(
                    ).cpu().numpy()

                discriminator_loss.backward()

                self.optimizer_discriminatorB.step()
                self.optimizer_discriminatorB.zero_grad()
                self.optimizer_generatorB.zero_grad()

                if self.BtoA:
                    logits_fake = self.discriminatorA(
                        images_generatedA.detach())
                    logits_real = self.discriminatorA(images_A)
                    discriminator_loss = gan_loss_discriminator(
                        logits_real, logits_fake, self.config['gan_loss_type'])
                    loss_dict[
                        'discriminator_loss_B'] += discriminator_loss.detach(
                        ).cpu().numpy()

                    if self.config[
                            'equivariance_loss_weight_discriminator'] != 0:
                        images_join = torch.cat(
                            [images_generatedA.detach(), images_A])
                        logits_join = torch.cat([logits_fake, logits_real])

                        transform = Transform(
                            images_join.shape[0],
                            **self.config['transform_params'])
                        images_transformed = transform.transform_frame(
                            images_join)
                        loss = corr(self.discriminatorA(images_transformed),
                                    transform.transform_frame(logits_join))

                        loss = self.config[
                            'equivariance_loss_weight_discriminator'] * loss
                        discriminator_loss += loss
                        loss_dict[
                            'equivariance_discriminator_B'] += loss.detach(
                            ).cpu().numpy()

                    discriminator_loss.backward()
                    self.optimizer_discriminatorA.step()
                    self.optimizer_discriminatorA.zero_grad()
                    self.optimizer_generatorA.zero_grad()

                iteration_count += 1

            with torch.no_grad():
                if not self.BtoA:
                    self.generatorB.eval()
                    transformed = transform_fixed.transform_frame(
                        images_fixed['A'])
                    self.logger.save_images(
                        self.epoch, images_fixed['A'],
                        self.generatorB(images_fixed['A'], source=True),
                        transformed, self.generatorB(transformed, source=True))
                    self.generatorB.train()
                else:
                    self.generatorA.eval()
                    self.generatorB.eval()

                    images_generatedB = self.generatorB(images_fixed['A'],
                                                        source=True)
                    images_generatedA = self.generatorA(images_fixed['B'],
                                                        source=True)

                    transformed = transform_fixed.transform_frame(
                        images_fixed['A'])
                    self.logger.save_images(
                        self.epoch, images_fixed['A'], images_generatedB,
                        transformed, self.generatorB(transformed, source=True),
                        self.generatorA(images_generatedB, source=True),
                        images_fixed['B'], images_generatedA,
                        self.generatorB(images_generatedA, source=True))

                    self.generatorA.train()
                    self.generatorB.train()

            self.scheduler_generatorB.step()
            self.scheduler_discriminatorB.step()
            if self.BtoA:
                self.scheduler_generatorA.step()
                self.scheduler_discriminatorA.step()

            save_dict = {
                key: value / iteration_count
                for key, value in loss_dict.items()
            }
            save_dict['lr'] = self.optimizer_generatorB.param_groups[0]['lr']

            self.logger.log(self.epoch, save_dict)
            self.save()
Example #12
0
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset,
            kp_after_softmax):
    log_dir = os.path.join(log_dir, 'animation')
    png_dir = os.path.join(log_dir, 'png')
    animate_params = config['animate_params']
    frame_size = config['dataset_params']['frame_shape'][0]
    latent_size = int(config['model_params']['common_params']['scale_factor'] *
                      frame_size)
    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            predictions = []
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            kp_source = kp_detector(source_frame)
            kp_driving_initial = kp_detector(driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                kp_driving = kp_detector(driving_frame)

                if kp_after_softmax:
                    kp_norm = normalize_kp(
                        kp_source=kp_source,
                        kp_driving=kp_driving,
                        kp_driving_initial=kp_driving_initial,
                        **animate_params['normalization_params'])
                    kp_source = draw_kp([latent_size, latent_size], kp_source)
                    kp_norm = draw_kp([latent_size, latent_size], kp_norm)
                    kp_source = norm_mask(latent_size, kp_source)
                    kp_norm = norm_mask(latent_size, kp_norm)
                    out = generator(source_frame,
                                    kp_source=kp_source,
                                    kp_driving=kp_norm)
                    kp_norm_int = F.interpolate(kp_norm,
                                                size=source_frame.shape[2:],
                                                mode='bilinear',
                                                align_corners=False)
                    out['kp_norm_int'] = kp_norm_int.repeat(1, 3, 1, 1)
                else:
                    out = generator(source_frame,
                                    kp_source=kp_source,
                                    kp_driving=kp_driving)

                out['kp_driving'] = kp_driving
                out['kp_source'] = kp_source

                predictions.append(
                    np.transpose(out['low_res_prediction'].data.cpu().numpy(),
                                 [0, 2, 3, 1])[0])
                predictions.append(
                    np.transpose(out['upscaled_prediction'].data.cpu().numpy(),
                                 [0, 2, 3, 1])[0])

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(
                        source=source_frame, driving=driving_frame, out=out)
                visualization = visualization
                visualizations.append(visualization)

            predictions = np.concatenate(predictions, axis=1)
            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            imageio.imsave(os.path.join(png_dir, result_name + '.png'),
                           (255 * predictions).astype(np.uint8))

            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
Example #13
0
def reconstruction(config, generator, kp_detector, checkpoint, log_dir,
                   dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='test'.")
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    cat_dict = lambda l, dim: {
        k: torch.cat([v[k] for v in l], dim=dim)
        for k in l[0]
    }
    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            kp_appearance = kp_detector(x['video'][:, :, :1])
            d = x['video'].shape[2]
            kp_video = cat_dict(
                [kp_detector(x['video'][:, :, i:(i + 1)]) for i in range(d)],
                dim=1)

            out = generate(generator,
                           appearance_image=x['video'][:, :, :1],
                           kp_appearance=kp_appearance,
                           kp_video=kp_video)
            x['source'] = x['video'][:, :, :1]

            # Store to .png for evaluation
            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(
                out_video_batch, [0, 2, 3, 4, 1])[0],
                                             axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'),
                           (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(
                **config['visualizer_params']).visualize_reconstruction(
                    x, out)
            image_name = x['name'][0] + config['reconstruction_params'][
                'format']
            imageio.mimsave(os.path.join(log_dir, image_name), image)

            loss = reconstruction_loss(out['video_prediction'].cpu(),
                                       x['video'].cpu(), 1)
            loss_list.append(loss.data.cpu().numpy())
            del x, kp_video, kp_appearance, out, loss
    print("Reconstruction loss: %s" % np.mean(loss_list))
def animate(config, generator, region_predictor, avd_network, checkpoint,
            log_dir, dataset):
    animate_params = config['animate_params']
    log_dir = os.path.join(log_dir, 'animation')

    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        region_predictor=region_predictor,
                        avd_network=avd_network)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        region_predictor = DataParallelWithCallback(region_predictor)
        avd_network = DataParallelWithCallback(avd_network)

    generator.eval()
    region_predictor.eval()
    avd_network.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            source_region_params = region_predictor(source_frame)
            driving_region_params_initial = region_predictor(
                driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                driving_region_params = region_predictor(driving_frame)
                new_region_params = get_animation_region_params(
                    source_region_params,
                    driving_region_params,
                    driving_region_params_initial,
                    mode=animate_params['mode'],
                    avd_network=avd_network)
                out = generator(source_frame,
                                source_region_params=source_region_params,
                                driving_region_params=new_region_params)

                out['driving_region_params'] = driving_region_params
                out['source_region_params'] = source_region_params
                out['new_region_params'] = new_region_params

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(
                        source=source_frame, driving=driving_frame, out=out)
                visualizations.append(visualization)

            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
Example #15
0
def main():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    parser = argparse.ArgumentParser()
    parser.add_argument('--LR',
                        type=list,
                        default=[1e-4, 1e-4],
                        help='learning rate')  # start from 1e-4
    parser.add_argument('--EPOCH', type=int, default=30, help='epoch')
    parser.add_argument('--slice_num',
                        type=int,
                        default=6,
                        help='how many slices to cut')
    parser.add_argument('--batch_size',
                        type=int,
                        default=40,
                        help='batch_size')
    parser.add_argument('--frame_num',
                        type=int,
                        default=5,
                        help='how many frames in a slice')
    parser.add_argument('--model_path',
                        type=str,
                        default='/Disk1/poli/models/DeepRNN/Kinetics_res18',
                        help='model_path')
    parser.add_argument('--model_name',
                        type=str,
                        default='checkpoint',
                        help='model name')
    parser.add_argument('--video_path',
                        type=str,
                        default='/home/poli/kinetics_scaled',
                        help='video path')
    parser.add_argument('--class_num', type=int, default=400, help='class num')
    parser.add_argument('--device_id',
                        type=list,
                        default=[0, 1, 2, 3],
                        help='learning rate')
    parser.add_argument('--resume', action='store_true', help='whether resume')
    parser.add_argument('--dropout',
                        type=list,
                        default=[0.2, 0.5],
                        help='dropout')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-4,
                        help='weight decay')
    parser.add_argument('--saveInter',
                        type=int,
                        default=1,
                        help='how many epoch to save once')
    parser.add_argument('--TD_rate',
                        type=float,
                        default=0.0,
                        help='propabaility of detachout')
    parser.add_argument('--img_size', type=int, default=224, help='image size')
    parser.add_argument('--syn_bn', action='store_true', help='use syn_bn')
    parser.add_argument('--logName',
                        type=str,
                        default='logs_res18',
                        help='log dir name')
    parser.add_argument('--train', action='store_true', help='train the model')
    parser.add_argument('--test', action='store_true', help='test the model')
    parser.add_argument(
        '--overlap_rate',
        type=float,
        default=0.25,
        help='the overlap rate of the overlap coherence training scheme')
    parser.add_argument('--lambdaa',
                        type=float,
                        default=0.0,
                        help='weight of the overlap coherence loss')

    opt = parser.parse_args()
    print(opt)

    torch.cuda.set_device(opt.device_id[0])

    # ######################## Module #################################
    print('Building model')
    model = actionModel(opt.class_num,
                        batch_norm=True,
                        dropout=opt.dropout,
                        TD_rate=opt.TD_rate,
                        image_size=opt.img_size,
                        syn_bn=opt.syn_bn,
                        test_scheme=3)
    print(model)
    if opt.syn_bn:
        model = DataParallelWithCallback(model,
                                         device_ids=opt.device_id).cuda()
    else:
        model = torch.nn.DataParallel(model, device_ids=opt.device_id).cuda()
    print("Channels: " + str(model.module.channels))

    # ########################Optimizer#########################
    optimizer = torch.optim.SGD([{
        'params': model.module.RNN.parameters(),
        'lr': opt.LR[0]
    }, {
        'params': model.module.ShortCut.parameters(),
        'lr': opt.LR[0]
    }, {
        'params': model.module.classifier.parameters(),
        'lr': opt.LR[1]
    }],
                                lr=opt.LR[1],
                                weight_decay=opt.weight_decay,
                                momentum=0.9)

    # ###################### Loss Function ####################################
    loss_classification_func = nn.NLLLoss(reduce=True)

    def loss_overlap_coherence_func(pre, cur):
        loss = nn.MSELoss()
        return loss(cur, pre.detach())

    # ###################### Resume ##########################################
    resume_epoch = 0
    resume_step = 0
    max_test_acc = 0

    if opt.resume or opt.test:
        print("loading model")
        checkpoint = torch.load(opt.model_path + '/' + opt.model_name,
                                map_location={
                                    'cuda:0': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:1': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:2': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:3': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:4': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:5': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:6': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:7': 'cuda:' + str(opt.device_id[0])
                                })

        model.load_state_dict(checkpoint['model'], strict=True)
        try:
            optimizer.load_state_dict(checkpoint['opt'], strict=True)
        except:
            pass
        for group_id, param_group in enumerate(optimizer.param_groups):
            if group_id == 0:
                param_group['lr'] = opt.LR[0]
            elif group_id == 1:
                param_group['lr'] = opt.LR[0]
            elif group_id == 2:
                param_group['lr'] = opt.LR[1]
        resume_epoch = checkpoint['epoch']
        if 'step' in checkpoint:
            resume_step = checkpoint['step'] + 1
        if 'max_acc' in checkpoint:
            max_test_acc = checkpoint['max_acc']
        print('Finish Loading')
        del checkpoint
    # ###########################################################################

    # training and testing
    model.train()
    predict_for_mAP = []
    label_for_mAP = []

    print("START")

    KineticsLoader = torch.utils.data.DataLoader(
        Kinetic_train_dataset.Kinetics(video_path=opt.video_path +
                                       '/train_frames',
                                       frame_num=opt.frame_num,
                                       batch_size=opt.batch_size,
                                       img_size=opt.img_size,
                                       slice_num=opt.slice_num,
                                       overlap_rate=opt.overlap_rate),
        batch_size=1,
        shuffle=True,
        num_workers=8)
    Loader_test = torch.utils.data.DataLoader(Kinetics_test_dataset.Kinetics(
        video_path=opt.video_path + '/val_frames',
        img_size=224,
        space=5,
        split_num=8,
        lenn=60,
        num_class=opt.class_num),
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=4)
    tensorboard_writer = SummaryWriter(
        opt.logName,
        purge_step=resume_epoch * len(KineticsLoader) * opt.slice_num +
        (resume_step + resume_step) * opt.slice_num)
    test = opt.test
    for epoch in range(resume_epoch, opt.EPOCH):

        predict_for_mAP = []
        label_for_mAP = []

        for step, (x, _, overlap_frame_num,
                   action) in enumerate(KineticsLoader):  # gives batch data

            if opt.train:
                if step + resume_step >= len(KineticsLoader):
                    break
                x = x[0]
                action = action[0]
                overlap_frame_num = overlap_frame_num[0]

                c = [
                    Variable(
                        torch.from_numpy(
                            np.zeros(
                                (x.shape[1], model.module.channels[layer + 1],
                                 model.module.input_size[layer],
                                 model.module.input_size[layer]
                                 )))).cuda().float()
                    for layer in range(model.module.RNN_layer)
                ]
                for slice in range(x.shape[0]):
                    b_x = Variable(x[slice]).cuda()
                    b_action = Variable(action[slice]).cuda()

                    out, out_beforeMerge, c = model(b_x.float(),
                                                    c)  # rnn output
                    for batch in range(len(out)):
                        predict_for_mAP.append(out[batch].data.cpu().numpy())
                        label_for_mAP.append(
                            b_action[batch][-1].data.cpu().numpy())

                    # ###################### overlap coherence loss #######################################################################################
                    loss_coherence = torch.zeros(1).cuda()

                    # claculate the coherence loss with the previous clip and current clip
                    if slice != 0:
                        for b in range(out.size()[0]):
                            loss_coherence += loss_overlap_coherence_func(
                                old_overlap[b],
                                torch.exp(out_beforeMerge[
                                    b, :overlap_frame_num[slice, b, 0].int()]))
                        loss_coherence = loss_coherence / out.size()[0]

                    # record the previous clips output
                    old_overlap = []
                    for b in range(out.size()[0]):
                        old_overlap.append(
                            torch.exp(
                                out_beforeMerge[b,
                                                -overlap_frame_num[slice, b,
                                                                   0].int():]))
                    #######################################################################################################################################

                    loss_classification = loss_classification_func(
                        out, b_action[:, -1].long())

                    loss = loss_classification + opt.lambdaa * loss_coherence
                    tensorboard_writer.add_scalar(
                        'train/loss', loss,
                        epoch * len(KineticsLoader) * opt.slice_num +
                        (step + resume_step) * opt.slice_num + slice)

                    loss.backward(retain_graph=False)

                predict_for_mAP = predict_for_mAP
                label_for_mAP = label_for_mAP
                mAPs = mAP(predict_for_mAP, label_for_mAP, 'Lsm')
                acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm')
                tensorboard_writer.add_scalar(
                    'train/mAP', mAPs,
                    epoch * len(KineticsLoader) * opt.slice_num +
                    (step + resume_step) * opt.slice_num + slice)
                tensorboard_writer.add_scalar(
                    'train/acc', acc,
                    epoch * len(KineticsLoader) * opt.slice_num +
                    (step + resume_step) * opt.slice_num + slice)

                print("Epoch: " + str(epoch) + " step: " +
                      str(step + resume_step) + " Loss: " +
                      str(loss.data.cpu().numpy()) + " Loss_coherence: " +
                      str(loss_coherence.data.cpu().numpy()) + " mAP: " +
                      str(mAPs)[0:7] + " acc: " + str(acc)[0:7])

                for p in model.module.parameters():
                    p.grad.data.clamp_(min=-5, max=5)

                if step % 2 == 1:
                    optimizer.step()
                    optimizer.zero_grad()

                predict_for_mAP = []
                label_for_mAP = []

            # ################################### test ###############################
            if (step + resume_step) % 700 == 699:
                test = True

            if test:
                print('Start Test')
                TEST_LOSS = AverageMeter()
                with torch.no_grad():
                    model.eval()
                    predict_for_mAP = []
                    label_for_mAP = []
                    print("TESTING")

                    for step_test, (x, _, _, action) in tqdm(
                            enumerate(Loader_test)):  # gives batch data
                        b_x = Variable(x).cuda()
                        b_action = Variable(action).cuda()

                        c = [
                            Variable(
                                torch.from_numpy(
                                    np.zeros((len(b_x),
                                              model.module.channels[layer + 1],
                                              model.module.input_size[layer],
                                              model.module.input_size[layer]
                                              )))).cuda().float()
                            for layer in range(model.module.RNN_layer)
                        ]
                        out, _, _ = model(b_x.float(), c)  # rnn output
                        loss = loss_classification_func(
                            out, b_action[:, -1].long())
                        TEST_LOSS.update(val=loss.data.cpu().numpy())

                        for batch in range(len(out)):
                            predict_for_mAP.append(
                                out[batch].data.cpu().numpy())
                            label_for_mAP.append(
                                b_action[batch][-1].data.cpu().numpy())

                        if step_test % 50 == 0:
                            MAP = mAP(np.array(predict_for_mAP),
                                      np.array(label_for_mAP), 'Lsm')
                            acc = accuracy(np.array(predict_for_mAP),
                                           np.array(label_for_mAP), 'Lsm')
                            print(" Loss: " + str(TEST_LOSS.avg)[0:5] + '  ' +
                                  'accuracy: ' + str(acc)[0:7])

                    predict_for_mAP = np.array(predict_for_mAP)
                    label_for_mAP = np.array(label_for_mAP)

                    MAP = mAP(predict_for_mAP, label_for_mAP, 'Lsm')
                    acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm')

                    print("mAP: " + str(MAP) + '  ' + 'accuracy: ' + str(acc))

                    if acc > max_test_acc:
                        print('Saving')
                        max_test_acc = acc
                        torch.save(
                            {
                                'model': model.state_dict(),
                                'max_acc': max_test_acc,
                                'epoch': epoch,
                                'step': 0,
                                'opt': optimizer.state_dict()
                            }, opt.model_path + '/' + opt.model_name + '_' +
                            str(epoch) + '_' + str(max_test_acc)[0:6])
                    model.train()

                    test = False
                    predict_for_mAP = []
                    label_for_mAP = []

                    if opt.test:
                        exit()

        if epoch % opt.saveInter == 0:
            print('Saving')
            torch.save(
                {
                    'model': model.state_dict(),
                    'max_acc': max_test_acc,
                    'epoch': epoch,
                    'step': 0,
                    'opt': optimizer.state_dict()
                }, opt.model_path + '/' + opt.model_name + '_' + str(epoch))

        resume_step = 0
Example #16
0
def train_motion_embedding(config, generator, motion_generator, kp_detector,
                           checkpoint, log_dir, dataset, valid_dataset,
                           device_ids):

    png_dir = os.path.join(log_dir, 'train_motion_embedding/png')
    log_dir = os.path.join(log_dir, 'train_motion_embedding')

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    train_params = config['train_motion_embedding_params']
    optimizer_generator = torch.optim.Adam(motion_generator.parameters(),
                                           lr=train_params['lr'],
                                           betas=(0.5, 0.999))

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "kp_detector_checkpoint should be specified for mode='test'.")

    start_epoch = 0
    it = 0

    scheduler_generator = MultiStepLR(optimizer_generator,
                                      train_params['epoch_milestones'],
                                      gamma=0.1,
                                      last_epoch=start_epoch - 1)
    dataloader = DataLoader(valid_dataset,
                            batch_size=train_params['batch_size'],
                            shuffle=True,
                            num_workers=4)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=1)

    loss_list = []
    motion_generator_full = MotionGeneratorFullModel(motion_generator,
                                                     train_params)
    motion_generator_full_par = DataParallelWithCallback(motion_generator_full,
                                                         device_ids=device_ids)

    kp_detector = DataParallelWithCallback(kp_detector)
    generator.eval()
    kp_detector.eval()
    cat_dict = lambda l, dim: {
        k: torch.cat([v[k] for v in l], dim=dim)
        for k in l[0]
    }

    with Logger(log_dir=log_dir,
                visualizer_params=config['visualizer_params'],
                **train_params['log_params']) as logger:
        #valid_motion_embedding(config, valid_dataloader, motion_generator, kp_detector, log_dir)
        for epoch in range(start_epoch, train_params['num_epochs']):
            print("Epoch {}".format(epoch))
            motion_generator.train()
            for it, x in tqdm(enumerate(dataloader)):

                with torch.no_grad():

                    # import ipdb; ipdb.set_trace()

                    # x['video']: [bz, ch, #frames, H, W]
                    # detect keypoint for first frame
                    kp_appearance = kp_detector(x['video'][:, :, :1])
                    # kp_appearance['mean']: [bz, frame idx, #kp, 2]
                    # kp_appearance['var']: [bz, frame idx, #kp, 2, 2]

                    d = x['video'].shape[2]
                    # kp_video['mean']: [bz, #frame, #kp, 2]
                    # kp_video['var']: [bz, #frame, #kp, 2, 2]
                    kp_video = cat_dict([
                        kp_detector(x['video'][:, :, i:(i + 1)])
                        for i in range(d)
                    ],
                                        dim=1)


                loss = motion_generator_full_par(d, kp_video, \
                    epoch, it==len(dataloader)-1)

                loss.backward()
                optimizer_generator.step()
                optimizer_generator.zero_grad()
                generator_loss_values = [loss.detach().cpu().numpy()]

                logger.log_iter(it,
                                names=generator_loss_names(
                                    train_params['loss_weights']),
                                values=generator_loss_values,
                                inp=x)

            valid_motion_embedding(config, valid_dataloader, motion_generator,
                                   kp_detector, log_dir)

            scheduler_generator.step()
            logger.log_epoch(epoch, {
                'generator': generator,
                'optimizer_generator': optimizer_generator
            })