def pruning():
    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'],
                          mode='RGB'),
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.4, 1),
                                                   shear=(-5, 5)),
                       A.PairRandomHorizontalFlip(),
                       A.PairRandomBoxBlur(0.1, 5),
                       A.PairRandomSharpen(0.1),
                       A.PairApplyOnlyAtIndices([1],
                                                T.ColorJitter(
                                                    0.15, 0.15, 0.15, 0.05)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 2),
                                                  shear=(-5, 5)),
                          T.RandomHorizontalFlip(),
                          A.RandomBoxBlur(0.1, 5),
                          A.RandomSharpen(0.1),
                          T.ColorJitter(0.15, 0.15, 0.15, 0.05),
                          T.ToTensor()
                      ])),
    ])
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    # Validation DataLoader
    dataset_valid = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'],
                          mode='RGB')
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.3, 1),
                                                   shear=(-5, 5)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['valid'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 1.2),
                                                  shear=(-5, 5)),
                          T.ToTensor()
                      ])),
    ])
    dataset_valid = SampleDataset(dataset_valid, 50)
    dataloader_valid = DataLoader(dataset_valid,
                                  pin_memory=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    # Model
    model = MattingBase(args.model_backbone).cuda()

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
    elif args.model_pretrain_initialization is not None:
        model.load_pretrained_deeplabv3_state_dict(
            torch.load(args.model_pretrain_initialization)['model_state'])

    # 打印初试稀疏率
    # for name, module in model.named_modules():
    #     # prune 10% of connections in all 2D-conv layers
    #     if isinstance(module, torch.nn.Conv2d):
    #         # DNSUnst(module, name='weight')
    #         prune.l1_unstructured(module, name='weight', amount=0.4)
    #         prune.remove(module, 'weight')
    print("the original sparsity: ", get_sparsity(model))

    optimizer = Adam([{
        'params': model.backbone.parameters(),
        'lr': 1e-4
    }, {
        'params': model.aspp.parameters(),
        'lr': 5e-4
    }, {
        'params': model.decoder.parameters(),
        'lr': 5e-4
    }])
    scaler = GradScaler()

    # Logging and checkpoints
    if not os.path.exists(f'checkpoint/{args.model_name}'):
        os.makedirs(f'checkpoint/{args.model_name}')
    writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr),
                true_bgr) in enumerate(tqdm(dataloader_train)):

            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.cuda(non_blocking=True)
            true_fgr = true_fgr.cuda(non_blocking=True)
            true_bgr = true_bgr.cuda(non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr,
                                                       true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 *
                                                          random.random())
                aug_shadow = T.RandomAffine(degrees=(-5, 5),
                                            translate=(0.2, 0.2),
                                            scale=(0.5, 1.5),
                                            shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)), ) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(
                    torch.randn_like(true_src[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(
                    torch.randn_like(true_bgr[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1),
                    translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
                loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha,
                                    true_fgr)

            scaler.scale(loss).backward()

            # 剪枝
            best_c = np.zeros(187)
            if i == 0:
                ncs = NCS_C(model, true_src, true_bgr, true_pha, true_fgr)
                best_c = ncs.run(model, true_src, true_bgr, true_pha, true_fgr)
                PreTPUnst(model, best_c)
            else:
                # 调整
                PreDNSUnst(model, best_c)

            scaler.step(optimizer)
            Pruned(model)

            scaler.update()
            optimizer.zero_grad()

            if (i + 1) % args.log_train_loss_interval == 0:
                writer.add_scalar('loss', loss, step)

            if (i + 1) % args.log_train_images_interval == 0:
                writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5),
                                 step)
                writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5),
                                 step)
                writer.add_image('train_pred_com',
                                 make_grid(pred_fgr * pred_pha, nrow=5), step)
                writer.add_image('train_pred_err', make_grid(pred_err, nrow=5),
                                 step)
                writer.add_image('train_true_src', make_grid(true_src, nrow=5),
                                 step)
                writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5),
                                 step)

            del true_pha, true_fgr, true_bgr, true_src
            del pred_pha, pred_fgr, pred_err
            del loss
            del best_c

            if (i + 1) % args.log_valid_interval == 0:
                valid(model, dataloader_valid, writer, step)

            if (step + 1) % args.checkpoint_interval == 0:
                torch.save(
                    model.state_dict(),
                    f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth'
                )

        print("the sparsity of epoch {} : {}".format(epoch,
                                                     get_sparsity(model)))
        torch.save(model.state_dict(),
                   f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

    # 打印最终的稀疏度
    print("the final sparsity: ", get_sparsity(model))
Exemple #2
0
        frames = frames.cpu()
        for i in range(frames.shape[0]):
            frame = frames[i]
            frame = to_pil_image(frame)
            frame.save(
                os.path.join(self.path,
                             str(index + i).zfill(5) + '.' + self.extension))


# --------------- Main ---------------

device = torch.device(args.device)

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(args.model_backbone, args.model_backbone_scale,
                          args.model_refine_mode,
                          args.model_refine_sample_pixels,
                          args.model_refine_threshold,
                          args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device),
                      strict=False)

# Load video and background
vid = VideoDataset(args.video_src)
bgr = [Image.open(args.video_bgr).convert('RGB')]
dataset = ZipDataset([vid, bgr],
args = parser.parse_args()


assert type(args.image_src) == type(args.image_bgr),  'Image source and background must be provided together.'
assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'


# --------------- Run Loop ---------------


device = torch.device(args.device)

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold,
        args.model_refine_kernel_size,
        refine_prevent_oversampling=False)

if args.model_checkpoint:
    model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
    
if args.precision == 'float32':
    precision = torch.float32
Exemple #4
0
# --------------- Main ---------------

torch.set_num_threads(args.num_threads)

device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using {args.model_type}")
print(f"Using {args.num_threads} threads")
print(f"Using {precision} precision")
print(f"Using {device} device")
print(f"Loading {args.model_checkpoint}")

# Load model
if args.model_type == "mattingbase":
    model = MattingBase(args.model_backbone)
if args.model_type == "mattingrefine":
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold,
    )

if args.model_type == "jit":
    model = torch.jit.load(args.model_checkpoint)
elif args.model_type == "trt":
    # Not working yet
    model = TRTModule().load_state_dict(torch.load(args.model_checkpoint))
else:
Exemple #5
0
        self.index += frames.shape[0]
            
    def _add_batch(self, frames, index):
        frames = frames.cpu()
        for i in range(frames.shape[0]):
            frame = frames[i]
            frame = to_pil_image(frame)
            frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))


# --------------- Main ---------------


# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(
        args.model_backbone,
        args.model_backbone_scale,
        args.model_refine_mode,
        args.model_refine_sample_pixels,
        args.model_refine_threshold,
        args.model_refine_kernel_size)

model = model.cuda().eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)


# Load video and background
vid = VideoDataset(args.video_src)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--precision',
                    type=str,
                    default='float32',
                    choices=['float32', 'float16'])
parser.add_argument('--validate', action='store_true')
parser.add_argument('--output', type=str, required=True)

args = parser.parse_args()

# --------------- Main ---------------

# Load model
if args.model_type == 'mattingbase':
    model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
    model = MattingRefine(args.model_backbone,
                          args.model_backbone_scale,
                          args.model_refine_mode,
                          args.model_refine_sample_pixels,
                          args.model_refine_threshold,
                          args.model_refine_kernel_size,
                          refine_patch_crop_method='roi_align',
                          refine_patch_replace_method='scatter_element')

model.load_state_dict(torch.load(args.model_checkpoint,
                                 map_location=args.device),
                      strict=False)
precision = {
    'float32': torch.float32,