示例#1
0
    def __init__(self, dsets, layer_types, net_spec, preps, xforms, augs):
        """
        Initialize DataProvider.

        Args:
            dspec_path: Path to the dataset specification file.
            net_spec:   Net specification.
            params:     Various options.
            auto_mask:  Whether to automatically generate mask from
                        corresponding label.
        """

        self.datasets = list()
        for d in dsets:
            dataset = SampleDataset(d, layer_types, net_spec, preps, xforms)
            self.datasets.append(dataset)
            
        self._data_aug = DataAugmentor(augs)
示例#2
0
def load_data(args):

    data_path = args.data_path
    n_classes = args.classes
    data_width = args.width
    data_height = args.height

    # generate loader
    test_dataset = SampleDataset(data_path)

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=4,
    )

    print('test_dataset : {}, test_loader : {}'.format(len(test_dataset),
                                                       len(test_loader)))

    return test_dataset, test_loader
def pruning():
    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'],
                          mode='RGB'),
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.4, 1),
                                                   shear=(-5, 5)),
                       A.PairRandomHorizontalFlip(),
                       A.PairRandomBoxBlur(0.1, 5),
                       A.PairRandomSharpen(0.1),
                       A.PairApplyOnlyAtIndices([1],
                                                T.ColorJitter(
                                                    0.15, 0.15, 0.15, 0.05)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 2),
                                                  shear=(-5, 5)),
                          T.RandomHorizontalFlip(),
                          A.RandomBoxBlur(0.1, 5),
                          A.RandomSharpen(0.1),
                          T.ColorJitter(0.15, 0.15, 0.15, 0.05),
                          T.ToTensor()
                      ])),
    ])
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    # Validation DataLoader
    dataset_valid = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'],
                          mode='RGB')
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.3, 1),
                                                   shear=(-5, 5)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['valid'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 1.2),
                                                  shear=(-5, 5)),
                          T.ToTensor()
                      ])),
    ])
    dataset_valid = SampleDataset(dataset_valid, 50)
    dataloader_valid = DataLoader(dataset_valid,
                                  pin_memory=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    # Model
    model = MattingBase(args.model_backbone).cuda()

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
    elif args.model_pretrain_initialization is not None:
        model.load_pretrained_deeplabv3_state_dict(
            torch.load(args.model_pretrain_initialization)['model_state'])

    # 打印初试稀疏率
    # for name, module in model.named_modules():
    #     # prune 10% of connections in all 2D-conv layers
    #     if isinstance(module, torch.nn.Conv2d):
    #         # DNSUnst(module, name='weight')
    #         prune.l1_unstructured(module, name='weight', amount=0.4)
    #         prune.remove(module, 'weight')
    print("the original sparsity: ", get_sparsity(model))

    optimizer = Adam([{
        'params': model.backbone.parameters(),
        'lr': 1e-4
    }, {
        'params': model.aspp.parameters(),
        'lr': 5e-4
    }, {
        'params': model.decoder.parameters(),
        'lr': 5e-4
    }])
    scaler = GradScaler()

    # Logging and checkpoints
    if not os.path.exists(f'checkpoint/{args.model_name}'):
        os.makedirs(f'checkpoint/{args.model_name}')
    writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr),
                true_bgr) in enumerate(tqdm(dataloader_train)):

            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.cuda(non_blocking=True)
            true_fgr = true_fgr.cuda(non_blocking=True)
            true_bgr = true_bgr.cuda(non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr,
                                                       true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 *
                                                          random.random())
                aug_shadow = T.RandomAffine(degrees=(-5, 5),
                                            translate=(0.2, 0.2),
                                            scale=(0.5, 1.5),
                                            shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)), ) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(
                    torch.randn_like(true_src[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(
                    torch.randn_like(true_bgr[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1),
                    translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
                loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha,
                                    true_fgr)

            scaler.scale(loss).backward()

            # 剪枝
            best_c = np.zeros(187)
            if i == 0:
                ncs = NCS_C(model, true_src, true_bgr, true_pha, true_fgr)
                best_c = ncs.run(model, true_src, true_bgr, true_pha, true_fgr)
                PreTPUnst(model, best_c)
            else:
                # 调整
                PreDNSUnst(model, best_c)

            scaler.step(optimizer)
            Pruned(model)

            scaler.update()
            optimizer.zero_grad()

            if (i + 1) % args.log_train_loss_interval == 0:
                writer.add_scalar('loss', loss, step)

            if (i + 1) % args.log_train_images_interval == 0:
                writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5),
                                 step)
                writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5),
                                 step)
                writer.add_image('train_pred_com',
                                 make_grid(pred_fgr * pred_pha, nrow=5), step)
                writer.add_image('train_pred_err', make_grid(pred_err, nrow=5),
                                 step)
                writer.add_image('train_true_src', make_grid(true_src, nrow=5),
                                 step)
                writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5),
                                 step)

            del true_pha, true_fgr, true_bgr, true_src
            del pred_pha, pred_fgr, pred_err
            del loss
            del best_c

            if (i + 1) % args.log_valid_interval == 0:
                valid(model, dataloader_valid, writer, step)

            if (step + 1) % args.checkpoint_interval == 0:
                torch.save(
                    model.state_dict(),
                    f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth'
                )

        print("the sparsity of epoch {} : {}".format(epoch,
                                                     get_sparsity(model)))
        torch.save(model.state_dict(),
                   f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

    # 打印最终的稀疏度
    print("the final sparsity: ", get_sparsity(model))
示例#4
0
文件: eval.py 项目: frankieeder/DCGVF
import numpy as np
import torch
from torch.utils.data import SubsetRandomSampler, DataLoader
from pytoflow.Network import TOFlow
import pytorch_ssim


def save_results(sample, toflow, name):
    x, y = sample
    y_hat = toflow(x.cuda())
    y_hat = y_hat.cpu().detach().numpy()
    np.savez_compressed(name, x=x, y=y, y_hat=y_hat)


# Dataset
vd = SampleDataset("../dcgvf_data/processed_decomp")

# Train/Val/Test Splits
train_ind, val_ind, test_ind = np.split(np.arange(
    len(vd)), [int(.6 * len(vd)), int(.8 * len(vd))])
train_sampler = SubsetRandomSampler(train_ind)
val_sampler = SubsetRandomSampler(val_ind)
test_sampler = SubsetRandomSampler(test_ind)

h = 256
w = 448
task = 'denoising'
cuda_flag = True
model_path = './toflow_models/denoise_decomp_best_params.pkl'

# Load Pretrained Model
示例#5
0
def train(dataset, val_dataset, v, start_epoch=0):
    """Train the model, evaluate it and store it.

    Args:
        dataset (dataset.PairDataset): The training dataset.
        val_dataset (dataset.PairDataset): The evaluation dataset.
        v (vocab.Vocab): The vocabulary built from the training dataset.
        start_epoch (int, optional): The starting epoch number. Defaults to 0.
    """

    DEVICE = torch.device("cuda" if config.is_cuda else "cpu")

    model = PGN(v)
    model.load_model()
    model.to(DEVICE)
    if config.fine_tune:
        # In fine-tuning mode, we fix the weights of all parameters except attention.wc.
        print('Fine-tuning mode.')
        for name, params in model.named_parameters():
            if name != 'attention.wc.weight':
                params.requires_grad = False
    # forward
    print("loading data")
    train_data = SampleDataset(dataset.pairs, v)
    val_data = SampleDataset(val_dataset.pairs, v)

    print("initializing optimizer")

    # Define the optimizer.
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)

    val_losses = np.inf
    if (os.path.exists(config.losses_path)):
        with open(config.losses_path, 'rb') as f:
            val_losses = pickle.load(f)


#     torch.cuda.empty_cache()
# SummaryWriter: Log writer used for TensorboardX visualization.
    writer = SummaryWriter(config.log_path)
    # tqdm: A tool for drawing progress bars during training.
    # scheduled_sampler : A tool for choosing teacher_forcing or not
    num_epochs = len(range(start_epoch, config.epochs))
    scheduled_sampler = ScheduledSampler(num_epochs)
    if config.scheduled_sampling:
        print('scheduled_sampling mode.')
    #  teacher_forcing = True

    with tqdm(total=config.epochs) as epoch_progress:
        for epoch in range(start_epoch, config.epochs):
            print(config_info(config))
            batch_losses = []  # Get loss of each batch.
            num_batches = len(train_dataloader)
            # set a teacher_forcing signal
            if config.scheduled_sampling:
                teacher_forcing = scheduled_sampler.teacher_forcing(
                    epoch - start_epoch)
            else:
                teacher_forcing = True
            print('teacher_forcing = {}'.format(teacher_forcing))
            with tqdm(total=num_batches) as batch_progress:
                for batch, data in enumerate(tqdm(train_dataloader)):
                    x, y, x_len, y_len, oov, len_oovs = data
                    assert not np.any(np.isnan(x.numpy()))
                    if config.is_cuda:  # Training with GPUs.
                        x = x.to(DEVICE)
                        y = y.to(DEVICE)
                        x_len = x_len.to(DEVICE)
                        len_oovs = len_oovs.to(DEVICE)

                    model.train()  # Sets the module in training mode.
                    optimizer.zero_grad()  # Clear gradients.
                    # Calculate loss.  Call model forward propagation
                    loss = model(x,
                                 x_len,
                                 y,
                                 len_oovs,
                                 batch=batch,
                                 num_batches=num_batches,
                                 teacher_forcing=teacher_forcing)
                    batch_losses.append(loss.item())
                    loss.backward()  # Backpropagation.

                    # Do gradient clipping to prevent gradient explosion.
                    clip_grad_norm_(model.encoder.parameters(),
                                    config.max_grad_norm)
                    clip_grad_norm_(model.decoder.parameters(),
                                    config.max_grad_norm)
                    clip_grad_norm_(model.attention.parameters(),
                                    config.max_grad_norm)
                    optimizer.step()  # Update weights.

                    # Output and record epoch loss every 100 batches.
                    if (batch % 32) == 0:
                        batch_progress.set_description(f'Epoch {epoch}')
                        batch_progress.set_postfix(Batch=batch,
                                                   Loss=loss.item())
                        batch_progress.update()
                        # Write loss for tensorboard.
                        writer.add_scalar(f'Average loss for epoch {epoch}',
                                          np.mean(batch_losses),
                                          global_step=batch)
            # Calculate average loss over all batches in an epoch.
            epoch_loss = np.mean(batch_losses)

            epoch_progress.set_description(f'Epoch {epoch}')
            epoch_progress.set_postfix(Loss=epoch_loss)
            epoch_progress.update()

            avg_val_loss = evaluate(model, val_data, epoch)

            print('training loss:{}'.format(epoch_loss),
                  'validation loss:{}'.format(avg_val_loss))

            # Update minimum evaluating loss.
            if (avg_val_loss < val_losses):
                torch.save(model.encoder, config.encoder_save_name)
                torch.save(model.decoder, config.decoder_save_name)
                torch.save(model.attention, config.attention_save_name)
                torch.save(model.reduce_state, config.reduce_state_save_name)
                val_losses = avg_val_loss
            with open(config.losses_path, 'wb') as f:
                pickle.dump(val_losses, f)

    writer.close()
示例#6
0
def train_net(model, args):

    data_path = args.data_path
    num_epochs = args.epochs
    gpu = args.gpu
    n_classes = args.classes
    data_width = args.width
    data_height = args.height

    # hyper parameter for training
    learning_rate = 1e-3
    v_noise = 0.1
    reg_strength = 1e-9

    # set device configuration
    device_ids = []

    if gpu == 'gpu':

        if not torch.cuda.is_available():
            print("No cuda available")
            raise SystemExit

        device = torch.device(args.device1)

        device_ids.append(args.device1)

        if args.device2 != -1:
            device_ids.append(args.device2)

        if args.device3 != -1:
            device_ids.append(args.device3)

        if args.device4 != -1:
            device_ids.append(args.device4)

    else:
        device = torch.device("cpu")

    if len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)

    model = model.to(device)

    # set image into training and validation dataset

    train_dataset = SampleDataset(data_path)

    print('total image : {}'.format(len(train_dataset)))

    train_indices, val_indices = train_test_split(np.arange(
        len(train_dataset)),
                                                  test_size=0.2)

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = DataLoader(train_dataset,
                              batch_size=50,
                              num_workers=4,
                              sampler=train_sampler)

    val_loader = DataLoader(train_dataset,
                            batch_size=30,
                            num_workers=4,
                            sampler=valid_sampler)

    model_folder = os.path.abspath('./checkpoints')
    if not os.path.exists(model_folder):
        os.mkdir(model_folder)

    if args.model_number == 1:
        model_path = os.path.join(model_folder, 'autoencoder1.pth')

    elif args.model_number == 2:
        model_path = os.path.join(model_folder, 'autoencoder2.pth')

    # loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters())

    display_steps = 10
    best_loss = 1e10
    loss_history = []

    for epoch in range(num_epochs):  # loop over the dataset multiple times

        model.train()
        epoch_size = 0
        running_loss = 0

        lambda2 = torch.tensor(reg_strength)
        l2_reg = torch.tensor(0.)

        for batch_idx, (data, _) in enumerate(train_loader):

            noise = v_noise * np.random.normal(size=np.shape(data))
            noise = torch.from_numpy(noise)

            noisy_train_data = data.double() + noise

            noisy_train_data.clamp(0.0, 1.0)

            noisy_train_data = noisy_train_data.to(device).float()

            data = data.to(device).float()

            optimizer.zero_grad()

            output = model(noisy_train_data)

            loss = criterion(output, data)

            for param in model.parameters():
                l2_reg += torch.norm(param)

            loss += lambda2 * l2_reg.detach().cpu()

            loss.backward()
            optimizer.step()

            if batch_idx % display_steps == 0:
                print('    ', end='')
                print('batch {:>3}/{:>3}, loss {:.4f}\r'\
                      .format(batch_idx+1, len(train_loader),loss.item()))

        # evalute
        print('Finished epoch {}, starting evaluation'.format(epoch + 1))

        model.eval()

        lambda2 = torch.tensor(reg_strength)
        l2_reg = torch.tensor(0.)

        with torch.no_grad():
            for data, _ in val_loader:

                data = data.to(device).float()
                target = data.to(device).float()

                output = model(data)

                loss = criterion(output, target)

                for param in model.parameters():
                    l2_reg += torch.norm(param)

                loss += lambda2 * l2_reg.detach().cpu()

                running_loss += loss.item()

        validate_loss = running_loss / len(val_loader)

        if validate_loss < best_loss:

            print('best validation loss : {:.4f}'.format(validate_loss))

            best_loss = validate_loss

            print("saving best model")

            model_copy = copy.deepcopy(model)
            model_copy = model_copy.cpu()
            model_state_dict = model_copy.state_dict()
            torch.save(model_state_dict, model_path)

    return loss_history
示例#7
0
def train_worker(rank, addr, port):

    # Distributed Setup
    os.environ['MASTER_ADDR'] = addr
    os.environ['MASTER_PORT'] = port
    dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)

    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]
                          ['train']['pha'], mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]
                          ['train']['fgr'], mode='RGB'),
        ], transforms=A.PairCompose([
            A.PairRandomAffineAndResize(
                (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
            A.PairRandomHorizontalFlip(),
            A.PairRandomBoxBlur(0.1, 5),
            A.PairRandomSharpen(0.1),
            A.PairApplyOnlyAtIndices(
                [1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
            A.PairApply(T.ToTensor())
        ]), assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
            A.RandomAffineAndResize(
                (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
            T.RandomHorizontalFlip(),
            A.RandomBoxBlur(0.1, 5),
            A.RandomSharpen(0.1),
            T.ColorJitter(0.15, 0.15, 0.15, 0.05),
            T.ToTensor()
        ])),
    ])
    dataset_train_len_per_gpu_worker = int(
        len(dataset_train) / distributed_num_gpus)
    dataset_train = Subset(dataset_train, range(
        rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  pin_memory=True,
                                  drop_last=True,
                                  batch_size=args.batch_size // distributed_num_gpus,
                                  num_workers=args.num_workers // distributed_num_gpus)

    # Validation DataLoader
    if rank == 0:
        dataset_valid = ZipDataset([
            ZipDataset([
                ImagesDataset(DATA_PATH[args.dataset_name]
                              ['valid']['pha'], mode='L'),
                ImagesDataset(DATA_PATH[args.dataset_name]
                              ['valid']['fgr'], mode='RGB')
            ], transforms=A.PairCompose([
                A.PairRandomAffineAndResize(
                    (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
                A.PairApply(T.ToTensor())
            ]), assert_equal_length=True),
            ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
                A.RandomAffineAndResize(
                    (2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
                T.ToTensor()
            ])),
        ])
        dataset_valid = SampleDataset(dataset_valid, 50)
        dataloader_valid = DataLoader(dataset_valid,
                                      pin_memory=True,
                                      drop_last=True,
                                      batch_size=args.batch_size // distributed_num_gpus,
                                      num_workers=args.num_workers // distributed_num_gpus)

    # Model
    model = MattingRefine(args.model_backbone,
                          args.model_backbone_scale,
                          args.model_refine_mode,
                          args.model_refine_sample_pixels,
                          args.model_refine_thresholding,
                          args.model_refine_kernel_size).to(rank)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model_distributed = nn.parallel.DistributedDataParallel(
        model, device_ids=[rank])

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))

    optimizer = Adam([
        {'params': model.backbone.parameters(), 'lr': 5e-5},
        {'params': model.aspp.parameters(), 'lr': 5e-5},
        {'params': model.decoder.parameters(), 'lr': 1e-4},
        {'params': model.refiner.parameters(), 'lr': 3e-4},
    ])
    scaler = GradScaler()

    # Logging and checkpoints
    if rank == 0:
        if not os.path.exists(f'checkpoint/{args.model_name}'):
            os.makedirs(f'checkpoint/{args.model_name}')
        writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.to(rank, non_blocking=True)
            true_fgr = true_fgr.to(rank, non_blocking=True)
            true_bgr = true_bgr.to(rank, non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(
                true_pha, true_fgr, true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(
                    0.3 * random.random())
                aug_shadow = T.RandomAffine(
                    degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)),) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(
                    true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(
                    true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(
                    true_src, true_bgr)
                loss = compute_loss(
                    pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            if rank == 0:
                if (i + 1) % args.log_train_loss_interval == 0:
                    writer.add_scalar('loss', loss, step)

                if (i + 1) % args.log_train_images_interval == 0:
                    writer.add_image('train_pred_pha',
                                     make_grid(pred_pha, nrow=5), step)
                    writer.add_image('train_pred_fgr',
                                     make_grid(pred_fgr, nrow=5), step)
                    writer.add_image('train_pred_com', make_grid(
                        pred_fgr * pred_pha, nrow=5), step)
                    writer.add_image('train_pred_err', make_grid(
                        pred_err_sm, nrow=5), step)
                    writer.add_image('train_true_src',
                                     make_grid(true_src, nrow=5), step)

                del true_pha, true_fgr, true_src, true_bgr
                del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm

                if (i + 1) % args.log_valid_interval == 0:
                    valid(model, dataloader_valid, writer, step)

                if (step + 1) % args.checkpoint_interval == 0:
                    torch.save(model.state_dict(
                    ), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')

        if rank == 0:
            torch.save(model.state_dict(),
                       f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

    # Clean up
    dist.destroy_process_group()
args.logdir = 'adversarial-detect-%s/train_num-%d-test_num-%d_%s' %\
              (args.arch, args.train_num_per_class, args.test_num_per_class, _timenow)
misc.prepare_logging(args)

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

# Datra loading code
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(SampleDataset(
    './data/train_images_list.pkl',
    start_class=args.start_class,
    end_class=args.end_class,
    num_per_class=args.train_num_per_class,
    random_order=True,
    transform=transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])),
                                           batch_size=1,
                                           shuffle=False,
                                           num_workers=4,
                                           pin_memory=True)

val_loader = torch.utils.data.DataLoader(SampleDataset(
    './data/val_images_list.pkl',
    start_class=args.start_class,
    end_class=args.end_class,
    num_per_class=args.test_num_per_class,
示例#9
0
def test(model, args):

    data_path = args.data_path
    gpu = args.gpu
    n_classes = args.classes
    data_width = args.width
    data_height = args.height

    # set device configuration
    device_ids = []

    if gpu == 'gpu':

        if not torch.cuda.is_available():
            print("No cuda available")
            raise SystemExit

        device = torch.device(args.device1)

        device_ids.append(args.device1)

        if args.device2 != -1:
            device_ids.append(args.device2)

        if args.device3 != -1:
            device_ids.append(args.device3)

        if args.device4 != -1:
            device_ids.append(args.device4)

    else:
        device = torch.device("cpu")

    if len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)

    model = model.to(device)

    # set testdataset

    test_dataset = SampleDataset(data_path)

    test_loader = DataLoader(
        test_dataset,
        batch_size=10,
        num_workers=4,
    )

    print('test_dataset : {}, test_loader : {}'.format(len(test_dataset),
                                                       len(test_loader)))

    avg_score = 0.0

    # test

    model.eval()  # Set model to evaluate mode

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            inputs = inputs.to(device).float()
            labels = labels.to(device).long()

            target = make_one_hot(labels[:, 0, :, :], n_classes, device)

            pred = model(inputs)

            loss = dice_score(pred, target)

            avg_score += loss.data.cpu().numpy()

            del inputs, labels, target, pred, loss

    avg_score /= len(test_loader)

    print('dice_score : {:.4f}'.format(avg_score))
示例#10
0
def test(model, args):

    data_path = args.data_path
    n_channels = args.channels
    n_classes = args.classes
    data_width = args.width
    data_height = args.height
    gpu = args.gpu

    # Hyper paremter for MagNet
    thresholds = [0.01, 0.05, 0.001, 0.005]

    reformer_model = None

    if args.reformer == 'autoencoder1':

        reformer_model = autoencoder(n_channels)

    elif args.reformer == 'autoencoder2':

        reformer_model = autoencoder2(n_channels)

    else:
        print("wrong reformer model : must be autoencoder1 or autoencoder2")
        raise SystemExit

    print('reformer model')
    summary(reformer_model,
            input_size=(n_channels, data_height, data_width),
            device='cpu')

    detector_model = None

    if args.detector == 'autoencoder1':

        detector_model = autoencoder(n_channels)

    elif args.detector == 'autoencoder2':

        detector_model = autoencoder2(n_channels)

    else:
        print("wrong detector model : must be autoencoder1 or autoencoder2")
        raise SystemExit

    print('detector model')
    summary(detector_model,
            input_size=(n_channels, data_height, data_width),
            device='cpu')

    # set device configuration
    device_ids = []

    if gpu == 'gpu':

        if not torch.cuda.is_available():
            print("No cuda available")
            raise SystemExit

        device = torch.device(args.model_device1)
        device_defense = torch.device(args.defense_model_device)

        device_ids.append(args.model_device1)

        if args.model_device2 != -1:
            device_ids.append(args.model_device2)

        if args.model_device3 != -1:
            device_ids.append(args.model_device3)

        if args.model_device4 != -1:
            device_ids.append(args.model_device4)

    else:
        device = torch.device("cpu")
        device_defense = torch.device("cpu")

    detector = AEDetector(detector_model,
                          device_defense,
                          args.detector_path,
                          p=2)
    reformer = SimpleReformer(reformer_model, device_defense,
                              args.reformer_path)
    classifier = Classifier(model, device, args.model_path, device_ids)

    # set testdataset

    test_dataset = SampleDataset(data_path)

    test_loader = DataLoader(
        test_dataset,
        batch_size=10,
        num_workers=4,
    )

    print('test_dataset : {}, test_loader : {}'.format(len(test_dataset),
                                                       len(test_loader)))

    # Defense with MagNet
    print('test start')

    for thrs in thresholds:

        print('----------------------------------------')

        counter = 0
        avg_score = 0.0
        thrs = torch.tensor(thrs)

        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(test_loader):

                inputs = inputs.float()
                labels = labels.to(device).long()
                target = make_one_hot(labels[:, 0, :, :], n_classes, device)

                operate_results = operate(reformer, classifier, inputs)

                all_pass, _ = filters(detector, inputs, thrs)

                if len(all_pass) == 0:
                    continue

                filtered_results = operate_results[all_pass]

                pred = filtered_results.to(device).float()

                target = target[all_pass]

                loss = dice_score(pred, target)

                avg_score += loss.data.cpu().numpy()

                # statistics
                counter += 1

                del inputs, labels, pred, target, loss

        if counter:
            avg_score = avg_score / counter
            print('threshold : {:.4f}, avg_score : {:.4f}'.format(
                thrs, avg_score))

        else:
            print(
                'threshold : {:.4f} , no images pass from filter'.format(thrs))
示例#11
0
文件: train.py 项目: fbjwying2/PNG
def train(dataset, val_dataset, v, start_epoch=0):
    """Train the model, evaluate it and store it.

    Args:
        dataset (dataset.PairDataset): The training dataset.
        val_dataset (dataset.PairDataset): The evaluation dataset.
        v (vocab.Vocab): The vocabulary built from the training dataset.
        start_epoch (int, optional): The starting epoch number. Defaults to 0.
    """
    print('loading model')
    DEVICE = torch.device("cuda" if config.is_cuda else "cpu")

    model = Seq2seq(v)
    model.load_model()
    model.to(DEVICE)

    # forward
    print("loading data")
    train_data = SampleDataset(dataset.pairs, v)
    val_data = SampleDataset(val_dataset.pairs, v)

    print("initializing optimizer")

    ###########################################
    #          TODO: module 1 task 2          #
    ###########################################

    # Define the optimizer.
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)

    val_losses = np.inf
    if (os.path.exists(config.losses_path)):
        with open(config.losses_path, 'rb') as f:
            val_losses = pickle.load(f)

    ###########################################
    #          TODO: module 3 task 3          #
    ###########################################

    # SummaryWriter: Log writer used for TensorboardX visualization.
    # writer = SummaryWriter(config.log_path) # todo

    # tqdm: A tool for drawing progress bars during training.
    with tqdm(total=config.epochs) as epoch_progress:
        # Loop for epochs.
        for epoch in range(start_epoch, config.epochs):
            batch_losses = []  # Get loss of each batch.
            with tqdm(total=len(train_dataloader) // config.batch_size)\
                    as batch_progress:
                # Lopp for batches.
                for batch, data in enumerate(tqdm(train_dataloader)):
                    x, y, x_len, y_len, oov, len_oovs = data
                    assert not np.any(np.isnan(x.numpy()))
                    if config.is_cuda:  # Training with GPUs.
                        x = x.to(DEVICE)
                        y = y.to(DEVICE)
                        x_len = x_len.to(DEVICE)
                        len_oovs = len_oovs.to(DEVICE)

                    ###########################################
                    #          TODO: module 3 task 1          #
                    ###########################################
                    model.train()
                    optimizer.zero_grad()
                    loss = model(x, x_len, y, len_oovs, batch)
                    batch_losses.append(loss.item())
                    loss.backward()

                    ###########################################
                    #          TODO: module 3 task 2          #
                    ###########################################

                    # Do gradient clipping to prevent gradient explosion.
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   config.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   config.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   config.max_grad_norm)

                    # Update weights.
                    optimizer.step()

                    # Output and record epoch loss every 100 batches.
                    if (batch % 100) == 0:
                        batch_progress.set_description(f'Epoch {epoch}')
                        batch_progress.set_postfix(Batch=batch,
                                                   Loss=loss.item())
                        batch_progress.update()
                        ###########################################
                        #          TODO: module 3 task 3          #
                        ###########################################

                        # Write loss for tensorboard.
                        # writer.add_scalar('Average_loss_for_eqoch', loss.item(), epoch)

            # Calculate average loss over all batches in an epoch.
            epoch_loss = np.mean(batch_losses)

            epoch_progress.set_description(f'Epoch {epoch}')
            epoch_progress.set_postfix(Loss=epoch_loss)
            epoch_progress.update()
            # Calculate evaluation loss.
            avg_val_loss = evaluate(model, val_data, epoch)

            print('training loss:{}'.format(epoch_loss),
                  'validation loss:{}'.format(avg_val_loss))

            # Update minimum evaluating loss.
            if (avg_val_loss < val_losses):
                torch.save(model.encoder, config.encoder_save_name)
                torch.save(model.decoder, config.decoder_save_name)
                torch.save(model.attention, config.attention_save_name)
                torch.save(model.reduce_state, config.reduce_state_save_name)
                val_losses = avg_val_loss
            with open(config.losses_path, 'wb') as f:
                pickle.dump(val_losses, f)
示例#12
0
def train_net(model, args):

    data_path = args.data_path
    num_epochs = args.epochs
    gpu = args.gpu
    n_classes = args.classes
    data_width = args.width
    data_height = args.height

    # set device configuration
    device_ids = []

    if gpu == 'gpu':

        if not torch.cuda.is_available():
            print("No cuda available")
            raise SystemExit

        device = torch.device(args.device1)

        device_ids.append(args.device1)

        if args.device2 != -1:
            device_ids.append(args.device2)

        if args.device3 != -1:
            device_ids.append(args.device3)

        if args.device4 != -1:
            device_ids.append(args.device4)

    else:
        device = torch.device("cpu")

    if len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)

    model = model.to(device)

    # set image into training and validation dataset

    train_dataset = SampleDataset(data_path)

    print('total image : {}'.format(len(train_dataset)))

    train_indices, val_indices = train_test_split(np.arange(
        len(train_dataset)),
                                                  test_size=0.2,
                                                  random_state=42)

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = DataLoader(train_dataset,
                              batch_size=20,
                              num_workers=4,
                              sampler=train_sampler)

    val_loader = DataLoader(train_dataset,
                            batch_size=10,
                            num_workers=4,
                            sampler=valid_sampler)

    model_folder = os.path.abspath('./checkpoints')
    if not os.path.exists(model_folder):
        os.mkdir(model_folder)

    if args.model == 'UNet':
        model_path = os.path.join(model_folder, 'UNet.pth')

    elif args.model == 'SegNet':
        model_path = os.path.join(model_folder, 'SegNet.pth')

    elif args.model == 'DenseNet':
        model_path = os.path.join(model_folder, 'DenseNet.pth')

    # set optimizer

    optimizer = torch.optim.Adam(model.parameters())

    # main train

    display_steps = 30
    best_loss = 1e10
    loss_history = []

    ## for early stopping
    early_stop = False
    patience = 7
    counter = 0

    for epoch in range(num_epochs):
        print('Starting epoch {}/{}'.format(epoch + 1, num_epochs))

        # train
        model.train()

        metrics = defaultdict(float)
        epoch_size = 0

        # train model
        for batch_idx, (images, masks) in enumerate(train_loader):

            images = images.to(device).float()
            masks = masks.to(device).long()

            optimizer.zero_grad()
            outputs = model(images)

            loss, cross, dice = combined_loss(outputs, masks.squeeze(1),
                                              device, n_classes)

            save_metrics(metrics, images.size(0), loss, cross, dice)

            loss.backward()
            optimizer.step()

            # statistics
            epoch_size += images.size(0)

            if batch_idx % display_steps == 0:
                print('    ', end='')
                print('batch {:>3}/{:>3} cross: {:.4f} , dice {:.4f} , combined_loss {:.4f}\r'\
                      .format(batch_idx+1, len(train_loader), cross.item(), dice.item(),loss.item()))

            del images, masks, outputs, loss, cross, dice

        print_metrics(metrics, epoch_size, 'train')

        # evalute
        print('Finished epoch {}, starting evaluation'.format(epoch + 1))
        model.eval()

        # validate model
        for images, masks in val_loader:
            images = images.to(device).float()
            masks = masks.to(device).long()

            outputs = model(images)

            loss, cross, dice = combined_loss(outputs, masks.squeeze(1),
                                              device, n_classes)

            save_metrics(metrics, images.size(0), loss, cross, dice)

            # statistics
            epoch_size += images.size(0)

            del images, masks, outputs, loss, cross, dice

        print_metrics(metrics, epoch_size, 'val')

        epoch_loss = metrics['loss'] / epoch_size

        # save model if best validation loss
        if epoch_loss < best_loss:
            print("saving best model")
            best_loss = epoch_loss

            model_copy = copy.deepcopy(model)
            model_copy = model_copy.cpu()

            model_state_dict = model_copy.module.state_dict(
            ) if len(device_ids) > 1 else model_copy.state_dict()
            torch.save(model_state_dict, model_path)

            del model_copy

            counter = 0

        else:
            counter += 1
            print('EarlyStopping counter : {:>3} / {:>3}'.format(
                counter, patience))

            if counter >= patience:
                early_stop = True

        loss_history.append(best_loss)
        print('Best val loss: {:4f}'.format(best_loss))

        if early_stop:
            print('Early Stopping')
            break

    return loss_history