def train(model: nn.Module,
          loader: DataLoader,
          class_loss: nn.Module,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          epoch: int,
          callback: VisdomLogger,
          freq: int,
          ex: Experiment = None) -> None:
    model.train()
    device = next(model.parameters()).device
    to_device = lambda x: x.to(device, non_blocking=True)
    loader_length = len(loader)
    train_losses = AverageMeter(device=device, length=loader_length)
    train_accs = AverageMeter(device=device, length=loader_length)

    pbar = tqdm(loader, ncols=80, desc='Training   [{:03d}]'.format(epoch))
    for i, (batch, labels, indices) in enumerate(pbar):
        batch, labels, indices = map(to_device, (batch, labels, indices))
        logits, features = model(batch)
        loss = class_loss(logits, labels).mean()
        acc = (logits.detach().argmax(1) == labels).float().mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_losses.append(loss)
        train_accs.append(acc)

        if callback is not None and not (i + 1) % freq:
            step = epoch + i / loader_length
            callback.scalar('xent',
                            step,
                            train_losses.last_avg,
                            title='Train Losses')
            callback.scalar('train_acc',
                            step,
                            train_accs.last_avg,
                            title='Train Acc')

    if ex is not None:
        for i, (loss, acc) in enumerate(
                zip(train_losses.values_list, train_accs.values_list)):
            step = epoch + i / loader_length
            ex.log_scalar('train.loss', loss, step=step)
            ex.log_scalar('train.acc', acc, step=step)
def do_epoch(args: argparse.Namespace,
             train_loader: torch.utils.data.DataLoader, model: DDP,
             optimizer: torch.optim.Optimizer,
             scheduler: torch.optim.lr_scheduler, epoch: int,
             callback: VisdomLogger, iter_per_epoch: int,
             log_iter: int) -> Tuple[torch.tensor, torch.tensor]:
    loss_meter = AverageMeter()
    train_losses = torch.zeros(log_iter).to(dist.get_rank())
    train_mIous = torch.zeros(log_iter).to(dist.get_rank())

    iterable_train_loader = iter(train_loader)

    if main_process(args):
        bar = tqdm(range(iter_per_epoch))
    else:
        bar = range(iter_per_epoch)

    for i in bar:
        model.train()
        current_iter = epoch * len(train_loader) + i + 1

        images, gt = iterable_train_loader.next()
        images = images.to(dist.get_rank(), non_blocking=True)
        gt = gt.to(dist.get_rank(), non_blocking=True)

        loss = compute_loss(
            args=args,
            model=model,
            images=images,
            targets=gt.long(),
            num_classes=args.num_classes_tr,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if args.scheduler == 'cosine':
            scheduler.step()

        if i % args.log_freq == 0:
            model.eval()
            logits = model(images)
            intersection, union, target = intersectionAndUnionGPU(
                logits.argmax(1), gt, args.num_classes_tr, 255)
            if args.distributed:
                dist.all_reduce(loss)
                dist.all_reduce(intersection)
                dist.all_reduce(union)
                dist.all_reduce(target)

            allAcc = (intersection.sum() / (target.sum() + 1e-10))  # scalar
            mAcc = (intersection / (target + 1e-10)).mean()
            mIoU = (intersection / (union + 1e-10)).mean()
            loss_meter.update(loss.item() / dist.get_world_size())

            if main_process(args):
                if callback is not None:
                    t = current_iter / len(train_loader)
                    callback.scalar('loss_train_batch',
                                    t,
                                    loss_meter.avg,
                                    title='Loss')
                    callback.scalars(['mIoU', 'mAcc', 'allAcc'],
                                     t, [mIoU, mAcc, allAcc],
                                     title='Training metrics')
                    for index, param_group in enumerate(
                            optimizer.param_groups):
                        lr = param_group['lr']
                        callback.scalar('lr', t, lr, title='Learning rate')
                        break

                train_losses[int(i / args.log_freq)] = loss_meter.avg
                train_mIous[int(i / args.log_freq)] = mIoU

    if args.scheduler != 'cosine':
        scheduler.step()

    return train_mIous, train_losses
Exemple #3
0
def main(args):
    rng = np.random.RandomState(args.seed)

    if args.test:
        assert args.checkpoint is not None, 'Please inform the checkpoint (trained model)'

    if args.logdir is None:
        logdir = get_logdir(args)
    else:
        logdir = pathlib.Path(args.logdir)
    if not logdir.exists():
        logdir.mkdir()

    print('Writing logs to {}'.format(logdir))

    device = torch.device(
        'cuda',
        args.gpu_idx) if torch.cuda.is_available() else torch.device('cpu')

    if args.port is not None:
        logger = VisdomLogger(port=args.port)
    else:
        logger = None

    print('Loading Data')
    x, y, yforg, usermapping, filenames = load_dataset(args.dataset_path)

    dev_users = range(args.dev_users[0], args.dev_users[1])
    if args.devset_size is not None:
        # Randomly select users from the dev set
        dev_users = rng.choice(dev_users, args.devset_size, replace=False)

    if args.devset_sk_size is not None:
        assert args.devset_sk_size <= len(
            dev_users), 'devset-sk-size should be smaller than devset-size'

        # Randomly select users from the dev set to have skilled forgeries (others don't)
        dev_sk_users = set(
            rng.choice(dev_users, args.devset_sk_size, replace=False))
    else:
        dev_sk_users = set(dev_users)

    print('{} users in dev set; {} users with skilled forgeries'.format(
        len(dev_users), len(dev_sk_users)))

    if args.exp_users is not None:
        val_users = range(args.exp_users[0], args.exp_users[1])
        print('Testing with users from {} to {}'.format(
            args.exp_users[0], args.exp_users[1]))
    elif args.use_testset:
        val_users = range(0, 300)
        print('Testing with Exploitation set')
    else:
        val_users = range(300, 350)

    print('Initializing model')
    base_model = models.available_models[args.model]().to(device)
    weights = base_model.build_weights(device)
    maml = MAML(base_model,
                args.num_updates,
                args.num_updates,
                args.train_lr,
                args.meta_lr,
                args.meta_min_lr,
                args.epochs,
                args.learn_task_lr,
                weights,
                device,
                logger,
                loss_function=balanced_binary_cross_entropy,
                is_classification=True)

    if args.checkpoint:
        params = torch.load(args.checkpoint)
        maml.load(params)

    if args.test:
        test_and_save(args, device, logdir, maml, val_users, x, y, yforg)
        return

    # Pretraining
    if args.pretrain_epochs > 0:
        print('Pre-training')
        data = util.get_subset((x, y, yforg), subset=range(350, 881))

        wrapped_model = PretrainWrapper(base_model, weights)

        if not args.pretrain_forg:
            data = util.remove_forgeries(data, forg_idx=2)

        train_loader, val_loader = pretrain.setup_data_loaders(
            data, 32, args.input_size)
        n_classes = len(np.unique(y))

        classification_layer = nn.Linear(base_model.feature_space_size,
                                         n_classes).to(device)
        if args.pretrain_forg:
            forg_layer = nn.Linear(base_model.feature_space_size, 1).to(device)
        else:
            forg_layer = nn.Module()  # Stub module with no parameters

        pretrain_args = argparse.Namespace(lr=0.01,
                                           lr_decay=0.1,
                                           lr_decay_times=1,
                                           momentum=0.9,
                                           weight_decay=0.001,
                                           forg=args.pretrain_forg,
                                           lamb=args.pretrain_forg_lambda,
                                           epochs=args.pretrain_epochs)
        print(pretrain_args)
        pretrain.train(wrapped_model,
                       classification_layer,
                       forg_layer,
                       train_loader,
                       val_loader,
                       device,
                       logger,
                       pretrain_args,
                       logdir=None)

    # MAML training

    trainset = MAMLDataSet(data=(x, y, yforg),
                           subset=dev_users,
                           sk_subset=dev_sk_users,
                           num_gen_train=args.num_gen,
                           num_rf_train=args.num_rf,
                           num_gen_test=args.num_gen_test,
                           num_rf_test=args.num_rf_test,
                           num_sk_test=args.num_sk_test,
                           input_shape=args.input_size,
                           test=False,
                           rng=np.random.RandomState(args.seed))

    val_set = MAMLDataSet(data=(x, y, yforg),
                          subset=val_users,
                          num_gen_train=args.num_gen,
                          num_rf_train=args.num_rf,
                          num_gen_test=args.num_gen_test,
                          num_rf_test=args.num_rf_test,
                          num_sk_test=args.num_sk_test,
                          input_shape=args.input_size,
                          test=True,
                          rng=np.random.RandomState(args.seed))

    loader = DataLoader(trainset,
                        batch_size=args.meta_batch_size,
                        shuffle=True,
                        num_workers=2,
                        collate_fn=trainset.collate_fn)

    print('Training')
    best_val_acc = 0
    with tqdm(initial=0, total=len(loader) * args.epochs) as pbar:
        if args.checkpoint is not None:
            postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch(
                maml, val_set, device, args.num_updates)

            if logger:
                for i in range(args.num_updates):
                    logger.scalar('val_postupdate_loss_{}'.format(i), 0,
                                  np.mean(postupdate_losses, axis=0)[i])

                    logger.scalar('val_postupdate_acc_{}'.format(i), 0,
                                  np.mean(postupdate_accs, axis=0)[i])

        for epoch in range(args.epochs):
            loss_weights = get_per_step_loss_importance_vector(
                args.num_updates, args.msl_epochs, epoch)

            n_batches = len(loader)
            for step, item in enumerate(loader):
                item = move_to_gpu(*item, device=device)
                maml.meta_learning_step((item[0], item[1]), (item[2], item[3]),
                                        loss_weights, epoch + step / n_batches)
                pbar.update(1)

            maml.scheduler.step()

            postupdate_accs, postupdate_losses, preupdate_losses = test_one_epoch(
                maml, val_set, device, args.num_updates)

            if logger:
                for i in range(args.num_updates):
                    logger.scalar('val_postupdate_loss_{}'.format(i),
                                  epoch + 1,
                                  np.mean(postupdate_losses, axis=0)[i])

                    logger.scalar('val_postupdate_acc_{}'.format(i), epoch + 1,
                                  np.mean(postupdate_accs, axis=0)[i])

                logger.save(logdir / 'train_curves.pickle')
            this_val_loss = np.mean(postupdate_losses, axis=0)[-1]
            this_val_acc = np.mean(postupdate_accs, axis=0)[-1]

            if this_val_acc > best_val_acc:
                best_val_acc = this_val_acc
                torch.save(maml.parameters, logdir / 'best_model.pth')
            print('Epoch {}. Val loss: {:.4f}. Val Acc: {:.2f}%'.format(
                epoch, this_val_loss, this_val_acc * 100))

    # Re-load best parameters and test with 10 folds
    params = torch.load(logdir / 'best_model.pth')
    maml.load(params)

    test_and_save(args, device, logdir, maml, val_users, x, y, yforg)