def pruning(): # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices([1], T.ColorJitter( 0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataloader_train = DataLoader(dataset_train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) # Validation DataLoader dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, batch_size=args.batch_size, num_workers=args.num_workers) # Model model = MattingBase(args.model_backbone).cuda() if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) elif args.model_pretrain_initialization is not None: model.load_pretrained_deeplabv3_state_dict( torch.load(args.model_pretrain_initialization)['model_state']) # 打印初试稀疏率 # for name, module in model.named_modules(): # # prune 10% of connections in all 2D-conv layers # if isinstance(module, torch.nn.Conv2d): # # DNSUnst(module, name='weight') # prune.l1_unstructured(module, name='weight', amount=0.4) # prune.remove(module, 'weight') print("the original sparsity: ", get_sparsity(model)) optimizer = Adam([{ 'params': model.backbone.parameters(), 'lr': 1e-4 }, { 'params': model.aspp.parameters(), 'lr': 5e-4 }, { 'params': model.decoder.parameters(), 'lr': 5e-4 }]) scaler = GradScaler() # Logging and checkpoints if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur( aug_shadow, (random.choice(range(20, 40)), ) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_( aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_( torch.randn_like(true_src[aug_noise_idx]).mul_( 0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_( torch.randn_like(true_bgr[aug_noise_idx]).mul_( 0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter( 0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine( degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) scaler.scale(loss).backward() # 剪枝 best_c = np.zeros(187) if i == 0: ncs = NCS_C(model, true_src, true_bgr, true_pha, true_fgr) best_c = ncs.run(model, true_src, true_bgr, true_pha, true_fgr) PreTPUnst(model, best_c) else: # 调整 PreDNSUnst(model, best_c) scaler.step(optimizer) Pruned(model) scaler.update() optimizer.zero_grad() if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step) del true_pha, true_fgr, true_bgr, true_src del pred_pha, pred_fgr, pred_err del loss del best_c if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save( model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth' ) print("the sparsity of epoch {} : {}".format(epoch, get_sparsity(model))) torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # 打印最终的稀疏度 print("the final sparsity: ", get_sparsity(model))
args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.to(device).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) # Load video and background vid = VideoDataset(args.video_src) bgr = [Image.open(args.video_bgr).convert('RGB')] dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([ A.PairApply( T.Resize(args.video_resize[::-1]) if args. video_resize else nn.Identity()), HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), A.PairApply(T.ToTensor()) ])) if args.video_target_bgr: dataset = ZipDataset([ dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor()) ]) # Create output directory if os.path.exists(args.output_dir): if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ' ).lower() == 'y': shutil.rmtree(args.output_dir)
model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.to(device).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) # Load images dataset = ZipDataset([ ImagesDataset(args.images_src), ImagesDataset(args.images_bgr), ], assert_equal_length=True, transforms=A.PairCompose([ HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), A.PairApply(T.ToTensor()) ])) dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True) # Create output directory if os.path.exists(args.output_dir): if args.y or input( f'Directory {args.output_dir} already exists. Override? [Y/N]: ' ).lower() == 'y': shutil.rmtree(args.output_dir)
strict=False) # Validation DataLoader dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataloader_valid = DataLoader(dataset_valid, pin_memory=True,
def train_worker(rank, addr, port): # Distributed Setup os.environ['MASTER_ADDR'] = addr os.environ['MASTER_PORT'] = port dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus) # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name] ['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name] ['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices( [1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataset_train_len_per_gpu_worker = int( len(dataset_train) / distributed_num_gpus) dataset_train = Subset(dataset_train, range( rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker)) dataloader_train = DataLoader(dataset_train, shuffle=True, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Validation DataLoader if rank == 0: dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name] ['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name] ['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize( (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Model model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_thresholding, args.model_refine_kernel_size).to(rank) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model_distributed = nn.parallel.DistributedDataParallel( model, device_ids=[rank]) if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) optimizer = Adam([ {'params': model.backbone.parameters(), 'lr': 5e-5}, {'params': model.aspp.parameters(), 'lr': 5e-5}, {'params': model.decoder.parameters(), 'lr': 1e-4}, {'params': model.refiner.parameters(), 'lr': 3e-4}, ]) scaler = GradScaler() # Logging and checkpoints if rank == 0: if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.to(rank, non_blocking=True) true_fgr = true_fgr.to(rank, non_blocking=True) true_bgr = true_bgr.to(rank, non_blocking=True) true_pha, true_fgr, true_bgr = random_crop( true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul( 0.3 * random.random()) aug_shadow = T.RandomAffine( degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur( aug_shadow, (random.choice(range(20, 40)),) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_( aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like( true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like( true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter( 0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine( degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed( true_src, true_bgr) loss = compute_loss( pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() if rank == 0: if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid( pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid( pred_err_sm, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) del true_pha, true_fgr, true_src, true_bgr del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save(model.state_dict( ), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') if rank == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # Clean up dist.destroy_process_group()