Ejemplo n.º 1
0
    def finetune(self, source_path, save_checkpoint=True):
        checkpoint_path = os.path.splitext(source_path)[0] + '_Gr.pth'
        if os.path.isfile(checkpoint_path):
            print('=> Loading the reenactment generator finetuned on: "%s"...' % os.path.basename(source_path))
            checkpoint = torch.load(checkpoint_path)
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.Gr.load_state_dict(checkpoint['state_dict'])
            return

        print('=> Finetuning the reenactment generator on: "%s"...' % os.path.basename(source_path))
        torch.set_grad_enabled(True)
        self.Gr.train(True)
        img_transforms = img_lms_pose_transforms.Compose([Pyramids(2), ToTensor(), Normalize()])
        train_dataset = SingleSeqRandomPairDataset(source_path, transform=img_transforms, postfixes=('_lms.npz',))
        train_sampler = RandomSampler(train_dataset, replacement=True, num_samples=self.finetune_iterations)
        train_loader = DataLoader(train_dataset, batch_size=self.finetune_batch_size, sampler=train_sampler,
                                  num_workers=self.finetune_workers, pin_memory=True, drop_last=True, shuffle=False)
        optimizer = optim.Adam(self.Gr.parameters(), lr=self.finetune_lr, betas=(0.5, 0.999))

        # For each batch in the training data
        for i, (img, landmarks) in enumerate(tqdm(train_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            with torch.no_grad():
                # For each view images and landmarks
                landmarks[1] = landmarks[1].to(self.device)
                for j in range(len(img)):
                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(self.device)

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0])):
                    context = self.landmarks_decoders[p](landmarks[1])
                    input.append(torch.cat((img[0][p], context), dim=1))

            # Reenactment
            img_pred = self.Gr(input)

            # Reconstruction loss
            loss_pixelwise = self.criterion_pixelwise(img_pred, img[1][0])
            loss_id = self.criterion_id(img_pred, img[1][0])
            loss_rec = 0.1 * loss_pixelwise + loss_id

            # Update generator weights
            optimizer.zero_grad()
            loss_rec.backward()
            optimizer.step()

        # Save finetuned weights to file
        if save_checkpoint:
            arch = self.Gr.module.arch if self.gpus and len(self.gpus) > 1 else self.Gr.arch
            state_dict = self.Gr.module.state_dict() if self.gpus and len(self.gpus) > 1 else self.Gr.state_dict()
            torch.save({'state_dict': state_dict, 'arch': arch}, checkpoint_path)

        torch.set_grad_enabled(False)
        self.Gr.train(False)
Ejemplo n.º 2
0
    def __call__(self,
                 source_path,
                 target_path,
                 output_path=None,
                 select_source='longest',
                 select_target='longest',
                 finetune=None):
        is_vid = os.path.splitext(source_path)[1] == '.mp4'
        finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid

        # Validation
        assert os.path.isfile(
            source_path), 'Source path "%s" does not exist' % source_path
        assert os.path.isfile(
            target_path), 'Target path "%s" does not exist' % target_path

        # Cache input
        source_cache_dir, source_seq_file_path, _ = self.cache(source_path)
        target_cache_dir, target_seq_file_path, _ = self.cache(target_path)

        # Load sequences from file
        with open(source_seq_file_path, "rb") as fp:  # Unpickling
            source_seq_list = pickle.load(fp)
        with open(target_seq_file_path, "rb") as fp:  # Unpickling
            target_seq_list = pickle.load(fp)

        # Select source and target sequence
        source_seq = select_seq(source_seq_list, select_source)
        target_seq = select_seq(target_seq_list, select_target)

        # Set source and target sequence videos paths
        src_path_no_ext, src_ext = os.path.splitext(source_path)
        src_vid_seq_name = os.path.basename(
            src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext)
        src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name)
        tgt_path_no_ext, tgt_ext = os.path.splitext(target_path)
        tgt_vid_seq_name = os.path.basename(
            tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext)
        tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name)

        # Set output path
        if output_path is not None:
            if os.path.isdir(output_path):
                output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4'
                output_path = os.path.join(output_path, output_filename)

        # Initialize appearance map
        src_transform = img_lms_pose_transforms.Compose(
            [Rotate(), Pyramids(2),
             ToTensor(), Normalize()])
        tgt_transform = img_lms_pose_transforms.Compose(
            [ToTensor(), Normalize()])
        appearance_map = AppearanceMapDataset(
            src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform,
            self.landmarks_postfix, self.pose_postfix,
            self.segmentation_postfix, self.min_radius)
        appearance_map_loader = DataLoader(appearance_map,
                                           batch_size=self.batch_size,
                                           num_workers=1,
                                           pin_memory=True,
                                           drop_last=False,
                                           shuffle=False)

        # Initialize video writer
        self.video_renderer.init(target_path,
                                 target_seq,
                                 output_path,
                                 _appearance_map=appearance_map)

        # Finetune reenactment model on source sequences
        if finetune:
            self.finetune(src_vid_seq_path, self.finetune_save)

        print(
            f'=> Face swapping: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...'
        )

        # For each batch of frames in the target video
        for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \
                in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            for p in range(len(src_frame)):
                src_frame[p] = src_frame[p].to(self.device)
            tgt_frame = tgt_frame.to(self.device)
            tgt_landmarks = tgt_landmarks.to(self.device)
            # tgt_mask = tgt_mask.unsqueeze(1).to(self.device)
            tgt_mask = tgt_mask.unsqueeze(1).int().to(self.device).bool(
            )  # TODO: check if the boolean tensor bug is fixed
            bw = bw.to(self.device)
            bw_indices = torch.nonzero(torch.any(bw > 0, dim=0),
                                       as_tuple=True)[0]
            bw = bw[:, bw_indices]

            # For each source frame perform reenactment
            reenactment_triplet = []
            for j in bw_indices:
                input = []
                for p in range(len(src_frame)):
                    context = self.landmarks_decoders[p](tgt_landmarks)
                    input.append(
                        torch.cat((src_frame[p][:, j], context), dim=1))

                # Reenactment
                reenactment_triplet.append(self.Gr(input).unsqueeze(1))
            reenactment_tensor = torch.cat(reenactment_triplet, dim=1)

            # Barycentric interpolation of reenacted frames
            reenactment_tensor = (reenactment_tensor *
                                  bw.view(*bw.shape, 1, 1, 1)).sum(dim=1)

            # Compute reenactment segmentation
            reenactment_seg = self.S(reenactment_tensor)
            reenactment_background_mask_tensor = (reenactment_seg.argmax(1) !=
                                                  1).unsqueeze(1)

            # Remove the background of the aligned face
            reenactment_tensor.masked_fill_(reenactment_background_mask_tensor,
                                            -1.0)

            # Soften target mask
            soft_tgt_mask, eroded_tgt_mask = self.smooth_mask(tgt_mask)

            # Complete face
            inpainting_input_tensor = torch.cat(
                (reenactment_tensor, eroded_tgt_mask.float()), dim=1)
            inpainting_input_tensor_pyd = create_pyramid(
                inpainting_input_tensor, 2)
            completion_tensor = self.Gc(inpainting_input_tensor_pyd)

            # Blend faces
            transfer_tensor = transfer_mask(completion_tensor, tgt_frame,
                                            eroded_tgt_mask)
            blend_input_tensor = torch.cat(
                (transfer_tensor, tgt_frame, eroded_tgt_mask.float()), dim=1)
            blend_input_tensor_pyd = create_pyramid(blend_input_tensor, 2)
            blend_tensor = self.Gb(blend_input_tensor_pyd)

            result_tensor = blend_tensor * soft_tgt_mask + tgt_frame * (
                1 - soft_tgt_mask)

            # Write output
            if self.verbose == 0:
                self.video_renderer.write(result_tensor)
            elif self.verbose == 1:
                curr_src_frames = [
                    src_frame[0][:, i] for i in range(src_frame[0].shape[1])
                ]
                self.video_renderer.write(*curr_src_frames, result_tensor,
                                          tgt_frame)
            else:
                curr_src_frames = [
                    src_frame[0][:, i] for i in range(src_frame[0].shape[1])
                ]
                tgt_seg_blend = blend_seg_label(tgt_frame,
                                                tgt_mask.squeeze(1),
                                                alpha=0.2)
                soft_tgt_mask = soft_tgt_mask.mul(2.).sub(1.).repeat(
                    1, 3, 1, 1)
                self.video_renderer.write(*curr_src_frames, result_tensor,
                                          tgt_frame, reenactment_tensor,
                                          completion_tensor, transfer_tensor,
                                          soft_tgt_mask, tgt_seg_blend,
                                          tgt_pose)

        # Load original reenactment weights
        if finetune:
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(self.reenactment_state_dict)
            else:
                self.Gr.load_state_dict(self.reenactment_state_dict)

        # Finalize video and wait for the video writer to finish writing
        self.video_renderer.finalize()
        self.video_renderer.wait_until_finished()
Ejemplo n.º 3
0
    def __call__(self,
                 source_path,
                 target_path,
                 output_path=None,
                 select_source='longest',
                 select_target='longest',
                 finetune=None):
        is_vid = os.path.splitext(source_path)[1] == '.mp4'
        finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid

        # Validation
        assert os.path.isfile(
            source_path), 'Source path "%s" does not exist' % source_path
        assert os.path.isfile(
            target_path), 'Target path "%s" does not exist' % target_path

        # Cache input
        source_cache_dir, source_seq_file_path, _ = self.cache(source_path)
        target_cache_dir, target_seq_file_path, _ = self.cache(target_path)

        # Load sequences from file
        with open(source_seq_file_path, "rb") as fp:  # Unpickling
            source_seq_list = pickle.load(fp)
        with open(target_seq_file_path, "rb") as fp:  # Unpickling
            target_seq_list = pickle.load(fp)

        # Select source and target sequence
        source_seq = select_seq(source_seq_list, select_source)
        target_seq = select_seq(target_seq_list, select_target)

        # Set source and target sequence videos paths
        src_path_no_ext, src_ext = os.path.splitext(source_path)
        src_vid_seq_name = os.path.basename(
            src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext)
        src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name)
        tgt_path_no_ext, tgt_ext = os.path.splitext(target_path)
        tgt_vid_seq_name = os.path.basename(
            tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext)
        tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name)

        # Set output path
        if output_path is not None:
            if os.path.isdir(output_path):
                output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4'
                output_path = os.path.join(output_path, output_filename)

        # Initialize appearance map
        src_transform = img_lms_pose_transforms.Compose(
            [Rotate(), Pyramids(2),
             ToTensor(), Normalize()])
        tgt_transform = img_lms_pose_transforms.Compose(
            [ToTensor(), Normalize()])
        appearance_map = AppearanceMapDataset(
            src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform,
            self.landmarks_postfix, self.pose_postfix,
            self.segmentation_postfix, self.min_radius)
        appearance_map_loader = DataLoader(appearance_map,
                                           batch_size=self.batch_size,
                                           num_workers=1,
                                           pin_memory=True,
                                           drop_last=False,
                                           shuffle=False)

        # Initialize video renderer
        self.video_renderer.init(target_path,
                                 target_seq,
                                 output_path,
                                 _appearance_map=appearance_map)

        # Finetune reenactment model on source sequences
        if finetune:
            self.finetune(src_vid_seq_path, self.finetune_save)

        print(
            f'=> Face reenactment: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...'
        )

        # For each batch of frames in the target video
        for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \
                in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            for p in range(len(src_frame)):
                src_frame[p] = src_frame[p].to(self.device)
            tgt_landmarks = tgt_landmarks.to(self.device)
            bw = bw.to(self.device)
            bw_indices = torch.nonzero(torch.any(bw > 0, dim=0),
                                       as_tuple=True)[0]
            bw = bw[:, bw_indices]

            # For each source frame perform reenactment
            reenactment_triplet = []
            for j in bw_indices:
                input = []
                for p in range(len(src_frame)):
                    context = self.landmarks_decoders[p](tgt_landmarks)
                    input.append(
                        torch.cat((src_frame[p][:, j], context), dim=1))

                # Reenactment
                reenactment_triplet.append(self.Gr(input).unsqueeze(1))
            reenactment_tensor = torch.cat(reenactment_triplet, dim=1)

            # Barycentric interpolation of reenacted frames
            reenactment_tensor = (reenactment_tensor *
                                  bw.view(*bw.shape, 1, 1, 1)).sum(dim=1)

            # Write output
            if self.verbose == 0:
                self.video_renderer.write(reenactment_tensor)
            elif self.verbose == 1:
                print(
                    (src_frame[0][:,
                                  0][0], reenactment_tensor[0], tgt_frame[0]))
                write_bgr = tensor2bgr(
                    torch.cat((src_frame[0][:, 0][0], reenactment_tensor[0],
                               tgt_frame[0]),
                              dim=2))
                cv2.imwrite(fr'{output_path}.jpg', write_bgr)
                self.video_renderer.write(src_frame[0][:, 0],
                                          reenactment_tensor, tgt_frame)
            else:
                self.video_renderer.write(src_frame[0][:, 0], src_frame[0][:,
                                                                           1],
                                          src_frame[0][:,
                                                       2], reenactment_tensor,
                                          tgt_frame, tgt_pose)

        # Load original reenactment weights
        if finetune:
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(self.reenactment_state_dict)
            else:
                self.Gr.load_state_dict(self.reenactment_state_dict)

        # Wait for the video render to finish rendering
        self.video_renderer.finalize()
        self.video_renderer.wait_until_finished()
def main(
    # General arguments
    exp_dir, resume_dir=None, start_epoch=None, epochs=(90,), iterations=None, resolutions=(128, 256),
    lr_gen=(1e-4,), lr_dis=(1e-4,), gpus=None, workers=4, batch_size=(64,), seed=None, log_freq=20,

    # Data arguments
    train_dataset='opencv_video_seq_dataset.VideoSeqDataset', val_dataset=None, numpy_transforms=None,
    tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'),

    # Training arguments
    optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)',
    pretrained=False, criterion_pixelwise='nn.L1Loss', criterion_id='vgg_loss.VGGLoss',
    criterion_attr='vgg_loss.VGGLoss', criterion_gan='gan_loss.GANLoss(use_lsgan=True)',
    generator='res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)',
    discriminator='discriminators_pix2pix.MultiscaleDiscriminator',
    rec_weight=1.0, gan_weight=0.001
):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        G.train(train)
        D.train(train)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,  optimizer_G.param_groups[0]['lr']))

        # For each batch in the training data
        for i, (img, landmarks, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images and landmarks
                landmarks[1] = landmarks[1].to(device)
                for j in range(len(img)):
                    # landmarks[j] = landmarks[j].to(device)

                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0])):
                    context = res_landmarks_decoders[p](landmarks[1])
                    input.append(torch.cat((img[0][p], context), dim=1))

            # Reenactment
            img_pred = G(input)

            # Fake Detection and Loss
            img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in img_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            pred_real = D(img[1])
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(img_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction and segmentation loss
            loss_pixelwise = criterion_pixelwise(img_pred, img[1][0])
            loss_id = criterion_id(img_pred, img[1][0])
            loss_attr = criterion_attr(img_pred, img[1][0])
            loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr

            loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec,
                          g_gan=loss_G_GAN, d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            grid = img_utils.make_grid(img[0][0], img_pred, img[1][0])
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg

    #################
    # Main pipeline #
    #################

    # Validation
    resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions]
    lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen]
    lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis]
    epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs]
    batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size]
    iterations = iterations if iterations is None or isinstance(iterations, (list, tuple)) else [iterations]

    lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen
    lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis
    epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs
    batch_size = batch_size * len(resolutions) if len(batch_size) == 1 else batch_size
    if iterations is not None:
        iterations = iterations * len(resolutions) if len(iterations) == 1 else iterations
        iterations = utils.str2int(iterations)

    if not os.path.isdir(exp_dir):
        raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'')
    assert len(lr_gen) == len(resolutions)
    assert len(lr_dis) == len(resolutions)
    assert len(epochs) == len(resolutions)
    assert len(batch_size) == len(resolutions)
    assert iterations is None or len(iterations) == len(resolutions)

    # Seed
    utils.set_seed(seed)

    # Check CUDA device availability
    device, gpus = utils.set_device(gpus)

    # Initialize loggers
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Initialize datasets
    numpy_transforms = obj_factory(numpy_transforms) if numpy_transforms is not None else []
    tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_lms_pose_transforms.Compose(numpy_transforms + tensor_transforms)

    train_dataset = obj_factory(train_dataset, transform=img_transforms)
    if val_dataset is not None:
        val_dataset = obj_factory(val_dataset, transform=img_transforms)

    # Create networks
    G_arch = utils.get_arch(generator)
    D_arch = utils.get_arch(discriminator)
    G = obj_factory(generator).to(device)
    D = obj_factory(discriminator).to(device)

    # Resume from a checkpoint or initialize the networks weights randomly
    checkpoint_dir = exp_dir if resume_dir is None else resume_dir
    G_path = os.path.join(checkpoint_dir, 'G_latest.pth')
    D_path = os.path.join(checkpoint_dir, 'D_latest.pth')
    best_loss = 1e6
    curr_res = resolutions[0]
    optimizer_G_state, optimizer_D_state = None, None
    if os.path.isfile(G_path) and os.path.isfile(D_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # G
        checkpoint = torch.load(G_path)
        if 'resolution' in checkpoint:
            curr_res = checkpoint['resolution']
            start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
        # else:
        #     curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0]
        best_loss_key = 'best_loss_%d' % curr_res
        best_loss = checkpoint[best_loss_key] if best_loss_key in checkpoint else best_loss
        G.apply(utils.init_weights)
        G.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer_G_state = checkpoint['optimizer']

        # D
        D.apply(utils.init_weights)
        if os.path.isfile(D_path):
            checkpoint = torch.load(D_path)
            D.load_state_dict(checkpoint['state_dict'], strict=False)
            optimizer_D_state = checkpoint['optimizer']
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            G.apply(utils.init_weights)
            D.apply(utils.init_weights)

    # Initialize landmarks decoders
    landmarks_decoders = []
    for res in resolutions:
        landmarks_decoders.insert(0, landmarks_utils.LandmarksHeatMapDecoder(res).to(device))

    # Lossess
    criterion_pixelwise = obj_factory(criterion_pixelwise).to(device)
    criterion_id = obj_factory(criterion_id).to(device)
    criterion_attr = obj_factory(criterion_attr).to(device)
    criterion_gan = obj_factory(criterion_gan).to(device)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        G = nn.DataParallel(G, gpus)
        D = nn.DataParallel(D, gpus)
        criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus)
        criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus)
        landmarks_decoders = [nn.DataParallel(ld, gpus) for ld in landmarks_decoders]

    # For each resolution
    start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0]))
    start_epoch = 0 if start_epoch is None else start_epoch
    for ri in range(start_res_ind, len(resolutions)):
        res = resolutions[ri]
        res_lr_gen = lr_gen[ri]
        res_lr_dis = lr_dis[ri]
        res_epochs = epochs[ri]
        res_iterations = iterations[ri] if iterations is not None else None
        res_batch_size = batch_size[ri]
        res_landmarks_decoders = landmarks_decoders[-ri - 1:]

        # Optimizer and scheduler
        optimizer_G = obj_factory(optimizer, G.parameters(), lr=res_lr_gen)
        optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis)
        scheduler_G = obj_factory(scheduler, optimizer_G)
        scheduler_D = obj_factory(scheduler, optimizer_D)
        if optimizer_G_state is not None:
            optimizer_G.load_state_dict(optimizer_G_state)
            optimizer_G_state = None
        if optimizer_D_state is not None:
            optimizer_D.load_state_dict(optimizer_D_state)
            optimizer_D_state = None

        # Initialize data loaders
        if res_iterations is None:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, len(train_dataset))
        else:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, res_iterations)
        train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler,
                                              num_workers=workers, pin_memory=True, drop_last=True, shuffle=False)
        if val_dataset is not None:
            if res_iterations is None:
                val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, len(val_dataset))
            else:
                val_iterations = (res_iterations * len(val_dataset.classes)) // len(train_dataset.classes)
                val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, val_iterations)
            val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler,
                                                num_workers=workers, pin_memory=True, drop_last=True, shuffle=False)
        else:
            val_loader = None

        # For each epoch
        for epoch in range(start_epoch, res_epochs):
            total_loss = proces_epoch(train_loader, train=True)
            if val_loader is not None:
                with torch.no_grad():
                    total_loss = proces_epoch(val_loader, train=False)

            # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler_G.step(total_loss)
                scheduler_D.step(total_loss)
            else:
                scheduler_G.step()
                scheduler_D.step()

            # Save models checkpoints
            is_best = total_loss < best_loss
            best_loss = min(best_loss, total_loss)
            utils.save_checkpoint(exp_dir, 'G', {
                'resolution': res,
                'epoch': epoch + 1,
                'state_dict': G.module.state_dict() if gpus and len(gpus) > 1 else G.state_dict(),
                'optimizer': optimizer_G.state_dict(),
                'best_loss_%d' % res: best_loss,
                'arch': G_arch,
            }, is_best)
            utils.save_checkpoint(exp_dir, 'D', {
                'resolution': res,
                'epoch': epoch + 1,
                'state_dict': D.module.state_dict() if gpus and len(gpus) > 1 else D.state_dict(),
                'optimizer': optimizer_D.state_dict(),
                'best_loss_%d' % res: best_loss,
                'arch': D_arch,
            }, is_best)

        # Reset start epoch to 0 because it's should only effect the first training resolution
        start_epoch = 0
        best_loss = 1e6
Ejemplo n.º 5
0
def main(dataset='fsgan.datasets.seq_dataset.SeqDataset', np_transforms=None,
         tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'),
         workers=4, batch_size=4):
    import time
    import fsgan
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms = obj_factory(np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_lms_pose_transforms.Compose(np_transforms + tensor_transforms)
    dataset = obj_factory(dataset, transform=img_transforms)
    # dataset = VideoSeqDataset(root_path, img_list_path, transform=img_transforms, frame_window=frame_window)
    dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, drop_last=True,
                                 shuffle=True)

    start = time.time()
    if isinstance(dataset, fsgan.datasets.seq_dataset.SeqPairDataset):
        for frame, landmarks, pose, target in dataloader:
            pass
    elif isinstance(dataset, fsgan.datasets.seq_dataset.SeqDataset):
        for frame, landmarks, pose in dataloader:
            # For each batch
            for b in range(frame.shape[0]):
                # Render
                render_img = tensor2bgr(frame[b]).copy()
                curr_landmarks = landmarks[b].numpy() * render_img.shape[0]
                curr_pose = pose[b].numpy() * 99.

                for point in curr_landmarks:
                    cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
                msg = 'Pose: %.1f, %.1f, %.1f' % (curr_pose[0], curr_pose[1], curr_pose[2])
                cv2.putText(render_img, msg, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(0) & 0xFF == ord('q'):
                    break


        # print(frame_window.shape)

        # if isinstance(frame_window, (list, tuple)):
        #     # For each batch
        #     for b in range(frame_window[0].shape[0]):
        #         # For each frame window in the list
        #         for p in range(len(frame_window)):
        #             # For each frame in the window
        #             for f in range(frame_window[p].shape[2]):
        #                 print(frame_window[p][b, :, f, :, :].shape)
        #                 # Render
        #                 render_img = tensor2bgr(frame_window[p][b, :, f, :, :]).copy()
        #                 landmarks = landmarks_window[p][b, f, :, :].numpy()
        #                 # for point in np.round(landmarks).astype(int):
        #                 for point in landmarks:
        #                     cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
        #                 cv2.imshow('render_img', render_img)
        #                 if cv2.waitKey(0) & 0xFF == ord('q'):
        #                     break
        # else:
        #     # For each batch
        #     for b in range(frame_window.shape[0]):
        #         # For each frame in the window
        #         for f in range(frame_window.shape[2]):
        #             print(frame_window[b, :, f, :, :].shape)
        #             # Render
        #             render_img = tensor2bgr(frame_window[b, :, f, :, :]).copy()
        #             landmarks = landmarks_window[b, f, :, :].numpy()
        #             # for point in np.round(landmarks).astype(int):
        #             for point in landmarks:
        #                 cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
        #             cv2.imshow('render_img', render_img)
        #             if cv2.waitKey(0) & 0xFF == ord('q'):
        #                 break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))