Beispiel #1
0
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.model = MattingRefine(*args, **kwargs)

        # Hoist the attributes to the top level.
        self.backbone_scale = self.model.backbone_scale
        self.refine_mode = self.model.refiner.mode
        self.refine_sample_pixels = self.model.refiner.sample_pixels
        self.refine_threshold = self.model.refiner.threshold
        self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
Beispiel #2
0
class MattingRefine_TorchScriptWrapper(nn.Module):
    """
    The purpose of this wrapper is to hoist all the configurable attributes to the top level.
    So that the user can easily change them after loading the saved TorchScript model.

    Example:
        model = torch.jit.load('torchscript.pth')
        model.backbone_scale = 0.25
        model.refine_mode = 'sampling'
        model.refine_sample_pixels = 80_000
        pha, fgr = model(src, bgr)[:2]
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.model = MattingRefine(*args, **kwargs)

        # Hoist the attributes to the top level.
        self.backbone_scale = self.model.backbone_scale
        self.refine_mode = self.model.refiner.mode
        self.refine_sample_pixels = self.model.refiner.sample_pixels
        self.refine_threshold = self.model.refiner.threshold
        self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling

    def forward(self, src, bgr):
        # Reset the attributes.
        self.model.backbone_scale = self.backbone_scale
        self.model.refiner.mode = self.refine_mode
        self.model.refiner.sample_pixels = self.refine_sample_pixels
        self.model.refiner.threshold = self.refine_threshold
        self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling

        return self.model(src, bgr)

    def load_state_dict(self, *args, **kwargs):
        return self.model.load_state_dict(*args, **kwargs)
Beispiel #3
0
            frame.save(
                os.path.join(self.path,
                             str(index + i).zfill(5) + '.' + self.extension))


# --------------- Main ---------------

device = torch.device(args.device)

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(args.model_backbone, args.model_backbone_scale,
                          args.model_refine_mode,
                          args.model_refine_sample_pixels,
                          args.model_refine_threshold,
                          args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device),
                      strict=False)

# Load video and background
vid = VideoDataset(args.video_src)
bgr = [Image.open(args.video_bgr).convert('RGB')]
dataset = ZipDataset([vid, bgr],
                     transforms=A.PairCompose([
                         A.PairApply(
                             T.Resize(args.video_resize[::-1]) if args.
                             video_resize else nn.Identity()),
assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'


# --------------- Run Loop ---------------


device = torch.device(args.device)

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold,
        args.model_refine_kernel_size,
        refine_prevent_oversampling=False)

if args.model_checkpoint:
    model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
    
if args.precision == 'float32':
    precision = torch.float32
else:
    precision = torch.float16
    
if args.backend == 'torchscript':
    model = torch.jit.script(model)
                    choices=['float32', 'float16'])
parser.add_argument('--validate', action='store_true')
parser.add_argument('--output', type=str, required=True)

args = parser.parse_args()

# --------------- Main ---------------

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(args.model_backbone,
                          args.model_backbone_scale,
                          args.model_refine_mode,
                          args.model_refine_sample_pixels,
                          args.model_refine_threshold,
                          args.model_refine_kernel_size,
                          refine_patch_crop_method='roi_align',
                          refine_patch_replace_method='scatter_element')

model.load_state_dict(torch.load(args.model_checkpoint,
                                 map_location=args.device),
                      strict=False)
precision = {
    'float32': torch.float32,
    'float16': torch.float16
}[args.precision]
model.eval().to(precision).to(args.device)

# Dummy Inputs
src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
Beispiel #6
0
def train_worker(rank, addr, port):

    # Distributed Setup
    os.environ['MASTER_ADDR'] = addr
    os.environ['MASTER_PORT'] = port
    dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)

    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]
                          ['train']['pha'], mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]
                          ['train']['fgr'], mode='RGB'),
        ], transforms=A.PairCompose([
            A.PairRandomAffineAndResize(
                (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
            A.PairRandomHorizontalFlip(),
            A.PairRandomBoxBlur(0.1, 5),
            A.PairRandomSharpen(0.1),
            A.PairApplyOnlyAtIndices(
                [1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
            A.PairApply(T.ToTensor())
        ]), assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
            A.RandomAffineAndResize(
                (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
            T.RandomHorizontalFlip(),
            A.RandomBoxBlur(0.1, 5),
            A.RandomSharpen(0.1),
            T.ColorJitter(0.15, 0.15, 0.15, 0.05),
            T.ToTensor()
        ])),
    ])
    dataset_train_len_per_gpu_worker = int(
        len(dataset_train) / distributed_num_gpus)
    dataset_train = Subset(dataset_train, range(
        rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  pin_memory=True,
                                  drop_last=True,
                                  batch_size=args.batch_size // distributed_num_gpus,
                                  num_workers=args.num_workers // distributed_num_gpus)

    # Validation DataLoader
    if rank == 0:
        dataset_valid = ZipDataset([
            ZipDataset([
                ImagesDataset(DATA_PATH[args.dataset_name]
                              ['valid']['pha'], mode='L'),
                ImagesDataset(DATA_PATH[args.dataset_name]
                              ['valid']['fgr'], mode='RGB')
            ], transforms=A.PairCompose([
                A.PairRandomAffineAndResize(
                    (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
                A.PairApply(T.ToTensor())
            ]), assert_equal_length=True),
            ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
                A.RandomAffineAndResize(
                    (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
                T.ToTensor()
            ])),
        ])
        dataset_valid = SampleDataset(dataset_valid, 50)
        dataloader_valid = DataLoader(dataset_valid,
                                      pin_memory=True,
                                      drop_last=True,
                                      batch_size=args.batch_size // distributed_num_gpus,
                                      num_workers=args.num_workers // distributed_num_gpus)

    # Model
    model = MattingRefine(args.model_backbone,
                          args.model_backbone_scale,
                          args.model_refine_mode,
                          args.model_refine_sample_pixels,
                          args.model_refine_thresholding,
                          args.model_refine_kernel_size).to(rank)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model_distributed = nn.parallel.DistributedDataParallel(
        model, device_ids=[rank])

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))

    optimizer = Adam([
        {'params': model.backbone.parameters(), 'lr': 5e-5},
        {'params': model.aspp.parameters(), 'lr': 5e-5},
        {'params': model.decoder.parameters(), 'lr': 1e-4},
        {'params': model.refiner.parameters(), 'lr': 3e-4},
    ])
    scaler = GradScaler()

    # Logging and checkpoints
    if rank == 0:
        if not os.path.exists(f'checkpoint/{args.model_name}'):
            os.makedirs(f'checkpoint/{args.model_name}')
        writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.to(rank, non_blocking=True)
            true_fgr = true_fgr.to(rank, non_blocking=True)
            true_bgr = true_bgr.to(rank, non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(
                true_pha, true_fgr, true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(
                    0.3 * random.random())
                aug_shadow = T.RandomAffine(
                    degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)),) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(
                    true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(
                    true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(
                    true_src, true_bgr)
                loss = compute_loss(
                    pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            if rank == 0:
                if (i + 1) % args.log_train_loss_interval == 0:
                    writer.add_scalar('loss', loss, step)

                if (i + 1) % args.log_train_images_interval == 0:
                    writer.add_image('train_pred_pha',
                                     make_grid(pred_pha, nrow=5), step)
                    writer.add_image('train_pred_fgr',
                                     make_grid(pred_fgr, nrow=5), step)
                    writer.add_image('train_pred_com', make_grid(
                        pred_fgr * pred_pha, nrow=5), step)
                    writer.add_image('train_pred_err', make_grid(
                        pred_err_sm, nrow=5), step)
                    writer.add_image('train_true_src',
                                     make_grid(true_src, nrow=5), step)

                del true_pha, true_fgr, true_src, true_bgr
                del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm

                if (i + 1) % args.log_valid_interval == 0:
                    valid(model, dataloader_valid, writer, step)

                if (step + 1) % args.checkpoint_interval == 0:
                    torch.save(model.state_dict(
                    ), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')

        if rank == 0:
            torch.save(model.state_dict(),
                       f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

    # Clean up
    dist.destroy_process_group()
parser.add_argument('--validate', action='store_true')
parser.add_argument('--output', type=str, required=True)

args = parser.parse_args()

# --------------- Main ---------------

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        backbone=args.model_backbone,
        backbone_scale=args.model_backbone_scale,
        refine_mode=args.model_refine_mode,
        refine_sample_pixels=args.model_refine_sample_pixels,
        refine_threshold=args.model_refine_threshold,
        refine_kernel_size=args.model_refine_kernel_size,
        refine_patch_crop_method=args.model_refine_patch_crop_method,
        refine_patch_replace_method=args.model_refine_patch_replace_method)

model.load_state_dict(torch.load(args.model_checkpoint,
                                 map_location=args.device),
                      strict=False)
precision = {
    'float32': torch.float32,
    'float16': torch.float16
}[args.precision]
model.eval().to(precision).to(args.device)

# Dummy Inputs