def __init__(self, *args, **kwargs): super().__init__() self.model = MattingRefine(*args, **kwargs) # Hoist the attributes to the top level. self.backbone_scale = self.model.backbone_scale self.refine_mode = self.model.refiner.mode self.refine_sample_pixels = self.model.refiner.sample_pixels self.refine_threshold = self.model.refiner.threshold self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
class MattingRefine_TorchScriptWrapper(nn.Module): """ The purpose of this wrapper is to hoist all the configurable attributes to the top level. So that the user can easily change them after loading the saved TorchScript model. Example: model = torch.jit.load('torchscript.pth') model.backbone_scale = 0.25 model.refine_mode = 'sampling' model.refine_sample_pixels = 80_000 pha, fgr = model(src, bgr)[:2] """ def __init__(self, *args, **kwargs): super().__init__() self.model = MattingRefine(*args, **kwargs) # Hoist the attributes to the top level. self.backbone_scale = self.model.backbone_scale self.refine_mode = self.model.refiner.mode self.refine_sample_pixels = self.model.refiner.sample_pixels self.refine_threshold = self.model.refiner.threshold self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling def forward(self, src, bgr): # Reset the attributes. self.model.backbone_scale = self.backbone_scale self.model.refiner.mode = self.refine_mode self.model.refiner.sample_pixels = self.refine_sample_pixels self.model.refiner.threshold = self.refine_threshold self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling return self.model(src, bgr) def load_state_dict(self, *args, **kwargs): return self.model.load_state_dict(*args, **kwargs)
frame.save( os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension)) # --------------- Main --------------- device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.to(device).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) # Load video and background vid = VideoDataset(args.video_src) bgr = [Image.open(args.video_bgr).convert('RGB')] dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([ A.PairApply( T.Resize(args.video_resize[::-1]) if args. video_resize else nn.Identity()),
assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.' # --------------- Run Loop --------------- device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size, refine_prevent_oversampling=False) if args.model_checkpoint: model.load_state_dict(torch.load(args.model_checkpoint), strict=False) if args.precision == 'float32': precision = torch.float32 else: precision = torch.float16 if args.backend == 'torchscript': model = torch.jit.script(model)
choices=['float32', 'float16']) parser.add_argument('--validate', action='store_true') parser.add_argument('--output', type=str, required=True) args = parser.parse_args() # --------------- Main --------------- # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size, refine_patch_crop_method='roi_align', refine_patch_replace_method='scatter_element') model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False) precision = { 'float32': torch.float32, 'float16': torch.float16 }[args.precision] model.eval().to(precision).to(args.device) # Dummy Inputs src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
def train_worker(rank, addr, port): # Distributed Setup os.environ['MASTER_ADDR'] = addr os.environ['MASTER_PORT'] = port dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus) # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name] ['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name] ['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices( [1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataset_train_len_per_gpu_worker = int( len(dataset_train) / distributed_num_gpus) dataset_train = Subset(dataset_train, range( rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker)) dataloader_train = DataLoader(dataset_train, shuffle=True, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Validation DataLoader if rank == 0: dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name] ['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name] ['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Model model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_thresholding, args.model_refine_kernel_size).to(rank) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model_distributed = nn.parallel.DistributedDataParallel( model, device_ids=[rank]) if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) optimizer = Adam([ {'params': model.backbone.parameters(), 'lr': 5e-5}, {'params': model.aspp.parameters(), 'lr': 5e-5}, {'params': model.decoder.parameters(), 'lr': 1e-4}, {'params': model.refiner.parameters(), 'lr': 3e-4}, ]) scaler = GradScaler() # Logging and checkpoints if rank == 0: if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.to(rank, non_blocking=True) true_fgr = true_fgr.to(rank, non_blocking=True) true_bgr = true_bgr.to(rank, non_blocking=True) true_pha, true_fgr, true_bgr = random_crop( true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul( 0.3 * random.random()) aug_shadow = T.RandomAffine( degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur( aug_shadow, (random.choice(range(20, 40)),) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_( aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like( true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like( true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter( 0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine( degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed( true_src, true_bgr) loss = compute_loss( pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() if rank == 0: if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid( pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid( pred_err_sm, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) del true_pha, true_fgr, true_src, true_bgr del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save(model.state_dict( ), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') if rank == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # Clean up dist.destroy_process_group()
parser.add_argument('--validate', action='store_true') parser.add_argument('--output', type=str, required=True) args = parser.parse_args() # --------------- Main --------------- # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( backbone=args.model_backbone, backbone_scale=args.model_backbone_scale, refine_mode=args.model_refine_mode, refine_sample_pixels=args.model_refine_sample_pixels, refine_threshold=args.model_refine_threshold, refine_kernel_size=args.model_refine_kernel_size, refine_patch_crop_method=args.model_refine_patch_crop_method, refine_patch_replace_method=args.model_refine_patch_replace_method) model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False) precision = { 'float32': torch.float32, 'float16': torch.float16 }[args.precision] model.eval().to(precision).to(args.device) # Dummy Inputs