Exemplo n.º 1
0
def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset',
         np_transforms=None,
         tensor_transforms=(
             'img_landmarks_transforms.ToTensor()',
             'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
         workers=4,
         batch_size=4):
    import time
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms = obj_factory(
        np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(np_transforms +
                                                      tensor_transforms)
    dataset = obj_factory(dataset, transform=img_transforms)
    dataloader = data.DataLoader(dataset,
                                 batch_size=4,
                                 num_workers=workers,
                                 pin_memory=True,
                                 drop_last=True,
                                 shuffle=True)

    start = time.time()
    if isinstance(dataset, ImageListDataset):
        for img, target in dataloader:
            print(img.shape)
            print(target)

            # For each batch
            for b in range(img.shape[0]):
                render_img = tensor2bgr(img[b]).copy()
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(0) & 0xFF == ord('q'):
                    break
    else:
        for img1, img2, target in dataloader:
            print(img1.shape)
            print(img2.shape)
            print(target)

            # For each batch
            for b in range(target.shape[0]):
                left_img = tensor2bgr(img1[b]).copy()
                right_img = tensor2bgr(img2[b]).copy()
                render_img = np.concatenate((left_img, right_img), axis=1)
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(0) & 0xFF == ord('q'):
                    break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))
Exemplo n.º 2
0
def main(input, np_transforms=None, tensor_transforms=None, batch_size=4):
    from torchvision.transforms import Compose
    from fsgan.utils.obj_factory import obj_factory

    np_transforms = obj_factory(np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else []
    img_transforms = Compose(np_transforms + tensor_transforms)

    img = cv2.imread(input)
    pose = np.array([1., 2., 3.])

    x = img_transforms((img, pose))
    pass
Exemplo n.º 3
0
def main(dataset='fsgan.datasets.image_seg_dataset.ImageSegDataset',
         np_transforms1=None,
         np_transforms2=None,
         tensor_transforms1=(
             'img_landmarks_transforms.ToTensor()',
             'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
         tensor_transforms2=('img_landmarks_transforms.ToTensor()', ),
         workers=4,
         batch_size=4):
    import time
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.seg_utils import blend_seg_pred, blend_seg_label
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms1 = obj_factory(
        np_transforms1) if np_transforms1 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    img_transforms1 = img_landmarks_transforms.Compose(np_transforms1 +
                                                       tensor_transforms1)
    np_transforms2 = obj_factory(
        np_transforms2) if np_transforms2 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms2 = img_landmarks_transforms.Compose(np_transforms2 +
                                                       tensor_transforms2)
    dataset = obj_factory(dataset,
                          transform=img_transforms1,
                          target_transform=img_transforms2)
    dataloader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=workers,
                                 pin_memory=True,
                                 drop_last=True,
                                 shuffle=True)

    start = time.time()
    for img, seg in dataloader:
        # For each batch
        for b in range(img.shape[0]):
            blend_tensor = blend_seg_pred(img, seg)
            render_img = tensor2bgr(blend_tensor[b])
            # render_img = tensor2bgr(img[b])
            cv2.imshow('render_img', render_img)
            if cv2.waitKey(0) & 0xFF == ord('q'):
                break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))
Exemplo n.º 4
0
def main(model='res_unet.ResUNet', res=(256, )):
    from fsgan.utils.obj_factory import obj_factory
    model = obj_factory(model)
    if len(res) == 1:
        img = torch.rand(1, model.in_nc, res, res)
        pred = model(img)
        print(pred.shape)
    else:
        img = []
        for i in range(1, len(res) + 1):
            img.append(torch.rand(1, model.in_nc, res[-i], res[-i]))
        pred = model(img)
        print(pred.shape)
Exemplo n.º 5
0
Arquivo: utils.py Projeto: lilac/fsgan
def load_model(model_path,
               name='',
               device=None,
               arch=None,
               return_checkpoint=False,
               train=False):
    """ Load a model from checkpoint.

    This is a utility function that combines the model weights and architecture (string representation) to easily
    load any model without explicit knowledge of its class.

    Args:
        model_path (str): Path to the model's checkpoint (.pth)
        name (str): The name of the model (for printing and error management)
        device (torch.device): The device to load the model to
        arch (str): The model's architecture (string representation)
        return_checkpoint (bool): If True, the checkpoint will be returned as well
        train (bool): If True, the model will be set to train mode, else it will be set to test mode

    Returns:
        (nn.Module, dict (optional)): A tuple that contains:
            - model (nn.Module): The loaded model
            - checkpoint (dict, optional): The model's checkpoint (only if return_checkpoint is True)
    """
    assert model_path is not None, '%s model must be specified!' % name
    assert os.path.exists(
        model_path), 'Couldn\'t find %s model in path: %s' % (name, model_path)
    print('=> Loading %s model: "%s"...' %
          (name, os.path.basename(model_path)))
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    assert arch is not None or 'arch' in checkpoint, 'Couldn\'t determine %s model architecture!' % name
    arch = checkpoint['arch'] if arch is None else arch
    model = obj_factory(arch)
    if device is not None:
        model.to(device)
    model.load_state_dict(checkpoint['state_dict'])
    model.train(train)

    if return_checkpoint:
        return model, checkpoint
    else:
        return model
Exemplo n.º 6
0
    def __init__(
        self,
        resolution=d('resolution'),
        crop_scale=d('crop_scale'),
        gpus=d('gpus'),
        cpu_only=d('cpu_only'),
        display=d('display'),
        verbose=d('verbose'),
        encoder_codec=d('encoder_codec'),
        # Detection arguments:
        detection_model=d('detection_model'),
        det_batch_size=d('det_batch_size'),
        det_postfix=d('det_postfix'),
        # Sequence arguments:
        iou_thresh=d('iou_thresh'),
        min_length=d('min_length'),
        min_size=d('min_size'),
        center_kernel=d('center_kernel'),
        size_kernel=d('size_kernel'),
        smooth_det=d('smooth_det'),
        seq_postfix=d('seq_postfix'),
        write_empty=d('write_empty'),
        # Pose arguments:
        pose_model=d('pose_model'),
        pose_batch_size=d('pose_batch_size'),
        pose_postfix=d('pose_postfix'),
        cache_pose=d('cache_pose'),
        cache_frontal=d('cache_frontal'),
        smooth_poses=d('smooth_poses'),
        # Landmarks arguments:
        lms_model=d('lms_model'),
        lms_batch_size=d('lms_batch_size'),
        landmarks_postfix=d('landmarks_postfix'),
        cache_landmarks=d('cache_landmarks'),
        smooth_landmarks=d('smooth_landmarks'),
        # Segmentation arguments:
        seg_model=d('seg_model'),
        smooth_segmentation=d('smooth_segmentation'),
        segmentation_postfix=d('segmentation_postfix'),
        cache_segmentation=d('cache_segmentation'),
        seg_batch_size=d('seg_batch_size'),
        seg_remove_mouth=d('seg_remove_mouth'),
        # Finetune arguments:
        finetune=d('finetune'),
        finetune_iterations=d('finetune_iterations'),
        finetune_lr=d('finetune_lr'),
        finetune_batch_size=d('finetune_batch_size'),
        finetune_workers=d('finetune_workers'),
        finetune_save=d('finetune_save'),
        # Swapping arguments:
        batch_size=d('batch_size'),
        reenactment_model=d('reenactment_model'),
        completion_model=d('completion_model'),
        blending_model=d('blending_model'),
        criterion_id=d('criterion_id'),
        min_radius=d('min_radius'),
        output_crop=d('output_crop'),
        renderer_process=d('renderer_process')):
        super(FaceSwapping,
              self).__init__(resolution,
                             crop_scale,
                             gpus,
                             cpu_only,
                             display,
                             verbose,
                             encoder_codec,
                             detection_model=detection_model,
                             det_batch_size=det_batch_size,
                             det_postfix=det_postfix,
                             iou_thresh=iou_thresh,
                             min_length=min_length,
                             min_size=min_size,
                             center_kernel=center_kernel,
                             size_kernel=size_kernel,
                             smooth_det=smooth_det,
                             seq_postfix=seq_postfix,
                             write_empty=write_empty,
                             pose_model=pose_model,
                             pose_batch_size=pose_batch_size,
                             pose_postfix=pose_postfix,
                             cache_pose=True,
                             cache_frontal=cache_frontal,
                             smooth_poses=smooth_poses,
                             lms_model=lms_model,
                             lms_batch_size=lms_batch_size,
                             landmarks_postfix=landmarks_postfix,
                             cache_landmarks=True,
                             smooth_landmarks=smooth_landmarks,
                             seg_model=seg_model,
                             seg_batch_size=seg_batch_size,
                             segmentation_postfix=segmentation_postfix,
                             cache_segmentation=True,
                             smooth_segmentation=smooth_segmentation,
                             seg_remove_mouth=seg_remove_mouth)
        self.batch_size = batch_size
        self.min_radius = min_radius
        self.output_crop = output_crop
        self.finetune_enabled = finetune
        self.finetune_iterations = finetune_iterations
        self.finetune_lr = finetune_lr
        self.finetune_batch_size = finetune_batch_size
        self.finetune_workers = finetune_workers
        self.finetune_save = finetune_save

        # Load reenactment model
        self.Gr, checkpoint = load_model(reenactment_model,
                                         'face reenactment',
                                         self.device,
                                         return_checkpoint=True)
        self.Gr.arch = checkpoint['arch']
        self.reenactment_state_dict = checkpoint['state_dict']

        # Load all other models
        self.Gc = load_model(completion_model, 'face completion', self.device)
        self.Gb = load_model(blending_model, 'face blending', self.device)

        # Initialize landmarks decoders
        self.landmarks_decoders = []
        for res in (128, 256):
            self.landmarks_decoders.insert(
                0,
                LandmarksHeatMapDecoder(res).to(self.device))

        # Initialize losses
        self.criterion_pixelwise = nn.L1Loss().to(self.device)
        self.criterion_id = obj_factory(criterion_id).to(self.device)

        # Support multiple GPUs
        if self.gpus and len(self.gpus) > 1:
            self.Gr = nn.DataParallel(self.Gr, self.gpus)
            self.Gc = nn.DataParallel(self.Gc, self.gpus)
            self.Gb = nn.DataParallel(self.Gb, self.gpus)
            self.criterion_id.vgg = nn.DataParallel(self.criterion_id.vgg,
                                                    self.gpus)

        # Initialize soft erosion
        self.smooth_mask = SoftErosion(kernel_size=21,
                                       threshold=0.6).to(self.device)

        # Initialize video writer
        self.video_renderer = FaceSwappingRenderer(
            self.display, self.verbose, self.output_crop, self.resolution,
            self.crop_scale, encoder_codec, renderer_process)
        self.video_renderer.start()
Exemplo n.º 7
0
def main(
        # General arguments
        exp_dir,
        resume_dir=None,
        start_epoch=None,
        epochs=(90, ),
        iterations=None,
        resolutions=(128, 256),
        lr_gen=(1e-4, ),
        lr_dis=(1e-4, ),
        gpus=None,
        workers=4,
        batch_size=(64, ),
        seed=None,
        log_freq=20,

        # Data arguments
        train_dataset='opencv_video_seq_dataset.VideoSeqDataset',
        val_dataset=None,
        numpy_transforms=None,
        tensor_transforms=(
            'img_landmarks_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),

        # Training arguments
        optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)',
        scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)',
        pretrained=False,
        criterion_pixelwise='nn.L1Loss',
        criterion_id='vgg_loss.VGGLoss',
        criterion_attr='vgg_loss.VGGLoss',
        criterion_gan='gan_loss.GANLoss(use_lsgan=True)',
        generator='res_unet.MultiScaleResUNet(in_nc=4,out_nc=3)',
        discriminator='discriminators_pix2pix.MultiscaleDiscriminator',
        reenactment_model=None,
        seg_model=None,
        lms_model=None,
        pix_weight=0.1,
        rec_weight=1.0,
        gan_weight=0.001,
        background_value=-1.0):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        Gc.train(train)
        D.train(train)
        Gr.train(False)
        S.train(False)
        L.train(False)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,
            scheduler_G.get_lr()[0]))

        # For each batch in the training data
        for i, (img, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images
                for j in range(len(img)):
                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Compute context
                context = L(img[1][0].sub(context_mean).div(context_std))
                context = landmarks_utils.filter_landmarks(context)

                # Normalize each of the pyramid images
                for j in range(len(img)):
                    for p in range(len(img[j])):
                        img[j][p].sub_(img_mean).div_(img_std)

                # # Compute segmentation
                # seg = []
                # for j in range(len(img)):
                #     curr_seg = S(img[j][0])
                #     if curr_seg.shape[2:] != (res, res):
                #         curr_seg = F.interpolate(curr_seg, (res, res), mode='bicubic', align_corners=False)
                #     seg.append(curr_seg)

                # Compute segmentation
                target_seg = S(img[1][0])
                if target_seg.shape[2:] != (res, res):
                    target_seg = F.interpolate(target_seg, (res, res),
                                               mode='bicubic',
                                               align_corners=False)

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0]) - 1, -1, -1):
                    context = F.interpolate(context,
                                            size=img[0][p].shape[2:],
                                            mode='bicubic',
                                            align_corners=False)
                    input.insert(0, torch.cat((img[0][p], context), dim=1))

                # Reenactment
                reenactment_img = Gr(input)
                reenactment_seg = S(reenactment_img)
                if reenactment_img.shape[2:] != (res, res):
                    reenactment_img = F.interpolate(reenactment_img,
                                                    (res, res),
                                                    mode='bilinear',
                                                    align_corners=False)
                    reenactment_seg = F.interpolate(reenactment_seg,
                                                    (res, res),
                                                    mode='bilinear',
                                                    align_corners=False)

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Source face
                reenactment_face_mask = reenactment_seg.argmax(1) == 1
                inpainting_mask = seg_utils.random_hair_inpainting_mask_tensor(
                    reenactment_face_mask).to(device)
                reenactment_face_mask = reenactment_face_mask * (
                    inpainting_mask == 0)
                reenactment_img_with_hole = reenactment_img.masked_fill(
                    ~reenactment_face_mask.unsqueeze(1), background_value)

                # Target face
                target_face_mask = (target_seg.argmax(1) == 1).unsqueeze(1)
                inpainting_target = img[1][0]
                inpainting_target.masked_fill_(~target_face_mask,
                                               background_value)

                # Inpainting input
                inpainting_input = torch.cat(
                    (reenactment_img_with_hole, target_face_mask.float()),
                    dim=1)
                inpainting_input_pyd = img_utils.create_pyramid(
                    inpainting_input, len(img[0]))

            # Face inpainting
            inpainting_pred = Gc(inpainting_input_pyd)

            # Fake Detection and Loss
            inpainting_pred_pyd = img_utils.create_pyramid(
                inpainting_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in inpainting_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            inpainting_target_pyd = img_utils.create_pyramid(
                inpainting_target, len(img[0]))
            pred_real = D(inpainting_target_pyd)
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(inpainting_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction
            loss_pixelwise = criterion_pixelwise(inpainting_pred,
                                                 inpainting_target)
            loss_id = criterion_id(inpainting_pred, inpainting_target)
            loss_attr = criterion_attr(inpainting_pred, inpainting_target)
            loss_rec = pix_weight * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr

            loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses',
                          pixelwise=loss_pixelwise,
                          id=loss_id,
                          attr=loss_attr,
                          rec=loss_rec,
                          g_gan=loss_G_GAN,
                          d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg(
            '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            grid = img_utils.make_grid(img[0][0], reenactment_img,
                                       reenactment_img_with_hole,
                                       inpainting_pred, inpainting_target)
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg

    #################
    # Main pipeline #
    #################

    # Validation
    resolutions = resolutions if isinstance(resolutions,
                                            (list, tuple)) else [resolutions]
    lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen]
    lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis]
    epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs]
    batch_size = batch_size if isinstance(batch_size,
                                          (list, tuple)) else [batch_size]
    iterations = iterations if iterations is None or isinstance(
        iterations, (list, tuple)) else [iterations]

    lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen
    lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis
    epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs
    batch_size = batch_size * len(resolutions) if len(
        batch_size) == 1 else batch_size
    if iterations is not None:
        iterations = iterations * len(resolutions) if len(
            iterations) == 1 else iterations
        iterations = utils.str2int(iterations)

    if not os.path.isdir(exp_dir):
        raise RuntimeError('Experiment directory was not found: \'' + exp_dir +
                           '\'')
    assert len(lr_gen) == len(resolutions)
    assert len(lr_dis) == len(resolutions)
    assert len(epochs) == len(resolutions)
    assert len(batch_size) == len(resolutions)
    assert iterations is None or len(iterations) == len(resolutions)

    # Seed
    utils.set_seed(seed)

    # Check CUDA device availability
    device, gpus = utils.set_device(gpus)

    # Initialize loggers
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Initialize datasets
    numpy_transforms = obj_factory(
        numpy_transforms) if numpy_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(numpy_transforms +
                                                      tensor_transforms)

    train_dataset = obj_factory(train_dataset, transform=img_transforms)
    if val_dataset is not None:
        val_dataset = obj_factory(val_dataset, transform=img_transforms)

    # Create networks
    Gc = obj_factory(generator).to(device)
    D = obj_factory(discriminator).to(device)

    # Resume from a checkpoint or initialize the networks weights randomly
    checkpoint_dir = exp_dir if resume_dir is None else resume_dir
    Gc_path = os.path.join(checkpoint_dir, 'Gc_latest.pth')
    D_path = os.path.join(checkpoint_dir, 'D_latest.pth')
    best_loss = 1000000.
    curr_res = resolutions[0]
    optimizer_G_state, optimizer_D_state = None, None
    if os.path.isfile(Gc_path) and os.path.isfile(D_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # Gc
        checkpoint = torch.load(Gc_path)
        if 'resolution' in checkpoint:
            curr_res = checkpoint['resolution']
            start_epoch = checkpoint[
                'epoch'] if start_epoch is None else start_epoch
        else:
            curr_res = resolutions[1] if len(
                resolutions) > 1 else resolutions[0]
        best_loss = checkpoint['best_loss']
        Gc.apply(utils.init_weights)
        Gc.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer_G_state = checkpoint['optimizer']

        # D
        D.apply(utils.init_weights)
        if os.path.isfile(D_path):
            checkpoint = torch.load(D_path)
            D.load_state_dict(checkpoint['state_dict'], strict=False)
            optimizer_D_state = checkpoint['optimizer']
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            Gc.apply(utils.init_weights)
            D.apply(utils.init_weights)

    # Load reenactment model
    print('=> Loading face reenactment model: "' +
          os.path.basename(reenactment_model) + '"...')
    if reenactment_model is None:
        raise RuntimeError('Reenactment model must be specified!')
    if not os.path.exists(reenactment_model):
        raise RuntimeError('Couldn\'t find reenactment model in path: ' +
                           reenactment_model)
    checkpoint = torch.load(reenactment_model)
    Gr = obj_factory(checkpoint['arch']).to(device)
    Gr.load_state_dict(checkpoint['state_dict'])

    # Load segmentation model
    print('=> Loading face segmentation model: "' +
          os.path.basename(seg_model) + '"...')
    if seg_model is None:
        raise RuntimeError('Segmentation model must be specified!')
    if not os.path.exists(seg_model):
        raise RuntimeError('Couldn\'t find segmentation model in path: ' +
                           seg_model)
    checkpoint = torch.load(seg_model)
    S = obj_factory(checkpoint['arch']).to(device)
    S.load_state_dict(checkpoint['state_dict'])

    # Load face landmarks model
    print('=> Loading face landmarks model: "' + os.path.basename(lms_model) +
          '"...')
    assert os.path.isfile(
        lms_model), 'The model path "%s" does not exist' % lms_model
    L = hrnet_wlfw().to(device)
    state_dict = torch.load(lms_model)
    L.load_state_dict(state_dict)

    # Initialize normalization tensors
    # Note: this is necessary because of the landmarks model
    img_mean = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1)
    img_std = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1)
    context_mean = torch.as_tensor([0.485, 0.456, 0.406],
                                   device=device).view(1, 3, 1, 1)
    context_std = torch.as_tensor([0.229, 0.224, 0.225],
                                  device=device).view(1, 3, 1, 1)

    # Lossess
    criterion_pixelwise = obj_factory(criterion_pixelwise).to(device)
    criterion_id = obj_factory(criterion_id).to(device)
    criterion_attr = obj_factory(criterion_attr).to(device)
    criterion_gan = obj_factory(criterion_gan).to(device)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        Gc = nn.DataParallel(Gc, gpus)
        Gr = nn.DataParallel(Gr, gpus)
        D = nn.DataParallel(D, gpus)
        S = nn.DataParallel(S, gpus)
        L = nn.DataParallel(L, gpus)
        criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus)
        criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus)

    # For each resolution
    start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0]))
    start_epoch = 0 if start_epoch is None else start_epoch
    for ri in range(start_res_ind, len(resolutions)):
        res = resolutions[ri]
        res_lr_gen = lr_gen[ri]
        res_lr_dis = lr_dis[ri]
        res_epochs = epochs[ri]
        res_iterations = iterations[ri] if iterations is not None else None
        res_batch_size = batch_size[ri]

        # Optimizer and scheduler
        optimizer_G = obj_factory(optimizer, Gc.parameters(), lr=res_lr_gen)
        optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis)
        scheduler_G = obj_factory(scheduler, optimizer_G)
        scheduler_D = obj_factory(scheduler, optimizer_D)
        if optimizer_G_state is not None:
            optimizer_G.load_state_dict(optimizer_G_state)
            optimizer_G_state = None
        if optimizer_D_state is not None:
            optimizer_D.load_state_dict(optimizer_D_state)
            optimizer_D_state = None

        # Initialize data loaders
        if res_iterations is None:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, len(train_dataset))
        else:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, res_iterations)
        train_loader = tutils.data.DataLoader(train_dataset,
                                              batch_size=res_batch_size,
                                              sampler=train_sampler,
                                              num_workers=workers,
                                              pin_memory=True,
                                              drop_last=True,
                                              shuffle=False)
        if val_dataset is not None:
            if res_iterations is None:
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, len(val_dataset))
            else:
                val_iterations = (res_iterations * len(
                    val_dataset.classes)) // len(train_dataset.classes)
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, val_iterations)
            val_loader = tutils.data.DataLoader(val_dataset,
                                                batch_size=res_batch_size,
                                                sampler=val_sampler,
                                                num_workers=workers,
                                                pin_memory=True,
                                                drop_last=True,
                                                shuffle=False)
        else:
            val_loader = None

        # For each epoch
        for epoch in range(start_epoch, res_epochs):
            total_loss = proces_epoch(train_loader, train=True)
            if val_loader is not None:
                with torch.no_grad():
                    total_loss = proces_epoch(val_loader, train=False)

            # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
            if isinstance(scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler_G.step(total_loss)
                scheduler_D.step(total_loss)
            else:
                scheduler_G.step()
                scheduler_D.step()

            # Save models checkpoints
            is_best = total_loss < best_loss
            best_loss = min(best_loss, total_loss)
            utils.save_checkpoint(
                exp_dir, 'Gc', {
                    'resolution':
                    res,
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    Gc.module.state_dict()
                    if gpus and len(gpus) > 1 else Gc.state_dict(),
                    'optimizer':
                    optimizer_G.state_dict(),
                    'best_loss':
                    best_loss,
                }, is_best)
            utils.save_checkpoint(
                exp_dir, 'D', {
                    'resolution':
                    res,
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    D.module.state_dict()
                    if gpus and len(gpus) > 1 else D.state_dict(),
                    'optimizer':
                    optimizer_D.state_dict(),
                    'best_loss':
                    best_loss,
                }, is_best)

        # Reset start epoch to 0 because it's should only effect the first training resolution
        start_epoch = 0
Exemplo n.º 8
0
def main(
        source_path,
        target_path,
        arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)',
        reenactment_model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth',
        seg_model_path='../weights/lfw_figaro_unet_256_2_0_segmentation_v1.pth',
        inpainting_model_path='../weights/ijbc_msrunet_256_2_0_inpainting_v1.pth',
        blend_model_path='../weights/ijbc_msrunet_256_2_0_blending_v1.pth',
        pose_model_path='../weights/hopenet_robust_alpha1.pth',
        pil_transforms1=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        pil_transforms2=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)',
                         'landmark_transforms.LandmarksToHeatmaps'),
        tensor_transforms1=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        tensor_transforms2=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        output_path=None,
        min_radius=2.0,
        crop_size=256,
        reverse_output=False,
        verbose=0,
        output_crop=False,
        display=False):
    torch.set_grad_enabled(False)

    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
                                      flip_input=False)
    device, gpus = utils.set_device()

    Gr = obj_factory(arch).to(device)
    checkpoint = torch.load(reenactment_model_path)
    Gr.load_state_dict(checkpoint['state_dict'])
    Gr.train(False)

    if seg_model_path is not None:
        print('Loading face segmentation model: "' +
              os.path.basename(seg_model_path) + '"...')
        if seg_model_path.endswith('.pth'):
            checkpoint = torch.load(seg_model_path)
            Gs = obj_factory(checkpoint['arch']).to(device)
            Gs.load_state_dict(checkpoint['state_dict'])
        else:
            Gs = torch.jit.load(seg_model_path, map_location=device)
        if Gs is None:
            raise RuntimeError('Failed to load face segmentation model!')
            Gs.eval()
    else:
        Gs = None

    if seg_model_path is not None:
        print('Loading face inpainting model: "' +
              os.path.basename(inpainting_model_path) + '"...')
        if inpainting_model_path.endswith('.pth'):
            checkpoint = torch.load(inpainting_model_path)
            Gi = obj_factory(checkpoint['arch']).to(device)
            Gi.load_state_dict(checkpoint['state_dict'])
        else:
            Gi = torch.jit.load(inpainting_model_path, map_location=device)
        if Gi is None:
            raise RuntimeError('Failed to load face segmentation model!')
        Gi.eval()
    else:
        Gi = None

    checkpoint = torch.load(blend_model_path)
    Gb = obj_factory(checkpoint['arch']).to(device)
    Gb.load_state_dict(checkpoint['state_dict'])
    Gb.train(False)

    Gp = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    Gp.load_state_dict(checkpoint['state_dict'])
    Gp.train(False)

    pil_transforms1 = obj_factory(
        pil_transforms1) if pil_transforms1 is not None else []
    pil_transforms2 = obj_factory(
        pil_transforms2) if pil_transforms2 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 +
                                                          tensor_transforms1)
    img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 +
                                                          tensor_transforms2)

    source_frame_indices, source_landmarks, source_bboxes, source_eulers = \
        extract_landmarks_bboxes_euler_from_images(source_path, Gp, fa, device=device)
    if source_frame_indices.size == 0:
        raise RuntimeError(
            'No faces were detected in the source image directory: ' +
            source_path)

    target_frame_indices, target_landmarks, target_bboxes, target_eulers = \
        extract_landmarks_bboxes_euler_from_images(target_path, Gp, fa, device=device)
    if target_frame_indices.size == 0:
        raise RuntimeError(
            'No faces were detected in the target image directory: ' +
            target_path)

    source_img_paths = glob(os.path.join(source_path, '*.jpg'))
    target_img_paths = glob(os.path.join(target_path, '*.jpg'))

    source_valid_frame_ind = 0
    for k, source_img_path in tqdm(enumerate(source_img_paths),
                                   unit='images',
                                   total=len(source_img_paths)):
        if k not in source_frame_indices:
            continue
        source_img_bgr = cv2.imread(source_img_path)
        if source_img_bgr is None:
            continue
        source_img_rgb = source_img_bgr[:, :, ::-1]
        curr_source_tensor, curr_source_landmarks, curr_source_bbox = img_transforms1(
            source_img_rgb, source_landmarks[source_valid_frame_ind],
            source_bboxes[source_valid_frame_ind])
        source_valid_frame_ind += 1

        for j in range(len(curr_source_tensor)):
            curr_source_tensor[j] = curr_source_tensor[j].to(device)

        target_valid_frame_ind = 0
        for i, target_img_path in enumerate(target_img_paths):
            curr_output_name = '_'.join([
                os.path.splitext(os.path.basename(source_img_path))[0],
                os.path.splitext(os.path.basename(target_img_path))[0]
            ]) + '.jpg'
            curr_output_path = os.path.join(output_path, curr_output_name)
            if os.path.isfile(curr_output_path):
                target_valid_frame_ind += 1
                continue
            target_img_bgr = cv2.imread(target_img_path)
            if target_img_bgr is None:
                continue
            if i not in target_frame_indices:
                continue
            target_img_rgb = target_img_bgr[:, :, ::-1]

            curr_target_tensor, curr_target_landmarks, curr_target_bbox = img_transforms2(
                target_img_rgb, target_landmarks[target_valid_frame_ind],
                target_bboxes[target_valid_frame_ind])
            curr_target_euler = target_eulers[target_valid_frame_ind]
            target_valid_frame_ind += 1

            reenactment_input_tensor = []
            for j in range(len(curr_source_tensor)):
                curr_target_landmarks[j] = curr_target_landmarks[j].to(device)
                reenactment_input_tensor.append(
                    torch.cat(
                        (curr_source_tensor[j], curr_target_landmarks[j]),
                        dim=0).unsqueeze(0))
            reenactment_img_tensor, reenactment_seg_tensor = Gr(
                reenactment_input_tensor)

            target_img_tensor = curr_target_tensor[0].unsqueeze(0).to(device)
            target_seg_pred_tensor = Gs(target_img_tensor)
            target_mask_tensor = target_seg_pred_tensor.argmax(1) == 1

            aligned_face_mask_tensor = reenactment_seg_tensor.argmax(1) == 1
            aligned_background_mask_tensor = ~aligned_face_mask_tensor
            aligned_img_no_background_tensor = reenactment_img_tensor.clone()
            aligned_img_no_background_tensor.masked_fill_(
                aligned_background_mask_tensor.unsqueeze(1), -1.0)

            inpainting_input_tensor = torch.cat(
                (aligned_img_no_background_tensor,
                 target_mask_tensor.unsqueeze(1).float()),
                dim=1)
            inpainting_input_tensor_pyd = create_pyramid(
                inpainting_input_tensor, len(curr_target_tensor))
            completion_tensor = Gi(inpainting_input_tensor_pyd)

            transfer_tensor = transfer_mask(completion_tensor,
                                            target_img_tensor,
                                            target_mask_tensor)
            blend_input_tensor = torch.cat(
                (transfer_tensor, target_img_tensor,
                 target_mask_tensor.unsqueeze(1).float()),
                dim=1)
            blend_input_tensor_pyd = create_pyramid(blend_input_tensor,
                                                    len(curr_target_tensor))
            blend_tensor = Gb(blend_input_tensor_pyd)

            blend_img = tensor2bgr(blend_tensor)

            if verbose == 0:
                render_img = blend_img if output_crop else crop2img(
                    target_img_bgr, blend_img, curr_target_bbox[0].numpy())
            elif verbose == 1:
                reenactment_only_tensor = transfer_mask(
                    reenactment_img_tensor, target_img_tensor,
                    aligned_face_mask_tensor & target_mask_tensor)
                reenactment_only_img = tensor2bgr(reenactment_only_tensor)

                completion_only_img = tensor2bgr(transfer_tensor)

                transfer_tensor = transfer_mask(
                    aligned_img_no_background_tensor, target_img_tensor,
                    target_mask_tensor)
                blend_input_tensor = torch.cat(
                    (transfer_tensor, target_img_tensor,
                     target_mask_tensor.unsqueeze(1).float()),
                    dim=1)
                blend_input_tensor_pyd = create_pyramid(
                    blend_input_tensor, len(curr_target_tensor))
                blend_tensor = Gb(blend_input_tensor_pyd)
                blend_only_img = tensor2bgr(blend_tensor)

                render_img = np.concatenate(
                    (reenactment_only_img, completion_only_img, blend_only_img,
                     blend_img),
                    axis=1)
            elif verbose == 2:
                reenactment_img_bgr = tensor2bgr(reenactment_img_tensor)
                reenactment_seg_bgr = tensor2bgr(
                    blend_seg_pred(reenactment_img_tensor,
                                   reenactment_seg_tensor))
                target_seg_bgr = tensor2bgr(
                    blend_seg_pred(target_img_tensor, target_seg_pred_tensor))
                aligned_img_no_background_bgr = tensor2bgr(
                    aligned_img_no_background_tensor)
                completion_bgr = tensor2bgr(completion_tensor)
                transfer_bgr = tensor2bgr(transfer_tensor)
                target_cropped_bgr = tensor2bgr(target_img_tensor)

                pose_axis_bgr = draw_axis(np.zeros_like(target_cropped_bgr),
                                          curr_target_euler[0],
                                          curr_target_euler[1],
                                          curr_target_euler[2])
                render_img1 = np.concatenate(
                    (reenactment_img_bgr, reenactment_seg_bgr, target_seg_bgr),
                    axis=1)
                render_img2 = np.concatenate((aligned_img_no_background_bgr,
                                              completion_bgr, transfer_bgr),
                                             axis=1)
                render_img3 = np.concatenate(
                    (pose_axis_bgr, blend_img, target_cropped_bgr), axis=1)
                render_img = np.concatenate(
                    (render_img1, render_img2, render_img3), axis=0)
            elif verbose == 3:
                source_cropped_bgr = tensor2bgr(
                    curr_source_tensor[0].unsqueeze(0))
                target_cropped_bgr = tensor2bgr(target_img_tensor)
                render_img = np.concatenate(
                    (source_cropped_bgr, target_cropped_bgr, blend_img),
                    axis=1)
            cv2.imwrite(curr_output_path, render_img)
            if display:
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
Exemplo n.º 9
0
def main(
        # General arguments
        exp_dir,
        resume_dir=None,
        start_epoch=None,
        epochs=(90, ),
        iterations=None,
        resolutions=(128, 256),
        learning_rate=(1e-1, ),
        gpus=None,
        workers=4,
        batch_size=(64, ),
        seed=None,
        log_freq=20,

        # Data arguments
        train_dataset='fsgan.image_seg_dataset.ImageSegDataset',
        val_dataset=None,
        numpy_transforms=None,
        tensor_transforms=(
            'img_landmarks_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),

        # Training arguments
        optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)',
        scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)',
        criterion='nn.CrossEntropyLoss',
        model='fsgan.models.simple_unet.UNet(n_classes=3,feature_scale=1)',
        pretrained=False,
        benchmark='fsgan.train_segmentation.IOUBenchmark(3)'):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        model.train(train)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,
            scheduler.get_lr()[0]))

        # For each batch in the training data
        for i, (input, target) in enumerate(pbar):
            # Prepare input
            input = input.to(device)
            target = target.to(device)
            with torch.no_grad():
                target = target.argmax(dim=1)

            # Execute model
            pred = model(input)

            # Calculate loss
            loss_total = criterion(pred, target)

            # Run benchmark
            benchmark_res = benchmark(pred,
                                      target) if benchmark is not None else {}

            if train:
                # Update generator weights
                optimizer.zero_grad()
                loss_total.backward()
                optimizer.step()

            logger.update('losses', total=loss_total)
            logger.update('bench', **benchmark_res)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg(
            '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            seg_pred = blend_seg_pred(input, pred)
            seg_gt = blend_seg_label(input, target)
            grid = img_utils.make_grid(input, seg_pred, seg_gt)
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['total'].avg

    #################
    # Main pipeline #
    #################

    # Validation
    resolutions = resolutions if isinstance(resolutions,
                                            (list, tuple)) else [resolutions]
    learning_rate = learning_rate if isinstance(learning_rate,
                                                (list,
                                                 tuple)) else [learning_rate]
    epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs]
    batch_size = batch_size if isinstance(batch_size,
                                          (list, tuple)) else [batch_size]
    iterations = iterations if iterations is None or isinstance(
        iterations, (list, tuple)) else [iterations]

    learning_rate = learning_rate * len(resolutions) if len(
        learning_rate) == 1 else learning_rate
    epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs
    batch_size = batch_size * len(resolutions) if len(
        batch_size) == 1 else batch_size
    if iterations is not None:
        iterations = iterations * len(resolutions) if len(
            iterations) == 1 else iterations
        iterations = utils.str2int(iterations)

    if not os.path.isdir(exp_dir):
        raise RuntimeError('Experiment directory was not found: \'' + exp_dir +
                           '\'')
    assert len(learning_rate) == len(resolutions)
    assert len(epochs) == len(resolutions)
    assert len(batch_size) == len(resolutions)
    assert iterations is None or len(iterations) == len(resolutions)

    # Seed
    utils.set_seed(seed)

    # Check CUDA device availability
    device, gpus = utils.set_device(gpus)

    # Initialize loggers
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Initialize datasets
    numpy_transforms = obj_factory(
        numpy_transforms) if numpy_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(numpy_transforms +
                                                      tensor_transforms)

    train_dataset = obj_factory(train_dataset, transform=img_transforms)
    if val_dataset is not None:
        val_dataset = obj_factory(val_dataset, transform=img_transforms)

    # Create networks
    arch = utils.get_arch(model, num_classes=len(train_dataset.classes))
    model = obj_factory(model,
                        num_classes=len(train_dataset.classes)).to(device)

    # Resume from a checkpoint or initialize the networks weights randomly
    checkpoint_dir = exp_dir if resume_dir is None else resume_dir
    model_path = os.path.join(checkpoint_dir, 'model_latest.pth')
    best_loss = 1e6
    curr_res = resolutions[0]
    optimizer_state = None
    if os.path.isfile(model_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # model
        checkpoint = torch.load(model_path)
        if 'resolution' in checkpoint:
            curr_res = checkpoint['resolution']
            start_epoch = checkpoint[
                'epoch'] if start_epoch is None else start_epoch
        # else:
        #     curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0]
        best_loss_key = 'best_loss_%d' % curr_res
        best_loss = checkpoint[
            best_loss_key] if best_loss_key in checkpoint else best_loss
        model.apply(utils.init_weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer_state = checkpoint['optimizer']
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            model.apply(utils.init_weights)

    # Lossess
    criterion = obj_factory(criterion).to(device)

    # Benchmark
    benchmark = obj_factory(benchmark).to(device)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        model = nn.DataParallel(model, gpus)

    # For each resolution
    start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0]))
    start_epoch = 0 if start_epoch is None else start_epoch
    for ri in range(start_res_ind, len(resolutions)):
        res = resolutions[ri]
        res_lr = learning_rate[ri]
        res_epochs = epochs[ri]
        res_iterations = iterations[ri] if iterations is not None else None
        res_batch_size = batch_size[ri]

        # Optimizer and scheduler
        optimizer = obj_factory(optimizer, model.parameters(), lr=res_lr)
        scheduler = obj_factory(scheduler, optimizer)
        if optimizer_state is not None:
            optimizer.load_state_dict(optimizer_state)

        # Initialize data loaders
        if res_iterations is None:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, len(train_dataset))
        else:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, res_iterations)
        train_loader = tutils.data.DataLoader(train_dataset,
                                              batch_size=res_batch_size,
                                              sampler=train_sampler,
                                              num_workers=workers,
                                              pin_memory=True,
                                              drop_last=True,
                                              shuffle=False)
        if val_dataset is not None:
            if res_iterations is None:
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, len(val_dataset))
            else:
                val_iterations = (res_iterations *
                                  len(val_dataset)) // len(train_dataset)
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, val_iterations)
            val_loader = tutils.data.DataLoader(val_dataset,
                                                batch_size=res_batch_size,
                                                sampler=val_sampler,
                                                num_workers=workers,
                                                pin_memory=True,
                                                drop_last=True,
                                                shuffle=False)
        else:
            val_loader = None

        # For each epoch
        for epoch in range(start_epoch, res_epochs):
            total_loss = proces_epoch(train_loader, train=True)
            if val_loader is not None:
                with torch.no_grad():
                    total_loss = proces_epoch(val_loader, train=False)
            if hasattr(benchmark, 'reset'):
                benchmark.reset()

            # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
            if isinstance(scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(total_loss)
            else:
                scheduler.step()

            # Save models checkpoints
            is_best = total_loss < best_loss
            best_loss = min(best_loss, total_loss)
            utils.save_checkpoint(
                exp_dir, 'model', {
                    'resolution':
                    res,
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    model.module.state_dict()
                    if gpus and len(gpus) > 1 else model.state_dict(),
                    'optimizer':
                    optimizer.state_dict(),
                    'best_loss_%d' % res:
                    best_loss,
                    'arch':
                    arch,
                }, is_best)

        # Reset start epoch to 0 because it's should only effect the first training resolution
        start_epoch = 0
        best_loss = 1e6
Exemplo n.º 10
0
def main(
        source_path,
        target_path,
        arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)',
        model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth',
        pose_model_path='../weights/hopenet_robust_alpha1.pth',
        pil_transforms1=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        pil_transforms2=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)',
                         'landmark_transforms.LandmarksToHeatmaps'),
        tensor_transforms1=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        tensor_transforms2=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        output_path=None,
        crop_size=256,
        display=False):
    torch.set_grad_enabled(False)

    # Initialize models
    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
                                      flip_input=False)
    device, gpus = utils.set_device()
    G = obj_factory(arch).to(device)
    checkpoint = torch.load(model_path)
    G.load_state_dict(checkpoint['state_dict'])
    G.train(False)

    # Initialize pose
    Gp = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    Gp.load_state_dict(checkpoint['state_dict'])
    Gp.train(False)

    # Initialize transformations
    pil_transforms1 = obj_factory(
        pil_transforms1) if pil_transforms1 is not None else []
    pil_transforms2 = obj_factory(
        pil_transforms2) if pil_transforms2 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 +
                                                          tensor_transforms1)
    img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 +
                                                          tensor_transforms2)

    # Process source image
    source_bgr = cv2.imread(source_path)
    source_rgb = source_bgr[:, :, ::-1]
    source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size)
    if source_bbox is None:
        raise RuntimeError("Couldn't detect a face in source image: " +
                           source_path)
    source_tensor, source_landmarks, source_bbox = img_transforms1(
        source_rgb, source_landmarks, source_bbox)
    source_cropped_bgr = tensor2bgr(
        source_tensor[0] if isinstance(source_tensor, list) else source_tensor)
    for i in range(len(source_tensor)):
        source_tensor[i] = source_tensor[i].to(device)

    # Extract landmarks and bounding boxes from target video
    frame_indices, landmarks, bboxes, eulers = extract_landmarks_bboxes_euler_from_video(
        target_path, Gp, device=device)
    if frame_indices.size == 0:
        raise RuntimeError('No faces were detected in the target video: ' +
                           target_path)

    # Open target video file
    cap = cv2.VideoCapture(target_path)
    if not cap.isOpened():
        raise RuntimeError('Failed to read video: ' + target_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Initialize output video file
    if output_path is not None:
        if os.path.isdir(output_path):
            output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \
                              os.path.splitext(os.path.basename(target_path))[0] + '.mp4'
            output_path = os.path.join(output_path, output_filename)
            print(output_path)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out_vid = cv2.VideoWriter(
            output_path, fourcc, fps,
            (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0]))
    else:
        out_vid = None

    # For each frame in the target video
    valid_frame_ind = 0
    for i in tqdm(range(total_frames)):
        ret, frame = cap.read()
        if frame is None:
            continue
        if i not in frame_indices:
            continue
        frame_rgb = frame[:, :, ::-1]
        frame_tensor, frame_landmarks, frame_bbox = img_transforms2(
            frame_rgb, landmarks[valid_frame_ind], bboxes[valid_frame_ind])
        valid_frame_ind += 1

        # frame_cropped_rgb, frame_landmarks = process_cached_frame(frame_rgb, landmarks[valid_frame_ind],
        #                                                           bboxes[valid_frame_ind], size)
        # frame_cropped_bgr = frame_cropped_rgb[:, :, ::-1].copy()
        # valid_frame_ind += 1

        #
        # frame_tensor, frame_landmarks_tensor = prepare_generator_input(frame_cropped_rgb, frame_landmarks)
        # frame_landmarks_tensor.to(device)
        input_tensor = []
        for j in range(len(source_tensor)):
            frame_landmarks[j] = frame_landmarks[j].to(device)
            input_tensor.append(
                torch.cat((source_tensor[j], frame_landmarks[j]),
                          dim=0).unsqueeze(0).to(device))
        out_img_tensor, out_seg_tensor = G(input_tensor)

        # Transfer image1 mask to image2
        # face_mask_tensor = out_seg_tensor.argmax(1) == 1  # face
        # face_mask_tensor = out_seg_tensor.argmax(1) == 2    # hair
        # face_mask_tensor = out_seg_tensor.argmax(1) >= 1  # head

        # target_img_tensor = frame_tensor[0].view(1, frame_tensor[0].shape[0],
        #                                          frame_tensor[0].shape[1], frame_tensor[0].shape[2]).to(device)

        # Convert back to numpy images
        out_img_bgr = tensor2bgr(out_img_tensor)
        frame_cropped_bgr = tensor2bgr(frame_tensor[0])

        # Render
        # for point in np.round(frame_landmarks).astype(int):
        #     cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1)
        render_img = np.concatenate(
            (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1)
        if out_vid is not None:
            out_vid.write(render_img)
        if out_vid is None or display:
            cv2.imshow('render_img', render_img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
Exemplo n.º 11
0
def main(
        source_path,
        target_path,
        arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)',
        model_path='../weights/ijbc_msrunet_256_1_2_reenactment_stepwise_v1.pth',
        pose_model_path='../weights/hopenet_robust_alpha1.pth',
        pil_transforms1=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        pil_transforms2=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        tensor_transforms1=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        tensor_transforms2=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        output_path=None,
        crop_size=256,
        display=False):
    torch.set_grad_enabled(False)

    # Initialize models
    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D,
                                      flip_input=True)
    device, gpus = utils.set_device()
    G = obj_factory(arch).to(device)
    checkpoint = torch.load(model_path)
    G.load_state_dict(checkpoint['state_dict'])
    G.train(False)

    # Initialize pose
    Gp = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    Gp.load_state_dict(checkpoint['state_dict'])
    Gp.train(False)

    # Initialize landmarks to heatmaps
    landmarks2heatmaps = [
        LandmarkHeatmap(kernel_size=13, size=(256, 256)).to(device),
        LandmarkHeatmap(kernel_size=7, size=(128, 128)).to(device)
    ]

    # Initialize transformations
    pil_transforms1 = obj_factory(
        pil_transforms1) if pil_transforms1 is not None else []
    pil_transforms2 = obj_factory(
        pil_transforms2) if pil_transforms2 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 +
                                                          tensor_transforms1)
    img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 +
                                                          tensor_transforms2)

    # Process source image
    source_bgr = cv2.imread(source_path)
    source_rgb = source_bgr[:, :, ::-1]
    source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size)
    if source_bbox is None:
        raise RuntimeError("Couldn't detect a face in source image: " +
                           source_path)
    source_tensor, source_landmarks, source_bbox = img_transforms1(
        source_rgb, source_landmarks, source_bbox)
    source_cropped_bgr = tensor2bgr(
        source_tensor[0] if isinstance(source_tensor, list) else source_tensor)
    for i in range(len(source_tensor)):
        source_tensor[i] = source_tensor[i].unsqueeze(0).to(device)

    # Extract landmarks, bounding boxes, euler angles, and 3D landmarks from target video
    frame_indices, landmarks, bboxes, eulers, landmarks_3d = \
        extract_landmarks_bboxes_euler_3d_from_video(target_path, Gp, fa, device=device)
    if frame_indices.size == 0:
        raise RuntimeError('No faces were detected in the target video: ' +
                           target_path)

    # Open target video file
    cap = cv2.VideoCapture(target_path)
    if not cap.isOpened():
        raise RuntimeError('Failed to read target video: ' + target_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Initialize output video file
    if output_path is not None:
        if os.path.isdir(output_path):
            output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \
                              os.path.splitext(os.path.basename(target_path))[0] + '.mp4'
            output_path = os.path.join(output_path, output_filename)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out_vid = cv2.VideoWriter(
            output_path, fourcc, fps,
            (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0]))
    else:
        out_vid = None

    # For each frame in the target video
    valid_frame_ind = 0
    for i in tqdm(range(total_frames)):
        ret, target_bgr = cap.read()
        if target_bgr is None:
            continue
        if i not in frame_indices:
            continue
        target_rgb = target_bgr[:, :, ::-1]
        target_tensor, target_landmarks, target_bbox = img_transforms2(
            target_rgb, landmarks_3d[valid_frame_ind], bboxes[valid_frame_ind])
        target_euler = eulers[valid_frame_ind]
        valid_frame_ind += 1

        # TODO: Calculate the number of required reenactment iterations
        reenactment_iterations = 2

        # Generate landmarks sequence
        target_landmarks_sequence = []
        for ri in range(1, reenactment_iterations):
            interp_landmarks = []
            for j in range(len(source_tensor)):
                alpha = float(ri) / reenactment_iterations
                curr_interp_landmarks_np = interpolate_points(
                    source_landmarks[j].cpu().numpy(),
                    target_landmarks[j].cpu().numpy(),
                    alpha=alpha)
                interp_landmarks.append(
                    torch.from_numpy(curr_interp_landmarks_np))
            target_landmarks_sequence.append(interp_landmarks)
        target_landmarks_sequence.append(target_landmarks)

        # Iterative reenactment
        out_img_tensor = source_tensor
        for curr_target_landmarks in target_landmarks_sequence:
            out_img_tensor = create_pyramid(out_img_tensor, 2)
            input_tensor = []
            for j in range(len(out_img_tensor)):
                curr_target_landmarks[j] = curr_target_landmarks[j].unsqueeze(
                    0).to(device)
                curr_target_landmarks[j] = landmarks2heatmaps[j](
                    curr_target_landmarks[j])
                input_tensor.append(
                    torch.cat((out_img_tensor[j], curr_target_landmarks[j]),
                              dim=1))
            out_img_tensor, out_seg_tensor = G(input_tensor)

        # Convert back to numpy images
        out_img_bgr = tensor2bgr(out_img_tensor)
        frame_cropped_bgr = tensor2bgr(target_tensor[0])

        # Render
        # for point in np.round(frame_landmarks).astype(int):
        #     cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1)
        render_img = np.concatenate(
            (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1)
        if out_vid is not None:
            out_vid.write(render_img)
        if out_vid is None or display:
            cv2.imshow('render_img', render_img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
Exemplo n.º 12
0
def main(dataset='opencv_video_seq_dataset.VideoSeqDataset',
         np_transforms=None,
         tensor_transforms=(
             'img_landmarks_transforms.ToTensor()',
             'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
         workers=4,
         batch_size=4):
    import time
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms = obj_factory(
        np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(np_transforms +
                                                      tensor_transforms)
    dataset = obj_factory(dataset, transform=img_transforms)
    # dataset = VideoSeqDataset(root_path, img_list_path, transform=img_transforms, frame_window=frame_window)
    dataloader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=workers,
                                 pin_memory=True,
                                 drop_last=True,
                                 shuffle=True)

    start = time.time()
    for frame_window, landmarks_window in dataloader:
        # print(frame_window.shape)

        if isinstance(frame_window, (list, tuple)):
            # For each batch
            for b in range(frame_window[0].shape[0]):
                # For each frame window in the list
                for p in range(len(frame_window)):
                    # For each frame in the window
                    for f in range(frame_window[p].shape[2]):
                        print(frame_window[p][b, :, f, :, :].shape)
                        # Render
                        render_img = tensor2bgr(
                            frame_window[p][b, :, f, :, :]).copy()
                        landmarks = landmarks_window[p][b, f, :, :].numpy()
                        # for point in np.round(landmarks).astype(int):
                        for point in landmarks:
                            cv2.circle(render_img, (point[0], point[1]), 2,
                                       (0, 0, 255), -1)
                        cv2.imshow('render_img', render_img)
                        if cv2.waitKey(0) & 0xFF == ord('q'):
                            break
        else:
            # For each batch
            for b in range(frame_window.shape[0]):
                # For each frame in the window
                for f in range(frame_window.shape[2]):
                    print(frame_window[b, :, f, :, :].shape)
                    # Render
                    render_img = tensor2bgr(frame_window[b, :, f, :, :]).copy()
                    landmarks = landmarks_window[b, f, :, :].numpy()
                    # for point in np.round(landmarks).astype(int):
                    for point in landmarks:
                        cv2.circle(render_img, (point[0], point[1]), 2,
                                   (0, 0, 255), -1)
                    cv2.imshow('render_img', render_img)
                    if cv2.waitKey(0) & 0xFF == ord('q'):
                        break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))
def main(
    # General arguments
    exp_dir, resume_dir=None, start_epoch=None, epochs=(90,), iterations=None, resolutions=(128, 256),
    lr_gen=(1e-4,), lr_dis=(1e-4,), gpus=None, workers=4, batch_size=(64,), seed=None, log_freq=20,

    # Data arguments
    train_dataset='opencv_video_seq_dataset.VideoSeqDataset', val_dataset=None, numpy_transforms=None,
    tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'),

    # Training arguments
    optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)',
    pretrained=False, criterion_pixelwise='nn.L1Loss', criterion_id='vgg_loss.VGGLoss',
    criterion_attr='vgg_loss.VGGLoss', criterion_gan='gan_loss.GANLoss(use_lsgan=True)',
    generator='res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)',
    discriminator='discriminators_pix2pix.MultiscaleDiscriminator',
    rec_weight=1.0, gan_weight=0.001
):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        G.train(train)
        D.train(train)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,  optimizer_G.param_groups[0]['lr']))

        # For each batch in the training data
        for i, (img, landmarks, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images and landmarks
                landmarks[1] = landmarks[1].to(device)
                for j in range(len(img)):
                    # landmarks[j] = landmarks[j].to(device)

                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0])):
                    context = res_landmarks_decoders[p](landmarks[1])
                    input.append(torch.cat((img[0][p], context), dim=1))

            # Reenactment
            img_pred = G(input)

            # Fake Detection and Loss
            img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in img_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            pred_real = D(img[1])
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(img_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction and segmentation loss
            loss_pixelwise = criterion_pixelwise(img_pred, img[1][0])
            loss_id = criterion_id(img_pred, img[1][0])
            loss_attr = criterion_attr(img_pred, img[1][0])
            loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr

            loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec,
                          g_gan=loss_G_GAN, d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            grid = img_utils.make_grid(img[0][0], img_pred, img[1][0])
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg

    #################
    # Main pipeline #
    #################

    # Validation
    resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions]
    lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen]
    lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis]
    epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs]
    batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size]
    iterations = iterations if iterations is None or isinstance(iterations, (list, tuple)) else [iterations]

    lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen
    lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis
    epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs
    batch_size = batch_size * len(resolutions) if len(batch_size) == 1 else batch_size
    if iterations is not None:
        iterations = iterations * len(resolutions) if len(iterations) == 1 else iterations
        iterations = utils.str2int(iterations)

    if not os.path.isdir(exp_dir):
        raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'')
    assert len(lr_gen) == len(resolutions)
    assert len(lr_dis) == len(resolutions)
    assert len(epochs) == len(resolutions)
    assert len(batch_size) == len(resolutions)
    assert iterations is None or len(iterations) == len(resolutions)

    # Seed
    utils.set_seed(seed)

    # Check CUDA device availability
    device, gpus = utils.set_device(gpus)

    # Initialize loggers
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Initialize datasets
    numpy_transforms = obj_factory(numpy_transforms) if numpy_transforms is not None else []
    tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_lms_pose_transforms.Compose(numpy_transforms + tensor_transforms)

    train_dataset = obj_factory(train_dataset, transform=img_transforms)
    if val_dataset is not None:
        val_dataset = obj_factory(val_dataset, transform=img_transforms)

    # Create networks
    G_arch = utils.get_arch(generator)
    D_arch = utils.get_arch(discriminator)
    G = obj_factory(generator).to(device)
    D = obj_factory(discriminator).to(device)

    # Resume from a checkpoint or initialize the networks weights randomly
    checkpoint_dir = exp_dir if resume_dir is None else resume_dir
    G_path = os.path.join(checkpoint_dir, 'G_latest.pth')
    D_path = os.path.join(checkpoint_dir, 'D_latest.pth')
    best_loss = 1e6
    curr_res = resolutions[0]
    optimizer_G_state, optimizer_D_state = None, None
    if os.path.isfile(G_path) and os.path.isfile(D_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # G
        checkpoint = torch.load(G_path)
        if 'resolution' in checkpoint:
            curr_res = checkpoint['resolution']
            start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
        # else:
        #     curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0]
        best_loss_key = 'best_loss_%d' % curr_res
        best_loss = checkpoint[best_loss_key] if best_loss_key in checkpoint else best_loss
        G.apply(utils.init_weights)
        G.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer_G_state = checkpoint['optimizer']

        # D
        D.apply(utils.init_weights)
        if os.path.isfile(D_path):
            checkpoint = torch.load(D_path)
            D.load_state_dict(checkpoint['state_dict'], strict=False)
            optimizer_D_state = checkpoint['optimizer']
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            G.apply(utils.init_weights)
            D.apply(utils.init_weights)

    # Initialize landmarks decoders
    landmarks_decoders = []
    for res in resolutions:
        landmarks_decoders.insert(0, landmarks_utils.LandmarksHeatMapDecoder(res).to(device))

    # Lossess
    criterion_pixelwise = obj_factory(criterion_pixelwise).to(device)
    criterion_id = obj_factory(criterion_id).to(device)
    criterion_attr = obj_factory(criterion_attr).to(device)
    criterion_gan = obj_factory(criterion_gan).to(device)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        G = nn.DataParallel(G, gpus)
        D = nn.DataParallel(D, gpus)
        criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus)
        criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus)
        landmarks_decoders = [nn.DataParallel(ld, gpus) for ld in landmarks_decoders]

    # For each resolution
    start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0]))
    start_epoch = 0 if start_epoch is None else start_epoch
    for ri in range(start_res_ind, len(resolutions)):
        res = resolutions[ri]
        res_lr_gen = lr_gen[ri]
        res_lr_dis = lr_dis[ri]
        res_epochs = epochs[ri]
        res_iterations = iterations[ri] if iterations is not None else None
        res_batch_size = batch_size[ri]
        res_landmarks_decoders = landmarks_decoders[-ri - 1:]

        # Optimizer and scheduler
        optimizer_G = obj_factory(optimizer, G.parameters(), lr=res_lr_gen)
        optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis)
        scheduler_G = obj_factory(scheduler, optimizer_G)
        scheduler_D = obj_factory(scheduler, optimizer_D)
        if optimizer_G_state is not None:
            optimizer_G.load_state_dict(optimizer_G_state)
            optimizer_G_state = None
        if optimizer_D_state is not None:
            optimizer_D.load_state_dict(optimizer_D_state)
            optimizer_D_state = None

        # Initialize data loaders
        if res_iterations is None:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, len(train_dataset))
        else:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, res_iterations)
        train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler,
                                              num_workers=workers, pin_memory=True, drop_last=True, shuffle=False)
        if val_dataset is not None:
            if res_iterations is None:
                val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, len(val_dataset))
            else:
                val_iterations = (res_iterations * len(val_dataset.classes)) // len(train_dataset.classes)
                val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, val_iterations)
            val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler,
                                                num_workers=workers, pin_memory=True, drop_last=True, shuffle=False)
        else:
            val_loader = None

        # For each epoch
        for epoch in range(start_epoch, res_epochs):
            total_loss = proces_epoch(train_loader, train=True)
            if val_loader is not None:
                with torch.no_grad():
                    total_loss = proces_epoch(val_loader, train=False)

            # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler_G.step(total_loss)
                scheduler_D.step(total_loss)
            else:
                scheduler_G.step()
                scheduler_D.step()

            # Save models checkpoints
            is_best = total_loss < best_loss
            best_loss = min(best_loss, total_loss)
            utils.save_checkpoint(exp_dir, 'G', {
                'resolution': res,
                'epoch': epoch + 1,
                'state_dict': G.module.state_dict() if gpus and len(gpus) > 1 else G.state_dict(),
                'optimizer': optimizer_G.state_dict(),
                'best_loss_%d' % res: best_loss,
                'arch': G_arch,
            }, is_best)
            utils.save_checkpoint(exp_dir, 'D', {
                'resolution': res,
                'epoch': epoch + 1,
                'state_dict': D.module.state_dict() if gpus and len(gpus) > 1 else D.state_dict(),
                'optimizer': optimizer_D.state_dict(),
                'best_loss_%d' % res: best_loss,
                'arch': D_arch,
            }, is_best)

        # Reset start epoch to 0 because it's should only effect the first training resolution
        start_epoch = 0
        best_loss = 1e6
Exemplo n.º 14
0
def main(dataset='fsgan.datasets.seq_dataset.SeqDataset', np_transforms=None,
         tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'),
         workers=4, batch_size=4):
    import time
    import fsgan
    from fsgan.utils.obj_factory import obj_factory
    from fsgan.utils.img_utils import tensor2bgr

    np_transforms = obj_factory(np_transforms) if np_transforms is not None else []
    tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_lms_pose_transforms.Compose(np_transforms + tensor_transforms)
    dataset = obj_factory(dataset, transform=img_transforms)
    # dataset = VideoSeqDataset(root_path, img_list_path, transform=img_transforms, frame_window=frame_window)
    dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=workers, pin_memory=True, drop_last=True,
                                 shuffle=True)

    start = time.time()
    if isinstance(dataset, fsgan.datasets.seq_dataset.SeqPairDataset):
        for frame, landmarks, pose, target in dataloader:
            pass
    elif isinstance(dataset, fsgan.datasets.seq_dataset.SeqDataset):
        for frame, landmarks, pose in dataloader:
            # For each batch
            for b in range(frame.shape[0]):
                # Render
                render_img = tensor2bgr(frame[b]).copy()
                curr_landmarks = landmarks[b].numpy() * render_img.shape[0]
                curr_pose = pose[b].numpy() * 99.

                for point in curr_landmarks:
                    cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
                msg = 'Pose: %.1f, %.1f, %.1f' % (curr_pose[0], curr_pose[1], curr_pose[2])
                cv2.putText(render_img, msg, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(0) & 0xFF == ord('q'):
                    break


        # print(frame_window.shape)

        # if isinstance(frame_window, (list, tuple)):
        #     # For each batch
        #     for b in range(frame_window[0].shape[0]):
        #         # For each frame window in the list
        #         for p in range(len(frame_window)):
        #             # For each frame in the window
        #             for f in range(frame_window[p].shape[2]):
        #                 print(frame_window[p][b, :, f, :, :].shape)
        #                 # Render
        #                 render_img = tensor2bgr(frame_window[p][b, :, f, :, :]).copy()
        #                 landmarks = landmarks_window[p][b, f, :, :].numpy()
        #                 # for point in np.round(landmarks).astype(int):
        #                 for point in landmarks:
        #                     cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
        #                 cv2.imshow('render_img', render_img)
        #                 if cv2.waitKey(0) & 0xFF == ord('q'):
        #                     break
        # else:
        #     # For each batch
        #     for b in range(frame_window.shape[0]):
        #         # For each frame in the window
        #         for f in range(frame_window.shape[2]):
        #             print(frame_window[b, :, f, :, :].shape)
        #             # Render
        #             render_img = tensor2bgr(frame_window[b, :, f, :, :]).copy()
        #             landmarks = landmarks_window[b, f, :, :].numpy()
        #             # for point in np.round(landmarks).astype(int):
        #             for point in landmarks:
        #                 cv2.circle(render_img, (point[0], point[1]), 2, (0, 0, 255), -1)
        #             cv2.imshow('render_img', render_img)
        #             if cv2.waitKey(0) & 0xFF == ord('q'):
        #                 break
    end = time.time()
    print('elapsed time: %f[s]' % (end - start))