Exemplo n.º 1
0
def main(
        # General arguments
        exp_dir,
        resume_dir=None,
        start_epoch=None,
        epochs=(90, ),
        iterations=None,
        resolutions=(128, 256),
        lr_gen=(1e-4, ),
        lr_dis=(1e-4, ),
        gpus=None,
        workers=4,
        batch_size=(64, ),
        seed=None,
        log_freq=20,

        # Data arguments
        train_dataset='opencv_video_seq_dataset.VideoSeqDataset',
        val_dataset=None,
        numpy_transforms=None,
        tensor_transforms=(
            'img_landmarks_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),

        # Training arguments
        optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)',
        scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)',
        pretrained=False,
        criterion_pixelwise='nn.L1Loss',
        criterion_id='vgg_loss.VGGLoss',
        criterion_attr='vgg_loss.VGGLoss',
        criterion_gan='gan_loss.GANLoss(use_lsgan=True)',
        generator='res_unet.MultiScaleResUNet(in_nc=4,out_nc=3)',
        discriminator='discriminators_pix2pix.MultiscaleDiscriminator',
        reenactment_model=None,
        seg_model=None,
        lms_model=None,
        pix_weight=0.1,
        rec_weight=1.0,
        gan_weight=0.001,
        background_value=-1.0):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        Gc.train(train)
        D.train(train)
        Gr.train(False)
        S.train(False)
        L.train(False)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,
            scheduler_G.get_lr()[0]))

        # For each batch in the training data
        for i, (img, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images
                for j in range(len(img)):
                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Compute context
                context = L(img[1][0].sub(context_mean).div(context_std))
                context = landmarks_utils.filter_landmarks(context)

                # Normalize each of the pyramid images
                for j in range(len(img)):
                    for p in range(len(img[j])):
                        img[j][p].sub_(img_mean).div_(img_std)

                # # Compute segmentation
                # seg = []
                # for j in range(len(img)):
                #     curr_seg = S(img[j][0])
                #     if curr_seg.shape[2:] != (res, res):
                #         curr_seg = F.interpolate(curr_seg, (res, res), mode='bicubic', align_corners=False)
                #     seg.append(curr_seg)

                # Compute segmentation
                target_seg = S(img[1][0])
                if target_seg.shape[2:] != (res, res):
                    target_seg = F.interpolate(target_seg, (res, res),
                                               mode='bicubic',
                                               align_corners=False)

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0]) - 1, -1, -1):
                    context = F.interpolate(context,
                                            size=img[0][p].shape[2:],
                                            mode='bicubic',
                                            align_corners=False)
                    input.insert(0, torch.cat((img[0][p], context), dim=1))

                # Reenactment
                reenactment_img = Gr(input)
                reenactment_seg = S(reenactment_img)
                if reenactment_img.shape[2:] != (res, res):
                    reenactment_img = F.interpolate(reenactment_img,
                                                    (res, res),
                                                    mode='bilinear',
                                                    align_corners=False)
                    reenactment_seg = F.interpolate(reenactment_seg,
                                                    (res, res),
                                                    mode='bilinear',
                                                    align_corners=False)

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Source face
                reenactment_face_mask = reenactment_seg.argmax(1) == 1
                inpainting_mask = seg_utils.random_hair_inpainting_mask_tensor(
                    reenactment_face_mask).to(device)
                reenactment_face_mask = reenactment_face_mask * (
                    inpainting_mask == 0)
                reenactment_img_with_hole = reenactment_img.masked_fill(
                    ~reenactment_face_mask.unsqueeze(1), background_value)

                # Target face
                target_face_mask = (target_seg.argmax(1) == 1).unsqueeze(1)
                inpainting_target = img[1][0]
                inpainting_target.masked_fill_(~target_face_mask,
                                               background_value)

                # Inpainting input
                inpainting_input = torch.cat(
                    (reenactment_img_with_hole, target_face_mask.float()),
                    dim=1)
                inpainting_input_pyd = img_utils.create_pyramid(
                    inpainting_input, len(img[0]))

            # Face inpainting
            inpainting_pred = Gc(inpainting_input_pyd)

            # Fake Detection and Loss
            inpainting_pred_pyd = img_utils.create_pyramid(
                inpainting_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in inpainting_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            inpainting_target_pyd = img_utils.create_pyramid(
                inpainting_target, len(img[0]))
            pred_real = D(inpainting_target_pyd)
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(inpainting_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction
            loss_pixelwise = criterion_pixelwise(inpainting_pred,
                                                 inpainting_target)
            loss_id = criterion_id(inpainting_pred, inpainting_target)
            loss_attr = criterion_attr(inpainting_pred, inpainting_target)
            loss_rec = pix_weight * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr

            loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses',
                          pixelwise=loss_pixelwise,
                          id=loss_id,
                          attr=loss_attr,
                          rec=loss_rec,
                          g_gan=loss_G_GAN,
                          d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg(
            '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            grid = img_utils.make_grid(img[0][0], reenactment_img,
                                       reenactment_img_with_hole,
                                       inpainting_pred, inpainting_target)
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg

    #################
    # Main pipeline #
    #################

    # Validation
    resolutions = resolutions if isinstance(resolutions,
                                            (list, tuple)) else [resolutions]
    lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen]
    lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis]
    epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs]
    batch_size = batch_size if isinstance(batch_size,
                                          (list, tuple)) else [batch_size]
    iterations = iterations if iterations is None or isinstance(
        iterations, (list, tuple)) else [iterations]

    lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen
    lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis
    epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs
    batch_size = batch_size * len(resolutions) if len(
        batch_size) == 1 else batch_size
    if iterations is not None:
        iterations = iterations * len(resolutions) if len(
            iterations) == 1 else iterations
        iterations = utils.str2int(iterations)

    if not os.path.isdir(exp_dir):
        raise RuntimeError('Experiment directory was not found: \'' + exp_dir +
                           '\'')
    assert len(lr_gen) == len(resolutions)
    assert len(lr_dis) == len(resolutions)
    assert len(epochs) == len(resolutions)
    assert len(batch_size) == len(resolutions)
    assert iterations is None or len(iterations) == len(resolutions)

    # Seed
    utils.set_seed(seed)

    # Check CUDA device availability
    device, gpus = utils.set_device(gpus)

    # Initialize loggers
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Initialize datasets
    numpy_transforms = obj_factory(
        numpy_transforms) if numpy_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(numpy_transforms +
                                                      tensor_transforms)

    train_dataset = obj_factory(train_dataset, transform=img_transforms)
    if val_dataset is not None:
        val_dataset = obj_factory(val_dataset, transform=img_transforms)

    # Create networks
    Gc = obj_factory(generator).to(device)
    D = obj_factory(discriminator).to(device)

    # Resume from a checkpoint or initialize the networks weights randomly
    checkpoint_dir = exp_dir if resume_dir is None else resume_dir
    Gc_path = os.path.join(checkpoint_dir, 'Gc_latest.pth')
    D_path = os.path.join(checkpoint_dir, 'D_latest.pth')
    best_loss = 1000000.
    curr_res = resolutions[0]
    optimizer_G_state, optimizer_D_state = None, None
    if os.path.isfile(Gc_path) and os.path.isfile(D_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # Gc
        checkpoint = torch.load(Gc_path)
        if 'resolution' in checkpoint:
            curr_res = checkpoint['resolution']
            start_epoch = checkpoint[
                'epoch'] if start_epoch is None else start_epoch
        else:
            curr_res = resolutions[1] if len(
                resolutions) > 1 else resolutions[0]
        best_loss = checkpoint['best_loss']
        Gc.apply(utils.init_weights)
        Gc.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer_G_state = checkpoint['optimizer']

        # D
        D.apply(utils.init_weights)
        if os.path.isfile(D_path):
            checkpoint = torch.load(D_path)
            D.load_state_dict(checkpoint['state_dict'], strict=False)
            optimizer_D_state = checkpoint['optimizer']
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            Gc.apply(utils.init_weights)
            D.apply(utils.init_weights)

    # Load reenactment model
    print('=> Loading face reenactment model: "' +
          os.path.basename(reenactment_model) + '"...')
    if reenactment_model is None:
        raise RuntimeError('Reenactment model must be specified!')
    if not os.path.exists(reenactment_model):
        raise RuntimeError('Couldn\'t find reenactment model in path: ' +
                           reenactment_model)
    checkpoint = torch.load(reenactment_model)
    Gr = obj_factory(checkpoint['arch']).to(device)
    Gr.load_state_dict(checkpoint['state_dict'])

    # Load segmentation model
    print('=> Loading face segmentation model: "' +
          os.path.basename(seg_model) + '"...')
    if seg_model is None:
        raise RuntimeError('Segmentation model must be specified!')
    if not os.path.exists(seg_model):
        raise RuntimeError('Couldn\'t find segmentation model in path: ' +
                           seg_model)
    checkpoint = torch.load(seg_model)
    S = obj_factory(checkpoint['arch']).to(device)
    S.load_state_dict(checkpoint['state_dict'])

    # Load face landmarks model
    print('=> Loading face landmarks model: "' + os.path.basename(lms_model) +
          '"...')
    assert os.path.isfile(
        lms_model), 'The model path "%s" does not exist' % lms_model
    L = hrnet_wlfw().to(device)
    state_dict = torch.load(lms_model)
    L.load_state_dict(state_dict)

    # Initialize normalization tensors
    # Note: this is necessary because of the landmarks model
    img_mean = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1)
    img_std = torch.as_tensor([0.5, 0.5, 0.5], device=device).view(1, 3, 1, 1)
    context_mean = torch.as_tensor([0.485, 0.456, 0.406],
                                   device=device).view(1, 3, 1, 1)
    context_std = torch.as_tensor([0.229, 0.224, 0.225],
                                  device=device).view(1, 3, 1, 1)

    # Lossess
    criterion_pixelwise = obj_factory(criterion_pixelwise).to(device)
    criterion_id = obj_factory(criterion_id).to(device)
    criterion_attr = obj_factory(criterion_attr).to(device)
    criterion_gan = obj_factory(criterion_gan).to(device)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        Gc = nn.DataParallel(Gc, gpus)
        Gr = nn.DataParallel(Gr, gpus)
        D = nn.DataParallel(D, gpus)
        S = nn.DataParallel(S, gpus)
        L = nn.DataParallel(L, gpus)
        criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus)
        criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus)

    # For each resolution
    start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0]))
    start_epoch = 0 if start_epoch is None else start_epoch
    for ri in range(start_res_ind, len(resolutions)):
        res = resolutions[ri]
        res_lr_gen = lr_gen[ri]
        res_lr_dis = lr_dis[ri]
        res_epochs = epochs[ri]
        res_iterations = iterations[ri] if iterations is not None else None
        res_batch_size = batch_size[ri]

        # Optimizer and scheduler
        optimizer_G = obj_factory(optimizer, Gc.parameters(), lr=res_lr_gen)
        optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis)
        scheduler_G = obj_factory(scheduler, optimizer_G)
        scheduler_D = obj_factory(scheduler, optimizer_D)
        if optimizer_G_state is not None:
            optimizer_G.load_state_dict(optimizer_G_state)
            optimizer_G_state = None
        if optimizer_D_state is not None:
            optimizer_D.load_state_dict(optimizer_D_state)
            optimizer_D_state = None

        # Initialize data loaders
        if res_iterations is None:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, len(train_dataset))
        else:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, res_iterations)
        train_loader = tutils.data.DataLoader(train_dataset,
                                              batch_size=res_batch_size,
                                              sampler=train_sampler,
                                              num_workers=workers,
                                              pin_memory=True,
                                              drop_last=True,
                                              shuffle=False)
        if val_dataset is not None:
            if res_iterations is None:
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, len(val_dataset))
            else:
                val_iterations = (res_iterations * len(
                    val_dataset.classes)) // len(train_dataset.classes)
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, val_iterations)
            val_loader = tutils.data.DataLoader(val_dataset,
                                                batch_size=res_batch_size,
                                                sampler=val_sampler,
                                                num_workers=workers,
                                                pin_memory=True,
                                                drop_last=True,
                                                shuffle=False)
        else:
            val_loader = None

        # For each epoch
        for epoch in range(start_epoch, res_epochs):
            total_loss = proces_epoch(train_loader, train=True)
            if val_loader is not None:
                with torch.no_grad():
                    total_loss = proces_epoch(val_loader, train=False)

            # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
            if isinstance(scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler_G.step(total_loss)
                scheduler_D.step(total_loss)
            else:
                scheduler_G.step()
                scheduler_D.step()

            # Save models checkpoints
            is_best = total_loss < best_loss
            best_loss = min(best_loss, total_loss)
            utils.save_checkpoint(
                exp_dir, 'Gc', {
                    'resolution':
                    res,
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    Gc.module.state_dict()
                    if gpus and len(gpus) > 1 else Gc.state_dict(),
                    'optimizer':
                    optimizer_G.state_dict(),
                    'best_loss':
                    best_loss,
                }, is_best)
            utils.save_checkpoint(
                exp_dir, 'D', {
                    'resolution':
                    res,
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    D.module.state_dict()
                    if gpus and len(gpus) > 1 else D.state_dict(),
                    'optimizer':
                    optimizer_D.state_dict(),
                    'best_loss':
                    best_loss,
                }, is_best)

        # Reset start epoch to 0 because it's should only effect the first training resolution
        start_epoch = 0
Exemplo n.º 2
0
def main(
        source_path,
        target_path,
        arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)',
        reenactment_model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth',
        seg_model_path='../weights/lfw_figaro_unet_256_2_0_segmentation_v1.pth',
        inpainting_model_path='../weights/ijbc_msrunet_256_2_0_inpainting_v1.pth',
        blend_model_path='../weights/ijbc_msrunet_256_2_0_blending_v1.pth',
        pose_model_path='../weights/hopenet_robust_alpha1.pth',
        pil_transforms1=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        pil_transforms2=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)',
                         'landmark_transforms.LandmarksToHeatmaps'),
        tensor_transforms1=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        tensor_transforms2=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        output_path=None,
        min_radius=2.0,
        crop_size=256,
        reverse_output=False,
        verbose=0,
        output_crop=False,
        display=False):
    torch.set_grad_enabled(False)

    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
                                      flip_input=False)
    device, gpus = utils.set_device()

    Gr = obj_factory(arch).to(device)
    checkpoint = torch.load(reenactment_model_path)
    Gr.load_state_dict(checkpoint['state_dict'])
    Gr.train(False)

    if seg_model_path is not None:
        print('Loading face segmentation model: "' +
              os.path.basename(seg_model_path) + '"...')
        if seg_model_path.endswith('.pth'):
            checkpoint = torch.load(seg_model_path)
            Gs = obj_factory(checkpoint['arch']).to(device)
            Gs.load_state_dict(checkpoint['state_dict'])
        else:
            Gs = torch.jit.load(seg_model_path, map_location=device)
        if Gs is None:
            raise RuntimeError('Failed to load face segmentation model!')
            Gs.eval()
    else:
        Gs = None

    if seg_model_path is not None:
        print('Loading face inpainting model: "' +
              os.path.basename(inpainting_model_path) + '"...')
        if inpainting_model_path.endswith('.pth'):
            checkpoint = torch.load(inpainting_model_path)
            Gi = obj_factory(checkpoint['arch']).to(device)
            Gi.load_state_dict(checkpoint['state_dict'])
        else:
            Gi = torch.jit.load(inpainting_model_path, map_location=device)
        if Gi is None:
            raise RuntimeError('Failed to load face segmentation model!')
        Gi.eval()
    else:
        Gi = None

    checkpoint = torch.load(blend_model_path)
    Gb = obj_factory(checkpoint['arch']).to(device)
    Gb.load_state_dict(checkpoint['state_dict'])
    Gb.train(False)

    Gp = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    Gp.load_state_dict(checkpoint['state_dict'])
    Gp.train(False)

    pil_transforms1 = obj_factory(
        pil_transforms1) if pil_transforms1 is not None else []
    pil_transforms2 = obj_factory(
        pil_transforms2) if pil_transforms2 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 +
                                                          tensor_transforms1)
    img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 +
                                                          tensor_transforms2)

    source_frame_indices, source_landmarks, source_bboxes, source_eulers = \
        extract_landmarks_bboxes_euler_from_images(source_path, Gp, fa, device=device)
    if source_frame_indices.size == 0:
        raise RuntimeError(
            'No faces were detected in the source image directory: ' +
            source_path)

    target_frame_indices, target_landmarks, target_bboxes, target_eulers = \
        extract_landmarks_bboxes_euler_from_images(target_path, Gp, fa, device=device)
    if target_frame_indices.size == 0:
        raise RuntimeError(
            'No faces were detected in the target image directory: ' +
            target_path)

    source_img_paths = glob(os.path.join(source_path, '*.jpg'))
    target_img_paths = glob(os.path.join(target_path, '*.jpg'))

    source_valid_frame_ind = 0
    for k, source_img_path in tqdm(enumerate(source_img_paths),
                                   unit='images',
                                   total=len(source_img_paths)):
        if k not in source_frame_indices:
            continue
        source_img_bgr = cv2.imread(source_img_path)
        if source_img_bgr is None:
            continue
        source_img_rgb = source_img_bgr[:, :, ::-1]
        curr_source_tensor, curr_source_landmarks, curr_source_bbox = img_transforms1(
            source_img_rgb, source_landmarks[source_valid_frame_ind],
            source_bboxes[source_valid_frame_ind])
        source_valid_frame_ind += 1

        for j in range(len(curr_source_tensor)):
            curr_source_tensor[j] = curr_source_tensor[j].to(device)

        target_valid_frame_ind = 0
        for i, target_img_path in enumerate(target_img_paths):
            curr_output_name = '_'.join([
                os.path.splitext(os.path.basename(source_img_path))[0],
                os.path.splitext(os.path.basename(target_img_path))[0]
            ]) + '.jpg'
            curr_output_path = os.path.join(output_path, curr_output_name)
            if os.path.isfile(curr_output_path):
                target_valid_frame_ind += 1
                continue
            target_img_bgr = cv2.imread(target_img_path)
            if target_img_bgr is None:
                continue
            if i not in target_frame_indices:
                continue
            target_img_rgb = target_img_bgr[:, :, ::-1]

            curr_target_tensor, curr_target_landmarks, curr_target_bbox = img_transforms2(
                target_img_rgb, target_landmarks[target_valid_frame_ind],
                target_bboxes[target_valid_frame_ind])
            curr_target_euler = target_eulers[target_valid_frame_ind]
            target_valid_frame_ind += 1

            reenactment_input_tensor = []
            for j in range(len(curr_source_tensor)):
                curr_target_landmarks[j] = curr_target_landmarks[j].to(device)
                reenactment_input_tensor.append(
                    torch.cat(
                        (curr_source_tensor[j], curr_target_landmarks[j]),
                        dim=0).unsqueeze(0))
            reenactment_img_tensor, reenactment_seg_tensor = Gr(
                reenactment_input_tensor)

            target_img_tensor = curr_target_tensor[0].unsqueeze(0).to(device)
            target_seg_pred_tensor = Gs(target_img_tensor)
            target_mask_tensor = target_seg_pred_tensor.argmax(1) == 1

            aligned_face_mask_tensor = reenactment_seg_tensor.argmax(1) == 1
            aligned_background_mask_tensor = ~aligned_face_mask_tensor
            aligned_img_no_background_tensor = reenactment_img_tensor.clone()
            aligned_img_no_background_tensor.masked_fill_(
                aligned_background_mask_tensor.unsqueeze(1), -1.0)

            inpainting_input_tensor = torch.cat(
                (aligned_img_no_background_tensor,
                 target_mask_tensor.unsqueeze(1).float()),
                dim=1)
            inpainting_input_tensor_pyd = create_pyramid(
                inpainting_input_tensor, len(curr_target_tensor))
            completion_tensor = Gi(inpainting_input_tensor_pyd)

            transfer_tensor = transfer_mask(completion_tensor,
                                            target_img_tensor,
                                            target_mask_tensor)
            blend_input_tensor = torch.cat(
                (transfer_tensor, target_img_tensor,
                 target_mask_tensor.unsqueeze(1).float()),
                dim=1)
            blend_input_tensor_pyd = create_pyramid(blend_input_tensor,
                                                    len(curr_target_tensor))
            blend_tensor = Gb(blend_input_tensor_pyd)

            blend_img = tensor2bgr(blend_tensor)

            if verbose == 0:
                render_img = blend_img if output_crop else crop2img(
                    target_img_bgr, blend_img, curr_target_bbox[0].numpy())
            elif verbose == 1:
                reenactment_only_tensor = transfer_mask(
                    reenactment_img_tensor, target_img_tensor,
                    aligned_face_mask_tensor & target_mask_tensor)
                reenactment_only_img = tensor2bgr(reenactment_only_tensor)

                completion_only_img = tensor2bgr(transfer_tensor)

                transfer_tensor = transfer_mask(
                    aligned_img_no_background_tensor, target_img_tensor,
                    target_mask_tensor)
                blend_input_tensor = torch.cat(
                    (transfer_tensor, target_img_tensor,
                     target_mask_tensor.unsqueeze(1).float()),
                    dim=1)
                blend_input_tensor_pyd = create_pyramid(
                    blend_input_tensor, len(curr_target_tensor))
                blend_tensor = Gb(blend_input_tensor_pyd)
                blend_only_img = tensor2bgr(blend_tensor)

                render_img = np.concatenate(
                    (reenactment_only_img, completion_only_img, blend_only_img,
                     blend_img),
                    axis=1)
            elif verbose == 2:
                reenactment_img_bgr = tensor2bgr(reenactment_img_tensor)
                reenactment_seg_bgr = tensor2bgr(
                    blend_seg_pred(reenactment_img_tensor,
                                   reenactment_seg_tensor))
                target_seg_bgr = tensor2bgr(
                    blend_seg_pred(target_img_tensor, target_seg_pred_tensor))
                aligned_img_no_background_bgr = tensor2bgr(
                    aligned_img_no_background_tensor)
                completion_bgr = tensor2bgr(completion_tensor)
                transfer_bgr = tensor2bgr(transfer_tensor)
                target_cropped_bgr = tensor2bgr(target_img_tensor)

                pose_axis_bgr = draw_axis(np.zeros_like(target_cropped_bgr),
                                          curr_target_euler[0],
                                          curr_target_euler[1],
                                          curr_target_euler[2])
                render_img1 = np.concatenate(
                    (reenactment_img_bgr, reenactment_seg_bgr, target_seg_bgr),
                    axis=1)
                render_img2 = np.concatenate((aligned_img_no_background_bgr,
                                              completion_bgr, transfer_bgr),
                                             axis=1)
                render_img3 = np.concatenate(
                    (pose_axis_bgr, blend_img, target_cropped_bgr), axis=1)
                render_img = np.concatenate(
                    (render_img1, render_img2, render_img3), axis=0)
            elif verbose == 3:
                source_cropped_bgr = tensor2bgr(
                    curr_source_tensor[0].unsqueeze(0))
                target_cropped_bgr = tensor2bgr(target_img_tensor)
                render_img = np.concatenate(
                    (source_cropped_bgr, target_cropped_bgr, blend_img),
                    axis=1)
            cv2.imwrite(curr_output_path, render_img)
            if display:
                cv2.imshow('render_img', render_img)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
Exemplo n.º 3
0
    def __init__(
        self,
        resolution=d('resolution'),
        crop_scale=d('crop_scale'),
        gpus=d('gpus'),
        cpu_only=d('cpu_only'),
        display=d('display'),
        verbose=d('verbose'),
        # Detection arguments:
        detection_model=d('detection_model'),
        det_batch_size=d('det_batch_size'),
        det_postfix=d('det_postfix'),
        # Sequence arguments:
        iou_thresh=d('iou_thresh'),
        min_length=d('min_length'),
        min_size=d('min_size'),
        center_kernel=d('center_kernel'),
        size_kernel=d('size_kernel'),
        smooth_det=d('smooth_det'),
        seq_postfix=d('seq_postfix'),
        write_empty=d('write_empty'),
        # Pose arguments:
        pose_model=d('pose_model'),
        pose_batch_size=d('pose_batch_size'),
        pose_postfix=d('pose_postfix'),
        cache_pose=d('cache_pose'),
        cache_frontal=d('cache_frontal'),
        smooth_poses=d('smooth_poses'),
        # Landmarks arguments:
        lms_model=d('lms_model'),
        lms_batch_size=d('lms_batch_size'),
        landmarks_postfix=d('landmarks_postfix'),
        cache_landmarks=d('cache_landmarks'),
        smooth_landmarks=d('smooth_landmarks'),
        # Segmentation arguments:
        seg_model=d('seg_model'),
        seg_batch_size=d('seg_batch_size'),
        segmentation_postfix=d('segmentation_postfix'),
        cache_segmentation=d('cache_segmentation'),
        smooth_segmentation=d('smooth_segmentation'),
        seg_remove_mouth=d('seg_remove_mouth')):
        # General
        self.resolution = resolution
        self.crop_scale = crop_scale
        self.display = display
        self.verbose = verbose

        # Detection
        self.face_detector = FaceDetector(det_postfix, detection_model, gpus,
                                          det_batch_size, display)
        self.det_postfix = det_postfix

        # Sequences
        self.iou_thresh = iou_thresh
        self.min_length = min_length
        self.min_size = min_size
        self.center_kernel = center_kernel
        self.size_kernel = size_kernel
        self.smooth_det = smooth_det
        self.seq_postfix = seq_postfix
        self.write_empty = write_empty

        # Pose
        self.pose_batch_size = pose_batch_size
        self.pose_postfix = pose_postfix
        self.cache_pose = cache_pose
        self.cache_frontal = cache_frontal
        self.smooth_poses = smooth_poses

        # Landmarks
        self.smooth_landmarks = smooth_landmarks
        self.landmarks_postfix = landmarks_postfix
        self.cache_landmarks = cache_landmarks
        self.lms_batch_size = lms_batch_size

        # Segmentation
        self.smooth_segmentation = smooth_segmentation
        self.segmentation_postfix = segmentation_postfix
        self.cache_segmentation = cache_segmentation
        self.seg_batch_size = seg_batch_size
        self.seg_remove_mouth = seg_remove_mouth and cache_landmarks

        # Initialize device
        torch.set_grad_enabled(False)
        self.device, self.gpus = set_device(gpus, not cpu_only)

        # Load models
        self.face_pose = load_model(pose_model, 'face pose',
                                    self.device) if cache_pose else None
        self.L = load_model(lms_model, 'face landmarks',
                            self.device) if cache_landmarks else None
        self.S = load_model(seg_model, 'face segmentation',
                            self.device) if cache_segmentation else None

        # Initialize heatmap encoder
        self.heatmap_encoder = LandmarksHeatMapEncoder().to(self.device)

        # Initialize normalization tensors
        # Note: this is necessary because of the landmarks model
        self.img_mean = torch.as_tensor([0.5, 0.5, 0.5],
                                        device=self.device).view(1, 3, 1, 1)
        self.img_std = torch.as_tensor([0.5, 0.5, 0.5],
                                       device=self.device).view(1, 3, 1, 1)
        self.context_mean = torch.as_tensor([0.485, 0.456, 0.406],
                                            device=self.device).view(
                                                1, 3, 1, 1)
        self.context_std = torch.as_tensor([0.229, 0.224, 0.225],
                                           device=self.device).view(
                                               1, 3, 1, 1)

        # Support multiple GPUs
        if self.gpus and len(self.gpus) > 1:
            self.face_pose = nn.DataParallel(
                self.face_pose,
                self.gpus) if self.face_pose is not None else None
            self.L = nn.DataParallel(self.L,
                                     self.gpus) if self.L is not None else None
            self.S = nn.DataParallel(self.S,
                                     self.gpus) if self.S is not None else None

        # Initialize temportal smoothing
        if smooth_segmentation > 0:
            self.smooth_seg = TemporalSmoothing(3, smooth_segmentation).to(
                self.device)
        else:
            self.smooth_seg = None

        # Initialize output videos format
        self.fourcc = cv2.VideoWriter_fourcc(*'avc1')
Exemplo n.º 4
0
def main(
        # General arguments
        exp_dir,
        resume_dir=None,
        start_epoch=None,
        epochs=(90, ),
        iterations=None,
        resolutions=(128, 256),
        learning_rate=(1e-1, ),
        gpus=None,
        workers=4,
        batch_size=(64, ),
        seed=None,
        log_freq=20,

        # Data arguments
        train_dataset='fsgan.image_seg_dataset.ImageSegDataset',
        val_dataset=None,
        numpy_transforms=None,
        tensor_transforms=(
            'img_landmarks_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),

        # Training arguments
        optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)',
        scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)',
        criterion='nn.CrossEntropyLoss',
        model='fsgan.models.simple_unet.UNet(n_classes=3,feature_scale=1)',
        pretrained=False,
        benchmark='fsgan.train_segmentation.IOUBenchmark(3)'):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        model.train(train)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,
            scheduler.get_lr()[0]))

        # For each batch in the training data
        for i, (input, target) in enumerate(pbar):
            # Prepare input
            input = input.to(device)
            target = target.to(device)
            with torch.no_grad():
                target = target.argmax(dim=1)

            # Execute model
            pred = model(input)

            # Calculate loss
            loss_total = criterion(pred, target)

            # Run benchmark
            benchmark_res = benchmark(pred,
                                      target) if benchmark is not None else {}

            if train:
                # Update generator weights
                optimizer.zero_grad()
                loss_total.backward()
                optimizer.step()

            logger.update('losses', total=loss_total)
            logger.update('bench', **benchmark_res)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg(
            '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            seg_pred = blend_seg_pred(input, pred)
            seg_gt = blend_seg_label(input, target)
            grid = img_utils.make_grid(input, seg_pred, seg_gt)
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['total'].avg

    #################
    # Main pipeline #
    #################

    # Validation
    resolutions = resolutions if isinstance(resolutions,
                                            (list, tuple)) else [resolutions]
    learning_rate = learning_rate if isinstance(learning_rate,
                                                (list,
                                                 tuple)) else [learning_rate]
    epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs]
    batch_size = batch_size if isinstance(batch_size,
                                          (list, tuple)) else [batch_size]
    iterations = iterations if iterations is None or isinstance(
        iterations, (list, tuple)) else [iterations]

    learning_rate = learning_rate * len(resolutions) if len(
        learning_rate) == 1 else learning_rate
    epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs
    batch_size = batch_size * len(resolutions) if len(
        batch_size) == 1 else batch_size
    if iterations is not None:
        iterations = iterations * len(resolutions) if len(
            iterations) == 1 else iterations
        iterations = utils.str2int(iterations)

    if not os.path.isdir(exp_dir):
        raise RuntimeError('Experiment directory was not found: \'' + exp_dir +
                           '\'')
    assert len(learning_rate) == len(resolutions)
    assert len(epochs) == len(resolutions)
    assert len(batch_size) == len(resolutions)
    assert iterations is None or len(iterations) == len(resolutions)

    # Seed
    utils.set_seed(seed)

    # Check CUDA device availability
    device, gpus = utils.set_device(gpus)

    # Initialize loggers
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Initialize datasets
    numpy_transforms = obj_factory(
        numpy_transforms) if numpy_transforms is not None else []
    tensor_transforms = obj_factory(
        tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_landmarks_transforms.Compose(numpy_transforms +
                                                      tensor_transforms)

    train_dataset = obj_factory(train_dataset, transform=img_transforms)
    if val_dataset is not None:
        val_dataset = obj_factory(val_dataset, transform=img_transforms)

    # Create networks
    arch = utils.get_arch(model, num_classes=len(train_dataset.classes))
    model = obj_factory(model,
                        num_classes=len(train_dataset.classes)).to(device)

    # Resume from a checkpoint or initialize the networks weights randomly
    checkpoint_dir = exp_dir if resume_dir is None else resume_dir
    model_path = os.path.join(checkpoint_dir, 'model_latest.pth')
    best_loss = 1e6
    curr_res = resolutions[0]
    optimizer_state = None
    if os.path.isfile(model_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # model
        checkpoint = torch.load(model_path)
        if 'resolution' in checkpoint:
            curr_res = checkpoint['resolution']
            start_epoch = checkpoint[
                'epoch'] if start_epoch is None else start_epoch
        # else:
        #     curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0]
        best_loss_key = 'best_loss_%d' % curr_res
        best_loss = checkpoint[
            best_loss_key] if best_loss_key in checkpoint else best_loss
        model.apply(utils.init_weights)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer_state = checkpoint['optimizer']
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            model.apply(utils.init_weights)

    # Lossess
    criterion = obj_factory(criterion).to(device)

    # Benchmark
    benchmark = obj_factory(benchmark).to(device)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        model = nn.DataParallel(model, gpus)

    # For each resolution
    start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0]))
    start_epoch = 0 if start_epoch is None else start_epoch
    for ri in range(start_res_ind, len(resolutions)):
        res = resolutions[ri]
        res_lr = learning_rate[ri]
        res_epochs = epochs[ri]
        res_iterations = iterations[ri] if iterations is not None else None
        res_batch_size = batch_size[ri]

        # Optimizer and scheduler
        optimizer = obj_factory(optimizer, model.parameters(), lr=res_lr)
        scheduler = obj_factory(scheduler, optimizer)
        if optimizer_state is not None:
            optimizer.load_state_dict(optimizer_state)

        # Initialize data loaders
        if res_iterations is None:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, len(train_dataset))
        else:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(
                train_dataset.weights, res_iterations)
        train_loader = tutils.data.DataLoader(train_dataset,
                                              batch_size=res_batch_size,
                                              sampler=train_sampler,
                                              num_workers=workers,
                                              pin_memory=True,
                                              drop_last=True,
                                              shuffle=False)
        if val_dataset is not None:
            if res_iterations is None:
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, len(val_dataset))
            else:
                val_iterations = (res_iterations *
                                  len(val_dataset)) // len(train_dataset)
                val_sampler = tutils.data.sampler.WeightedRandomSampler(
                    val_dataset.weights, val_iterations)
            val_loader = tutils.data.DataLoader(val_dataset,
                                                batch_size=res_batch_size,
                                                sampler=val_sampler,
                                                num_workers=workers,
                                                pin_memory=True,
                                                drop_last=True,
                                                shuffle=False)
        else:
            val_loader = None

        # For each epoch
        for epoch in range(start_epoch, res_epochs):
            total_loss = proces_epoch(train_loader, train=True)
            if val_loader is not None:
                with torch.no_grad():
                    total_loss = proces_epoch(val_loader, train=False)
            if hasattr(benchmark, 'reset'):
                benchmark.reset()

            # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
            if isinstance(scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(total_loss)
            else:
                scheduler.step()

            # Save models checkpoints
            is_best = total_loss < best_loss
            best_loss = min(best_loss, total_loss)
            utils.save_checkpoint(
                exp_dir, 'model', {
                    'resolution':
                    res,
                    'epoch':
                    epoch + 1,
                    'state_dict':
                    model.module.state_dict()
                    if gpus and len(gpus) > 1 else model.state_dict(),
                    'optimizer':
                    optimizer.state_dict(),
                    'best_loss_%d' % res:
                    best_loss,
                    'arch':
                    arch,
                }, is_best)

        # Reset start epoch to 0 because it's should only effect the first training resolution
        start_epoch = 0
        best_loss = 1e6
Exemplo n.º 5
0
def main(input_path,
         output_path=None,
         seq_postfix='_dsfd_seq.pkl',
         output_postfix='_dsfd_seq_lms_euler.pkl',
         pose_model_path='weights/hopenet_robust_alpha1.pkl',
         smooth_det=False,
         smooth_euler=False,
         gpus=None,
         cpu_only=False,
         batch_size=16):
    cache_path = os.path.splitext(input_path)[0] + seq_postfix
    output_path = os.path.splitext(
        input_path)[0] + output_postfix if output_path is None else output_path

    # Initialize device
    torch.set_grad_enabled(False)
    device, gpus = set_device(gpus, not cpu_only)

    # Load sequences from file
    with open(cache_path, "rb") as fp:  # Unpickling
        seq_list = pickle.load(fp)

    # Load pose model
    face_pose = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    face_pose.load_state_dict(checkpoint)
    face_pose.train(False)

    # Open input video file
    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        raise RuntimeError('Failed to read video: ' + input_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    input_vid_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    input_vid_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Smooth sequence bounding boxes
    if smooth_det:
        for seq in seq_list:
            seq.smooth()

    # For each sequence
    total_detections = sum([len(s) for s in seq_list])
    pbar = tqdm(range(total_detections), unit='detections')
    for seq in seq_list:
        euler = []
        frame_cropped_tensor_list = []
        cap.set(cv2.CAP_PROP_POS_FRAMES, seq.start_index)

        # For each detection bounding box in the current sequence
        for i, det in enumerate(seq.detections):
            ret, frame_bgr = cap.read()
            if frame_bgr is None:
                raise RuntimeError('Failed to read frame from video!')
            frame_rgb = frame_bgr[:, :, ::-1]

            # Crop frame
            bbox = np.concatenate((det[:2], det[2:] - det[:2]))
            bbox = scale_bbox(bbox, 1.2)
            frame_cropped_rgb = crop_img(frame_rgb, bbox)
            frame_cropped_rgb = cv2.resize(frame_cropped_rgb, (224, 224),
                                           interpolation=cv2.INTER_CUBIC)
            frame_cropped_tensor = rgb2tensor(frame_cropped_rgb).to(device)

            # Gather batches
            frame_cropped_tensor_list.append(frame_cropped_tensor)
            if len(frame_cropped_tensor_list) < batch_size and (i +
                                                                1) < len(seq):
                continue
            frame_cropped_tensor_batch = torch.cat(frame_cropped_tensor_list,
                                                   dim=0)

            # Calculate euler angles
            curr_euler_batch = face_pose(
                frame_cropped_tensor_batch)  # Yaw, Pitch, Roll
            curr_euler_batch = curr_euler_batch.cpu().numpy()

            # For each prediction in the batch
            for b, curr_euler in enumerate(curr_euler_batch):
                # Add euler to list
                euler.append(curr_euler)

                # Render
                # render_img = tensor2bgr(frame_cropped_tensor_batch[b]).copy()
                # cv2.putText(render_img, '(%.2f, %.2f, %.2f)' % (curr_euler[0], curr_euler[1], curr_euler[2]), (15, 15),
                #             cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
                # cv2.imshow('render_img', render_img)
                # if cv2.waitKey(0) & 0xFF == ord('q'):
                #     break

            # Clear lists
            frame_cropped_tensor_list.clear()

            pbar.update(len(frame_cropped_tensor_batch))

        # Add landmarks to sequence and optionally smooth them
        euler = np.array(euler)
        if smooth_euler:
            euler = smooth(euler)
        seq.euler = euler

    # Write final sequence list to file
    with open(output_path, "wb") as fp:  # Pickling
        pickle.dump(seq_list, fp)
Exemplo n.º 6
0
def main(
        source_path,
        target_path,
        arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)',
        model_path='../weights/ijbc_msrunet_256_2_0_reenactment_v1.pth',
        pose_model_path='../weights/hopenet_robust_alpha1.pth',
        pil_transforms1=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        pil_transforms2=('landmark_transforms.FaceAlignCrop',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)',
                         'landmark_transforms.LandmarksToHeatmaps'),
        tensor_transforms1=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        tensor_transforms2=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        output_path=None,
        crop_size=256,
        display=False):
    torch.set_grad_enabled(False)

    # Initialize models
    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
                                      flip_input=False)
    device, gpus = utils.set_device()
    G = obj_factory(arch).to(device)
    checkpoint = torch.load(model_path)
    G.load_state_dict(checkpoint['state_dict'])
    G.train(False)

    # Initialize pose
    Gp = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    Gp.load_state_dict(checkpoint['state_dict'])
    Gp.train(False)

    # Initialize transformations
    pil_transforms1 = obj_factory(
        pil_transforms1) if pil_transforms1 is not None else []
    pil_transforms2 = obj_factory(
        pil_transforms2) if pil_transforms2 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 +
                                                          tensor_transforms1)
    img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 +
                                                          tensor_transforms2)

    # Process source image
    source_bgr = cv2.imread(source_path)
    source_rgb = source_bgr[:, :, ::-1]
    source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size)
    if source_bbox is None:
        raise RuntimeError("Couldn't detect a face in source image: " +
                           source_path)
    source_tensor, source_landmarks, source_bbox = img_transforms1(
        source_rgb, source_landmarks, source_bbox)
    source_cropped_bgr = tensor2bgr(
        source_tensor[0] if isinstance(source_tensor, list) else source_tensor)
    for i in range(len(source_tensor)):
        source_tensor[i] = source_tensor[i].to(device)

    # Extract landmarks and bounding boxes from target video
    frame_indices, landmarks, bboxes, eulers = extract_landmarks_bboxes_euler_from_video(
        target_path, Gp, device=device)
    if frame_indices.size == 0:
        raise RuntimeError('No faces were detected in the target video: ' +
                           target_path)

    # Open target video file
    cap = cv2.VideoCapture(target_path)
    if not cap.isOpened():
        raise RuntimeError('Failed to read video: ' + target_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Initialize output video file
    if output_path is not None:
        if os.path.isdir(output_path):
            output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \
                              os.path.splitext(os.path.basename(target_path))[0] + '.mp4'
            output_path = os.path.join(output_path, output_filename)
            print(output_path)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out_vid = cv2.VideoWriter(
            output_path, fourcc, fps,
            (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0]))
    else:
        out_vid = None

    # For each frame in the target video
    valid_frame_ind = 0
    for i in tqdm(range(total_frames)):
        ret, frame = cap.read()
        if frame is None:
            continue
        if i not in frame_indices:
            continue
        frame_rgb = frame[:, :, ::-1]
        frame_tensor, frame_landmarks, frame_bbox = img_transforms2(
            frame_rgb, landmarks[valid_frame_ind], bboxes[valid_frame_ind])
        valid_frame_ind += 1

        # frame_cropped_rgb, frame_landmarks = process_cached_frame(frame_rgb, landmarks[valid_frame_ind],
        #                                                           bboxes[valid_frame_ind], size)
        # frame_cropped_bgr = frame_cropped_rgb[:, :, ::-1].copy()
        # valid_frame_ind += 1

        #
        # frame_tensor, frame_landmarks_tensor = prepare_generator_input(frame_cropped_rgb, frame_landmarks)
        # frame_landmarks_tensor.to(device)
        input_tensor = []
        for j in range(len(source_tensor)):
            frame_landmarks[j] = frame_landmarks[j].to(device)
            input_tensor.append(
                torch.cat((source_tensor[j], frame_landmarks[j]),
                          dim=0).unsqueeze(0).to(device))
        out_img_tensor, out_seg_tensor = G(input_tensor)

        # Transfer image1 mask to image2
        # face_mask_tensor = out_seg_tensor.argmax(1) == 1  # face
        # face_mask_tensor = out_seg_tensor.argmax(1) == 2    # hair
        # face_mask_tensor = out_seg_tensor.argmax(1) >= 1  # head

        # target_img_tensor = frame_tensor[0].view(1, frame_tensor[0].shape[0],
        #                                          frame_tensor[0].shape[1], frame_tensor[0].shape[2]).to(device)

        # Convert back to numpy images
        out_img_bgr = tensor2bgr(out_img_tensor)
        frame_cropped_bgr = tensor2bgr(frame_tensor[0])

        # Render
        # for point in np.round(frame_landmarks).astype(int):
        #     cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1)
        render_img = np.concatenate(
            (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1)
        if out_vid is not None:
            out_vid.write(render_img)
        if out_vid is None or display:
            cv2.imshow('render_img', render_img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
Exemplo n.º 7
0
def main(
        source_path,
        target_path,
        arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)',
        model_path='../weights/ijbc_msrunet_256_1_2_reenactment_stepwise_v1.pth',
        pose_model_path='../weights/hopenet_robust_alpha1.pth',
        pil_transforms1=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        pil_transforms2=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        tensor_transforms1=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        tensor_transforms2=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        output_path=None,
        crop_size=256,
        display=False):
    torch.set_grad_enabled(False)

    # Initialize models
    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D,
                                      flip_input=True)
    device, gpus = utils.set_device()
    G = obj_factory(arch).to(device)
    checkpoint = torch.load(model_path)
    G.load_state_dict(checkpoint['state_dict'])
    G.train(False)

    # Initialize pose
    Gp = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    Gp.load_state_dict(checkpoint['state_dict'])
    Gp.train(False)

    # Initialize landmarks to heatmaps
    landmarks2heatmaps = [
        LandmarkHeatmap(kernel_size=13, size=(256, 256)).to(device),
        LandmarkHeatmap(kernel_size=7, size=(128, 128)).to(device)
    ]

    # Initialize transformations
    pil_transforms1 = obj_factory(
        pil_transforms1) if pil_transforms1 is not None else []
    pil_transforms2 = obj_factory(
        pil_transforms2) if pil_transforms2 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 +
                                                          tensor_transforms1)
    img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 +
                                                          tensor_transforms2)

    # Process source image
    source_bgr = cv2.imread(source_path)
    source_rgb = source_bgr[:, :, ::-1]
    source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size)
    if source_bbox is None:
        raise RuntimeError("Couldn't detect a face in source image: " +
                           source_path)
    source_tensor, source_landmarks, source_bbox = img_transforms1(
        source_rgb, source_landmarks, source_bbox)
    source_cropped_bgr = tensor2bgr(
        source_tensor[0] if isinstance(source_tensor, list) else source_tensor)
    for i in range(len(source_tensor)):
        source_tensor[i] = source_tensor[i].unsqueeze(0).to(device)

    # Extract landmarks, bounding boxes, euler angles, and 3D landmarks from target video
    frame_indices, landmarks, bboxes, eulers, landmarks_3d = \
        extract_landmarks_bboxes_euler_3d_from_video(target_path, Gp, fa, device=device)
    if frame_indices.size == 0:
        raise RuntimeError('No faces were detected in the target video: ' +
                           target_path)

    # Open target video file
    cap = cv2.VideoCapture(target_path)
    if not cap.isOpened():
        raise RuntimeError('Failed to read target video: ' + target_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Initialize output video file
    if output_path is not None:
        if os.path.isdir(output_path):
            output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \
                              os.path.splitext(os.path.basename(target_path))[0] + '.mp4'
            output_path = os.path.join(output_path, output_filename)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out_vid = cv2.VideoWriter(
            output_path, fourcc, fps,
            (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0]))
    else:
        out_vid = None

    # For each frame in the target video
    valid_frame_ind = 0
    for i in tqdm(range(total_frames)):
        ret, target_bgr = cap.read()
        if target_bgr is None:
            continue
        if i not in frame_indices:
            continue
        target_rgb = target_bgr[:, :, ::-1]
        target_tensor, target_landmarks, target_bbox = img_transforms2(
            target_rgb, landmarks_3d[valid_frame_ind], bboxes[valid_frame_ind])
        target_euler = eulers[valid_frame_ind]
        valid_frame_ind += 1

        # TODO: Calculate the number of required reenactment iterations
        reenactment_iterations = 2

        # Generate landmarks sequence
        target_landmarks_sequence = []
        for ri in range(1, reenactment_iterations):
            interp_landmarks = []
            for j in range(len(source_tensor)):
                alpha = float(ri) / reenactment_iterations
                curr_interp_landmarks_np = interpolate_points(
                    source_landmarks[j].cpu().numpy(),
                    target_landmarks[j].cpu().numpy(),
                    alpha=alpha)
                interp_landmarks.append(
                    torch.from_numpy(curr_interp_landmarks_np))
            target_landmarks_sequence.append(interp_landmarks)
        target_landmarks_sequence.append(target_landmarks)

        # Iterative reenactment
        out_img_tensor = source_tensor
        for curr_target_landmarks in target_landmarks_sequence:
            out_img_tensor = create_pyramid(out_img_tensor, 2)
            input_tensor = []
            for j in range(len(out_img_tensor)):
                curr_target_landmarks[j] = curr_target_landmarks[j].unsqueeze(
                    0).to(device)
                curr_target_landmarks[j] = landmarks2heatmaps[j](
                    curr_target_landmarks[j])
                input_tensor.append(
                    torch.cat((out_img_tensor[j], curr_target_landmarks[j]),
                              dim=1))
            out_img_tensor, out_seg_tensor = G(input_tensor)

        # Convert back to numpy images
        out_img_bgr = tensor2bgr(out_img_tensor)
        frame_cropped_bgr = tensor2bgr(target_tensor[0])

        # Render
        # for point in np.round(frame_landmarks).astype(int):
        #     cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1)
        render_img = np.concatenate(
            (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1)
        if out_vid is not None:
            out_vid.write(render_img)
        if out_vid is None or display:
            cv2.imshow('render_img', render_img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
def main(
    # General arguments
    exp_dir, resume_dir=None, start_epoch=None, epochs=(90,), iterations=None, resolutions=(128, 256),
    lr_gen=(1e-4,), lr_dis=(1e-4,), gpus=None, workers=4, batch_size=(64,), seed=None, log_freq=20,

    # Data arguments
    train_dataset='opencv_video_seq_dataset.VideoSeqDataset', val_dataset=None, numpy_transforms=None,
    tensor_transforms=('img_lms_pose_transforms.ToTensor()', 'img_lms_pose_transforms.Normalize()'),

    # Training arguments
    optimizer='optim.SGD(momentum=0.9,weight_decay=1e-4)', scheduler='lr_scheduler.StepLR(step_size=30,gamma=0.1)',
    pretrained=False, criterion_pixelwise='nn.L1Loss', criterion_id='vgg_loss.VGGLoss',
    criterion_attr='vgg_loss.VGGLoss', criterion_gan='gan_loss.GANLoss(use_lsgan=True)',
    generator='res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)',
    discriminator='discriminators_pix2pix.MultiscaleDiscriminator',
    rec_weight=1.0, gan_weight=0.001
):
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        G.train(train)
        D.train(train)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,  optimizer_G.param_groups[0]['lr']))

        # For each batch in the training data
        for i, (img, landmarks, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images and landmarks
                landmarks[1] = landmarks[1].to(device)
                for j in range(len(img)):
                    # landmarks[j] = landmarks[j].to(device)

                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0])):
                    context = res_landmarks_decoders[p](landmarks[1])
                    input.append(torch.cat((img[0][p], context), dim=1))

            # Reenactment
            img_pred = G(input)

            # Fake Detection and Loss
            img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in img_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            pred_real = D(img[1])
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(img_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction and segmentation loss
            loss_pixelwise = criterion_pixelwise(img_pred, img[1][0])
            loss_id = criterion_id(img_pred, img[1][0])
            loss_attr = criterion_attr(img_pred, img[1][0])
            loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr

            loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec,
                          g_gan=loss_G_GAN, d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            grid = img_utils.make_grid(img[0][0], img_pred, img[1][0])
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg

    #################
    # Main pipeline #
    #################

    # Validation
    resolutions = resolutions if isinstance(resolutions, (list, tuple)) else [resolutions]
    lr_gen = lr_gen if isinstance(lr_gen, (list, tuple)) else [lr_gen]
    lr_dis = lr_dis if isinstance(lr_dis, (list, tuple)) else [lr_dis]
    epochs = epochs if isinstance(epochs, (list, tuple)) else [epochs]
    batch_size = batch_size if isinstance(batch_size, (list, tuple)) else [batch_size]
    iterations = iterations if iterations is None or isinstance(iterations, (list, tuple)) else [iterations]

    lr_gen = lr_gen * len(resolutions) if len(lr_gen) == 1 else lr_gen
    lr_dis = lr_dis * len(resolutions) if len(lr_dis) == 1 else lr_dis
    epochs = epochs * len(resolutions) if len(epochs) == 1 else epochs
    batch_size = batch_size * len(resolutions) if len(batch_size) == 1 else batch_size
    if iterations is not None:
        iterations = iterations * len(resolutions) if len(iterations) == 1 else iterations
        iterations = utils.str2int(iterations)

    if not os.path.isdir(exp_dir):
        raise RuntimeError('Experiment directory was not found: \'' + exp_dir + '\'')
    assert len(lr_gen) == len(resolutions)
    assert len(lr_dis) == len(resolutions)
    assert len(epochs) == len(resolutions)
    assert len(batch_size) == len(resolutions)
    assert iterations is None or len(iterations) == len(resolutions)

    # Seed
    utils.set_seed(seed)

    # Check CUDA device availability
    device, gpus = utils.set_device(gpus)

    # Initialize loggers
    logger = TensorBoardLogger(log_dir=exp_dir)

    # Initialize datasets
    numpy_transforms = obj_factory(numpy_transforms) if numpy_transforms is not None else []
    tensor_transforms = obj_factory(tensor_transforms) if tensor_transforms is not None else []
    img_transforms = img_lms_pose_transforms.Compose(numpy_transforms + tensor_transforms)

    train_dataset = obj_factory(train_dataset, transform=img_transforms)
    if val_dataset is not None:
        val_dataset = obj_factory(val_dataset, transform=img_transforms)

    # Create networks
    G_arch = utils.get_arch(generator)
    D_arch = utils.get_arch(discriminator)
    G = obj_factory(generator).to(device)
    D = obj_factory(discriminator).to(device)

    # Resume from a checkpoint or initialize the networks weights randomly
    checkpoint_dir = exp_dir if resume_dir is None else resume_dir
    G_path = os.path.join(checkpoint_dir, 'G_latest.pth')
    D_path = os.path.join(checkpoint_dir, 'D_latest.pth')
    best_loss = 1e6
    curr_res = resolutions[0]
    optimizer_G_state, optimizer_D_state = None, None
    if os.path.isfile(G_path) and os.path.isfile(D_path):
        print("=> loading checkpoint from '{}'".format(checkpoint_dir))
        # G
        checkpoint = torch.load(G_path)
        if 'resolution' in checkpoint:
            curr_res = checkpoint['resolution']
            start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
        # else:
        #     curr_res = resolutions[1] if len(resolutions) > 1 else resolutions[0]
        best_loss_key = 'best_loss_%d' % curr_res
        best_loss = checkpoint[best_loss_key] if best_loss_key in checkpoint else best_loss
        G.apply(utils.init_weights)
        G.load_state_dict(checkpoint['state_dict'], strict=False)
        optimizer_G_state = checkpoint['optimizer']

        # D
        D.apply(utils.init_weights)
        if os.path.isfile(D_path):
            checkpoint = torch.load(D_path)
            D.load_state_dict(checkpoint['state_dict'], strict=False)
            optimizer_D_state = checkpoint['optimizer']
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_dir))
        if not pretrained:
            print("=> randomly initializing networks...")
            G.apply(utils.init_weights)
            D.apply(utils.init_weights)

    # Initialize landmarks decoders
    landmarks_decoders = []
    for res in resolutions:
        landmarks_decoders.insert(0, landmarks_utils.LandmarksHeatMapDecoder(res).to(device))

    # Lossess
    criterion_pixelwise = obj_factory(criterion_pixelwise).to(device)
    criterion_id = obj_factory(criterion_id).to(device)
    criterion_attr = obj_factory(criterion_attr).to(device)
    criterion_gan = obj_factory(criterion_gan).to(device)

    # Support multiple GPUs
    if gpus and len(gpus) > 1:
        G = nn.DataParallel(G, gpus)
        D = nn.DataParallel(D, gpus)
        criterion_id.vgg = nn.DataParallel(criterion_id.vgg, gpus)
        criterion_attr.vgg = nn.DataParallel(criterion_attr.vgg, gpus)
        landmarks_decoders = [nn.DataParallel(ld, gpus) for ld in landmarks_decoders]

    # For each resolution
    start_res_ind = int(np.log2(curr_res)) - int(np.log2(resolutions[0]))
    start_epoch = 0 if start_epoch is None else start_epoch
    for ri in range(start_res_ind, len(resolutions)):
        res = resolutions[ri]
        res_lr_gen = lr_gen[ri]
        res_lr_dis = lr_dis[ri]
        res_epochs = epochs[ri]
        res_iterations = iterations[ri] if iterations is not None else None
        res_batch_size = batch_size[ri]
        res_landmarks_decoders = landmarks_decoders[-ri - 1:]

        # Optimizer and scheduler
        optimizer_G = obj_factory(optimizer, G.parameters(), lr=res_lr_gen)
        optimizer_D = obj_factory(optimizer, D.parameters(), lr=res_lr_dis)
        scheduler_G = obj_factory(scheduler, optimizer_G)
        scheduler_D = obj_factory(scheduler, optimizer_D)
        if optimizer_G_state is not None:
            optimizer_G.load_state_dict(optimizer_G_state)
            optimizer_G_state = None
        if optimizer_D_state is not None:
            optimizer_D.load_state_dict(optimizer_D_state)
            optimizer_D_state = None

        # Initialize data loaders
        if res_iterations is None:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, len(train_dataset))
        else:
            train_sampler = tutils.data.sampler.WeightedRandomSampler(train_dataset.weights, res_iterations)
        train_loader = tutils.data.DataLoader(train_dataset, batch_size=res_batch_size, sampler=train_sampler,
                                              num_workers=workers, pin_memory=True, drop_last=True, shuffle=False)
        if val_dataset is not None:
            if res_iterations is None:
                val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, len(val_dataset))
            else:
                val_iterations = (res_iterations * len(val_dataset.classes)) // len(train_dataset.classes)
                val_sampler = tutils.data.sampler.WeightedRandomSampler(val_dataset.weights, val_iterations)
            val_loader = tutils.data.DataLoader(val_dataset, batch_size=res_batch_size, sampler=val_sampler,
                                                num_workers=workers, pin_memory=True, drop_last=True, shuffle=False)
        else:
            val_loader = None

        # For each epoch
        for epoch in range(start_epoch, res_epochs):
            total_loss = proces_epoch(train_loader, train=True)
            if val_loader is not None:
                with torch.no_grad():
                    total_loss = proces_epoch(val_loader, train=False)

            # Schedulers step (in PyTorch 1.1.0+ it must follow after the epoch training and validation steps)
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler_G.step(total_loss)
                scheduler_D.step(total_loss)
            else:
                scheduler_G.step()
                scheduler_D.step()

            # Save models checkpoints
            is_best = total_loss < best_loss
            best_loss = min(best_loss, total_loss)
            utils.save_checkpoint(exp_dir, 'G', {
                'resolution': res,
                'epoch': epoch + 1,
                'state_dict': G.module.state_dict() if gpus and len(gpus) > 1 else G.state_dict(),
                'optimizer': optimizer_G.state_dict(),
                'best_loss_%d' % res: best_loss,
                'arch': G_arch,
            }, is_best)
            utils.save_checkpoint(exp_dir, 'D', {
                'resolution': res,
                'epoch': epoch + 1,
                'state_dict': D.module.state_dict() if gpus and len(gpus) > 1 else D.state_dict(),
                'optimizer': optimizer_D.state_dict(),
                'best_loss_%d' % res: best_loss,
                'arch': D_arch,
            }, is_best)

        # Reset start epoch to 0 because it's should only effect the first training resolution
        start_epoch = 0
        best_loss = 1e6