def pruning(): # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices([1], T.ColorJitter( 0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataloader_train = DataLoader(dataset_train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) # Validation DataLoader dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, batch_size=args.batch_size, num_workers=args.num_workers) # Model model = MattingBase(args.model_backbone).cuda() if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) elif args.model_pretrain_initialization is not None: model.load_pretrained_deeplabv3_state_dict( torch.load(args.model_pretrain_initialization)['model_state']) # 打印初试稀疏率 # for name, module in model.named_modules(): # # prune 10% of connections in all 2D-conv layers # if isinstance(module, torch.nn.Conv2d): # # DNSUnst(module, name='weight') # prune.l1_unstructured(module, name='weight', amount=0.4) # prune.remove(module, 'weight') print("the original sparsity: ", get_sparsity(model)) optimizer = Adam([{ 'params': model.backbone.parameters(), 'lr': 1e-4 }, { 'params': model.aspp.parameters(), 'lr': 5e-4 }, { 'params': model.decoder.parameters(), 'lr': 5e-4 }]) scaler = GradScaler() # Logging and checkpoints if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur( aug_shadow, (random.choice(range(20, 40)), ) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_( aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_( torch.randn_like(true_src[aug_noise_idx]).mul_( 0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_( torch.randn_like(true_bgr[aug_noise_idx]).mul_( 0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter( 0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine( degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) scaler.scale(loss).backward() # 剪枝 best_c = np.zeros(187) if i == 0: ncs = NCS_C(model, true_src, true_bgr, true_pha, true_fgr) best_c = ncs.run(model, true_src, true_bgr, true_pha, true_fgr) PreTPUnst(model, best_c) else: # 调整 PreDNSUnst(model, best_c) scaler.step(optimizer) Pruned(model) scaler.update() optimizer.zero_grad() if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step) del true_pha, true_fgr, true_bgr, true_src del pred_pha, pred_fgr, pred_err del loss del best_c if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save( model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth' ) print("the sparsity of epoch {} : {}".format(epoch, get_sparsity(model))) torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # 打印最终的稀疏度 print("the final sparsity: ", get_sparsity(model))
frames = frames.cpu() for i in range(frames.shape[0]): frame = frames[i] frame = to_pil_image(frame) frame.save( os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension)) # --------------- Main --------------- device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.to(device).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) # Load video and background vid = VideoDataset(args.video_src) bgr = [Image.open(args.video_bgr).convert('RGB')] dataset = ZipDataset([vid, bgr],
args = parser.parse_args() assert type(args.image_src) == type(args.image_bgr), 'Image source and background must be provided together.' assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.' # --------------- Run Loop --------------- device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size, refine_prevent_oversampling=False) if args.model_checkpoint: model.load_state_dict(torch.load(args.model_checkpoint), strict=False) if args.precision == 'float32': precision = torch.float32
# --------------- Main --------------- torch.set_num_threads(args.num_threads) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {args.model_type}") print(f"Using {args.num_threads} threads") print(f"Using {precision} precision") print(f"Using {device} device") print(f"Loading {args.model_checkpoint}") # Load model if args.model_type == "mattingbase": model = MattingBase(args.model_backbone) if args.model_type == "mattingrefine": model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, ) if args.model_type == "jit": model = torch.jit.load(args.model_checkpoint) elif args.model_type == "trt": # Not working yet model = TRTModule().load_state_dict(torch.load(args.model_checkpoint)) else:
self.index += frames.shape[0] def _add_batch(self, frames, index): frames = frames.cpu() for i in range(frames.shape[0]): frame = frames[i] frame = to_pil_image(frame) frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension)) # --------------- Main --------------- # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.cuda().eval() model.load_state_dict(torch.load(args.model_checkpoint), strict=False) # Load video and background vid = VideoDataset(args.video_src)
parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) parser.add_argument('--validate', action='store_true') parser.add_argument('--output', type=str, required=True) args = parser.parse_args() # --------------- Main --------------- # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size, refine_patch_crop_method='roi_align', refine_patch_replace_method='scatter_element') model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False) precision = { 'float32': torch.float32,