def pruning():
    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'],
                          mode='RGB'),
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.4, 1),
                                                   shear=(-5, 5)),
                       A.PairRandomHorizontalFlip(),
                       A.PairRandomBoxBlur(0.1, 5),
                       A.PairRandomSharpen(0.1),
                       A.PairApplyOnlyAtIndices([1],
                                                T.ColorJitter(
                                                    0.15, 0.15, 0.15, 0.05)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 2),
                                                  shear=(-5, 5)),
                          T.RandomHorizontalFlip(),
                          A.RandomBoxBlur(0.1, 5),
                          A.RandomSharpen(0.1),
                          T.ColorJitter(0.15, 0.15, 0.15, 0.05),
                          T.ToTensor()
                      ])),
    ])
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    # Validation DataLoader
    dataset_valid = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'],
                          mode='RGB')
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.3, 1),
                                                   shear=(-5, 5)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['valid'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 1.2),
                                                  shear=(-5, 5)),
                          T.ToTensor()
                      ])),
    ])
    dataset_valid = SampleDataset(dataset_valid, 50)
    dataloader_valid = DataLoader(dataset_valid,
                                  pin_memory=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    # Model
    model = MattingBase(args.model_backbone).cuda()

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
    elif args.model_pretrain_initialization is not None:
        model.load_pretrained_deeplabv3_state_dict(
            torch.load(args.model_pretrain_initialization)['model_state'])

    # 打印初试稀疏率
    # for name, module in model.named_modules():
    #     # prune 10% of connections in all 2D-conv layers
    #     if isinstance(module, torch.nn.Conv2d):
    #         # DNSUnst(module, name='weight')
    #         prune.l1_unstructured(module, name='weight', amount=0.4)
    #         prune.remove(module, 'weight')
    print("the original sparsity: ", get_sparsity(model))

    optimizer = Adam([{
        'params': model.backbone.parameters(),
        'lr': 1e-4
    }, {
        'params': model.aspp.parameters(),
        'lr': 5e-4
    }, {
        'params': model.decoder.parameters(),
        'lr': 5e-4
    }])
    scaler = GradScaler()

    # Logging and checkpoints
    if not os.path.exists(f'checkpoint/{args.model_name}'):
        os.makedirs(f'checkpoint/{args.model_name}')
    writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr),
                true_bgr) in enumerate(tqdm(dataloader_train)):

            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.cuda(non_blocking=True)
            true_fgr = true_fgr.cuda(non_blocking=True)
            true_bgr = true_bgr.cuda(non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr,
                                                       true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 *
                                                          random.random())
                aug_shadow = T.RandomAffine(degrees=(-5, 5),
                                            translate=(0.2, 0.2),
                                            scale=(0.5, 1.5),
                                            shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)), ) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(
                    torch.randn_like(true_src[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(
                    torch.randn_like(true_bgr[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1),
                    translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
                loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha,
                                    true_fgr)

            scaler.scale(loss).backward()

            # 剪枝
            best_c = np.zeros(187)
            if i == 0:
                ncs = NCS_C(model, true_src, true_bgr, true_pha, true_fgr)
                best_c = ncs.run(model, true_src, true_bgr, true_pha, true_fgr)
                PreTPUnst(model, best_c)
            else:
                # 调整
                PreDNSUnst(model, best_c)

            scaler.step(optimizer)
            Pruned(model)

            scaler.update()
            optimizer.zero_grad()

            if (i + 1) % args.log_train_loss_interval == 0:
                writer.add_scalar('loss', loss, step)

            if (i + 1) % args.log_train_images_interval == 0:
                writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5),
                                 step)
                writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5),
                                 step)
                writer.add_image('train_pred_com',
                                 make_grid(pred_fgr * pred_pha, nrow=5), step)
                writer.add_image('train_pred_err', make_grid(pred_err, nrow=5),
                                 step)
                writer.add_image('train_true_src', make_grid(true_src, nrow=5),
                                 step)
                writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5),
                                 step)

            del true_pha, true_fgr, true_bgr, true_src
            del pred_pha, pred_fgr, pred_err
            del loss
            del best_c

            if (i + 1) % args.log_valid_interval == 0:
                valid(model, dataloader_valid, writer, step)

            if (step + 1) % args.checkpoint_interval == 0:
                torch.save(
                    model.state_dict(),
                    f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth'
                )

        print("the sparsity of epoch {} : {}".format(epoch,
                                                     get_sparsity(model)))
        torch.save(model.state_dict(),
                   f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

    # 打印最终的稀疏度
    print("the final sparsity: ", get_sparsity(model))
Example #2
0
                          args.model_refine_sample_pixels,
                          args.model_refine_threshold,
                          args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device),
                      strict=False)

# Load video and background
vid = VideoDataset(args.video_src)
bgr = [Image.open(args.video_bgr).convert('RGB')]
dataset = ZipDataset([vid, bgr],
                     transforms=A.PairCompose([
                         A.PairApply(
                             T.Resize(args.video_resize[::-1]) if args.
                             video_resize else nn.Identity()),
                         HomographicAlignment() if args.preprocess_alignment
                         else A.PairApply(nn.Identity()),
                         A.PairApply(T.ToTensor())
                     ]))
if args.video_target_bgr:
    dataset = ZipDataset([
        dataset,
        VideoDataset(args.video_target_bgr, transforms=T.ToTensor())
    ])

# Create output directory
if os.path.exists(args.output_dir):
    if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: '
             ).lower() == 'y':
        shutil.rmtree(args.output_dir)
    else:
                          args.model_refine_threshold,
                          args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device),
                      strict=False)

# Load images
dataset = ZipDataset([
    ImagesDataset(args.images_src),
    ImagesDataset(args.images_bgr),
],
                     assert_equal_length=True,
                     transforms=A.PairCompose([
                         HomographicAlignment() if args.preprocess_alignment
                         else A.PairApply(nn.Identity()),
                         A.PairApply(T.ToTensor())
                     ]))
dataloader = DataLoader(dataset,
                        batch_size=1,
                        num_workers=args.num_workers,
                        pin_memory=True)

# Create output directory
if os.path.exists(args.output_dir):
    if args.y or input(
            f'Directory {args.output_dir} already exists. Override? [Y/N]: '
    ).lower() == 'y':
        shutil.rmtree(args.output_dir)
    else:
        exit()
                          args.model_refine_kernel_size)

model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device),
                      strict=False)

# Validation DataLoader
dataset_valid = ZipDataset([
    ZipDataset([
        ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
        ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
    ],
               transforms=A.PairCompose([
                   A.PairRandomAffineAndResize((512, 512),
                                               degrees=(-5, 5),
                                               translate=(0.1, 0.1),
                                               scale=(0.3, 1),
                                               shear=(-5, 5)),
                   A.PairApply(T.ToTensor())
               ]),
               assert_equal_length=True),
    ImagesDataset(DATA_PATH['backgrounds']['valid'],
                  mode='RGB',
                  transforms=T.Compose([
                      A.RandomAffineAndResize((512, 512),
                                              degrees=(-5, 5),
                                              translate=(0.1, 0.1),
                                              scale=(1, 1.2),
                                              shear=(-5, 5)),
                      T.ToTensor()
                  ])),
])
Example #5
0
def train_worker(rank, addr, port):

    # Distributed Setup
    os.environ['MASTER_ADDR'] = addr
    os.environ['MASTER_PORT'] = port
    dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)

    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]
                          ['train']['pha'], mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]
                          ['train']['fgr'], mode='RGB'),
        ], transforms=A.PairCompose([
            A.PairRandomAffineAndResize(
                (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
            A.PairRandomHorizontalFlip(),
            A.PairRandomBoxBlur(0.1, 5),
            A.PairRandomSharpen(0.1),
            A.PairApplyOnlyAtIndices(
                [1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
            A.PairApply(T.ToTensor())
        ]), assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
            A.RandomAffineAndResize(
                (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
            T.RandomHorizontalFlip(),
            A.RandomBoxBlur(0.1, 5),
            A.RandomSharpen(0.1),
            T.ColorJitter(0.15, 0.15, 0.15, 0.05),
            T.ToTensor()
        ])),
    ])
    dataset_train_len_per_gpu_worker = int(
        len(dataset_train) / distributed_num_gpus)
    dataset_train = Subset(dataset_train, range(
        rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  pin_memory=True,
                                  drop_last=True,
                                  batch_size=args.batch_size // distributed_num_gpus,
                                  num_workers=args.num_workers // distributed_num_gpus)

    # Validation DataLoader
    if rank == 0:
        dataset_valid = ZipDataset([
            ZipDataset([
                ImagesDataset(DATA_PATH[args.dataset_name]
                              ['valid']['pha'], mode='L'),
                ImagesDataset(DATA_PATH[args.dataset_name]
                              ['valid']['fgr'], mode='RGB')
            ], transforms=A.PairCompose([
                A.PairRandomAffineAndResize(
                    (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
                A.PairApply(T.ToTensor())
            ]), assert_equal_length=True),
            ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
                A.RandomAffineAndResize(
                    (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
                T.ToTensor()
            ])),
        ])
        dataset_valid = SampleDataset(dataset_valid, 50)
        dataloader_valid = DataLoader(dataset_valid,
                                      pin_memory=True,
                                      drop_last=True,
                                      batch_size=args.batch_size // distributed_num_gpus,
                                      num_workers=args.num_workers // distributed_num_gpus)

    # Model
    model = MattingRefine(args.model_backbone,
                          args.model_backbone_scale,
                          args.model_refine_mode,
                          args.model_refine_sample_pixels,
                          args.model_refine_thresholding,
                          args.model_refine_kernel_size).to(rank)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model_distributed = nn.parallel.DistributedDataParallel(
        model, device_ids=[rank])

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))

    optimizer = Adam([
        {'params': model.backbone.parameters(), 'lr': 5e-5},
        {'params': model.aspp.parameters(), 'lr': 5e-5},
        {'params': model.decoder.parameters(), 'lr': 1e-4},
        {'params': model.refiner.parameters(), 'lr': 3e-4},
    ])
    scaler = GradScaler()

    # Logging and checkpoints
    if rank == 0:
        if not os.path.exists(f'checkpoint/{args.model_name}'):
            os.makedirs(f'checkpoint/{args.model_name}')
        writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.to(rank, non_blocking=True)
            true_fgr = true_fgr.to(rank, non_blocking=True)
            true_bgr = true_bgr.to(rank, non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(
                true_pha, true_fgr, true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(
                    0.3 * random.random())
                aug_shadow = T.RandomAffine(
                    degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)),) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(
                    true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(
                    true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(
                    true_src, true_bgr)
                loss = compute_loss(
                    pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            if rank == 0:
                if (i + 1) % args.log_train_loss_interval == 0:
                    writer.add_scalar('loss', loss, step)

                if (i + 1) % args.log_train_images_interval == 0:
                    writer.add_image('train_pred_pha',
                                     make_grid(pred_pha, nrow=5), step)
                    writer.add_image('train_pred_fgr',
                                     make_grid(pred_fgr, nrow=5), step)
                    writer.add_image('train_pred_com', make_grid(
                        pred_fgr * pred_pha, nrow=5), step)
                    writer.add_image('train_pred_err', make_grid(
                        pred_err_sm, nrow=5), step)
                    writer.add_image('train_true_src',
                                     make_grid(true_src, nrow=5), step)

                del true_pha, true_fgr, true_src, true_bgr
                del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm

                if (i + 1) % args.log_valid_interval == 0:
                    valid(model, dataloader_valid, writer, step)

                if (step + 1) % args.checkpoint_interval == 0:
                    torch.save(model.state_dict(
                    ), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')

        if rank == 0:
            torch.save(model.state_dict(),
                       f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

    # Clean up
    dist.destroy_process_group()