예제 #1
0
def main():

    opt = parse_option()

    # test loader
    args = opt

    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans,
                                                  fix_seed=False),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans,
                                                 fix_seed=False),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans,
            fix_seed=False),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans,
            fix_seed=False),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_test_options['D']
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans,
                                                  fix_seed=False),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans,
                                                 fix_seed=False),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # load model
    model = create_model(opt.model, n_cls, opt.dataset)
    ckpt = torch.load(opt.model_path)
    model.load_state_dict(ckpt['model'])

    if torch.cuda.is_available():
        model = model.cuda()
        cudnn.benchmark = True

    # evalation
    start = time.time()
    val_acc, val_std = meta_test(model, meta_valloader)
    val_time = time.time() - start
    print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc, val_std, val_time))

    start = time.time()
    val_acc_feat, val_std_feat = meta_test(model,
                                           meta_valloader,
                                           use_logit=False)
    val_time = time.time() - start
    print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc_feat, val_std_feat, val_time))

    start = time.time()
    test_acc, test_std = meta_test(model, meta_testloader)
    test_time = time.time() - start
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc, test_std, test_time))

    start = time.time()
    test_acc_feat, test_std_feat = meta_test(model,
                                             meta_testloader,
                                             use_logit=False)
    test_time = time.time() - start
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))
예제 #2
0
def main():

    opt = parse_option()

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # model
    if not opt.load_latest:
        model = create_model(opt.model, n_cls, opt.dataset)
    else:
        latest_file = os.path.join(opt.save_folder, 'latest.pth')
        model = load_teacher(latest_file, n_cls, opt.dataset)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, model, criterion,
                                      optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('train_loss', train_loss, epoch)

        test_acc, test_acc_top5, test_loss = validate(val_loader, model,
                                                      criterion, opt)

        logger.log_value('test_acc', test_acc, epoch)
        logger.log_value('test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_loss', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch':
                epoch,
                'model':
                model.state_dict()
                if opt.n_gpu <= 1 else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            latest_file = os.path.join(opt.save_folder, 'latest.pth')
            os.symlink(save_file, latest_file)

    # save the last model
    state = {
        'opt':
        opt,
        'model':
        model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)
예제 #3
0
def main():
    opt = parse_option()

    with open(f"{opt.tb_folder}/config.json", "w") as fo:
        fo.write(json.dumps(vars(opt), indent=4))

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    elif opt.dataset == "imagenet":
        train_trans, test_trans = transforms_options["A"]
        train_dataset = ImagenetFolder(root=os.path.join(
            opt.data_root, "train"),
                                       transform=train_trans)
        val_dataset = ImagenetFolder(root=os.path.join(opt.data_root, "val"),
                                     transform=test_trans)
        train_loader = DataLoader(train_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(val_dataset,
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        n_cls = 1000
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model = create_model(opt.model, n_cls, opt.dataset, use_srl=opt.srl)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    if opt.label_smoothing:
        criterion = LabelSmoothing(smoothing=opt.smoothing_ratio)
    elif opt.gce:
        criterion = GuidedComplementEntropy(alpha=opt.gce_alpha, classes=n_cls)
    else:
        criterion = nn.CrossEntropyLoss()
    if opt.opl:
        auxiliary_loss = OrthogonalProjectionLoss(use_attention=True)
    elif opt.popl:
        auxiliary_loss = PerpetualOrthogonalProjectionLoss(feat_dim=640)
    else:
        auxiliary_loss = None

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        if auxiliary_loss is not None:
            auxiliary_loss = auxiliary_loss.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)
    else:
        scheduler = None

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        if auxiliary_loss is not None:
            train_acc, train_loss, [train_cel, train_opl
                                    ] = train(epoch=epoch,
                                              train_loader=train_loader,
                                              model=model,
                                              criterion=criterion,
                                              optimizer=optimizer,
                                              opt=opt,
                                              auxiliary=auxiliary_loss)
        else:
            train_acc, train_loss = train(epoch=epoch,
                                          train_loader=train_loader,
                                          model=model,
                                          criterion=criterion,
                                          optimizer=optimizer,
                                          opt=opt)

        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('accuracy/train_acc', train_acc, epoch)
        logger.log_value('train_losses/loss', train_loss, epoch)
        if auxiliary_loss is not None:
            logger.log_value('train_losses/cel', train_cel, epoch)
            logger.log_value('train_losses/opl', train_opl, epoch)
        else:
            logger.log_value('train_losses/cel', train_loss, epoch)

        if auxiliary_loss is not None:
            test_acc, test_acc_top5, test_loss, [test_cel, test_opl] = \
                validate(val_loader, model, criterion, opt, auxiliary=auxiliary_loss)
        else:
            test_acc, test_acc_top5, test_loss = validate(
                val_loader, model, criterion, opt)

        logger.log_value('accuracy/test_acc', test_acc, epoch)
        logger.log_value('accuracy/test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_losses/loss', test_loss, epoch)
        if auxiliary_loss is not None:
            logger.log_value('test_losses/cel', test_cel, epoch)
            logger.log_value('test_losses/opl', test_opl, epoch)
        else:
            logger.log_value('test_losses/cel', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch':
                epoch,
                'model':
                model.state_dict()
                if opt.n_gpu <= 1 else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)

    # save the last model
    state = {
        'opt':
        opt,
        'model':
        model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)
예제 #4
0
파일: loaders.py 프로젝트: yassouali/SCL
def get_eval_loaders(opt):
    """
    Create the evaluation dataloaders
    """
    train_trans, test_trans = transforms_options[opt.transform]

    # ImagetNet derivatives - miniImageNet
    if opt.dataset == 'miniImageNet':
        assert opt.transform == "A"
        meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', train_transform=train_trans,
                                    test_transform=test_trans, fix_seed=False),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', train_transform=train_trans,
                                    test_transform=test_trans, fix_seed=False),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64

    # ImagetNet derivatives - tieredImageNet
    elif opt.dataset == 'tieredImageNet':
        assert opt.transform == "A"
        meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test',
                                    train_transform=train_trans, test_transform=test_trans, fix_seed=False),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val', train_transform=train_trans,
                                    test_transform=test_trans, fix_seed=False),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351

    # CIFAR-100 derivatives - both CIFAR-FS & FC100
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        assert opt.transform == "D"
        meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', train_transform=train_trans,
                                    test_transform=test_trans, fix_seed=False),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', train_transform=train_trans,
                                    test_transform=test_trans, fix_seed=False),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(opt.dataset))

    # For cross-domain - we evaluate on a new dataset / domain
    elif opt.dataset in ['cub', 'cars', 'places', 'plantae']:
        train_classes = {'cub': 100, 'cars': 98, 'places': 183, 'plantae': 100}
        assert opt.transform == "C"
        assert not opt.use_trainval, f"Train val option not possible for dataset {opt.dataset}"

        meta_testloader = DataLoader(MetaCUB(args=opt, partition='novel',
                                    train_transform=train_trans, test_transform=test_trans, fix_seed=False),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCUB(args=opt, partition='val', train_transform=train_trans,
                                    test_transform=test_trans, fix_seed=False),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        n_cls = train_classes[opt.dataset]

    else:
        raise NotImplementedError(opt.dataset)

    return meta_testloader, meta_valloader, n_cls
예제 #5
0
def main():
    best_acc = 0

    opt = parse_option()

    # tensorboard logger
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        if opt.distill in ['contrast']:
            train_set = ImageNet(args=opt,
                                 partition=train_partition,
                                 transform=train_trans,
                                 is_sample=True,
                                 k=opt.nce_k)
        else:
            train_set = ImageNet(args=opt,
                                 partition=train_partition,
                                 transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        if opt.distill in ['contrast']:
            train_set = TieredImageNet(args=opt,
                                       partition=train_partition,
                                       transform=train_trans,
                                       is_sample=True,
                                       k=opt.nce_k)
        else:
            train_set = TieredImageNet(args=opt,
                                       partition=train_partition,
                                       transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']
        if opt.distill in ['contrast']:
            train_set = CIFAR100(args=opt,
                                 partition=train_partition,
                                 transform=train_trans,
                                 is_sample=True,
                                 k=opt.nce_k)
        else:
            train_set = CIFAR100(args=opt,
                                 partition=train_partition,
                                 transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model_t = load_teacher(opt.path_t, n_cls, opt.dataset)
    model_s = create_model(opt.model_s, n_cls, opt.dataset)

    data = torch.randn(2, 3, 84, 84)
    model_t.eval()
    model_s.eval()
    feat_t, _ = model_t(data, is_feat=True)
    feat_s, _ = model_s(data, is_feat=True)

    module_list = nn.ModuleList([])
    module_list.append(model_s)
    trainable_list = nn.ModuleList([])
    trainable_list.append(model_s)

    criterion_cls = nn.CrossEntropyLoss()
    criterion_div = DistillKL(opt.kd_T)
    if opt.distill == 'kd':
        criterion_kd = DistillKL(opt.kd_T)
    elif opt.distill == 'contrast':
        criterion_kd = NCELoss(opt, n_data)
        embed_s = Embed(feat_s[-1].shape[1], opt.feat_dim)
        embed_t = Embed(feat_t[-1].shape[1], opt.feat_dim)
        module_list.append(embed_s)
        module_list.append(embed_t)
        trainable_list.append(embed_s)
        trainable_list.append(embed_t)
    elif opt.distill == 'attention':
        criterion_kd = Attention()
    elif opt.distill == 'hint':
        criterion_kd = HintLoss()
    else:
        raise NotImplementedError(opt.distill)

    criterion_list = nn.ModuleList([])
    criterion_list.append(criterion_cls)  # classification loss
    criterion_list.append(
        criterion_div)  # KL divergence loss, original knowledge distillation
    criterion_list.append(criterion_kd)  # other knowledge distillation loss

    # optimizer
    optimizer = optim.SGD(trainable_list.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    # append teacher after optimizer to avoid weight_decay
    module_list.append(model_t)

    if torch.cuda.is_available():
        module_list.cuda()
        criterion_list.cuda()
        cudnn.benchmark = True

    # validate teacher accuracy
    teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt)
    print('teacher accuracy: ', teacher_acc)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)

    # routine: supervised model distillation
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, module_list,
                                      criterion_list, optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('train_loss', train_loss, epoch)

        test_acc, test_acc_top5, test_loss = validate(val_loader, model_s,
                                                      criterion_cls, opt)

        logger.log_value('test_acc', test_acc, epoch)
        logger.log_value('test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_loss', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch': epoch,
                'model': model_s.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)

    # save the last model
    state = {
        'opt': opt,
        'model': model_s.state_dict(),
    }
    save_file = os.path.join(opt.save_folder,
                             '{}_last.pth'.format(opt.model_s))
    torch.save(state, save_file)
예제 #6
0
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
예제 #7
0
def main():

    opt = parse_option()

    if opt.name is not None:
        wandb.init(name=opt.name)
    else:
        wandb.init()
    wandb.config.update(opt)

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test',
                                                        train_transform=train_trans,
                                                        test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val',
                                                       train_transform=train_trans,
                                                       test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(opt.dataset))
    elif opt.dataset == 'CUB_200_2011':
        train_trans, test_trans = transforms_options['C']

        vocab = lang_utils.load_vocab(opt.lang_dir) if opt.lsl else None
        devocab = {v:k for k,v in vocab.items()} if opt.lsl else None

        train_loader = DataLoader(CUB2011(args=opt, partition=train_partition, transform=train_trans,
                                          vocab=vocab),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CUB2011(args=opt, partition='val', transform=test_trans, vocab=vocab),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCUB2011(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans, vocab=vocab),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCUB2011(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans, vocab=vocab),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            raise NotImplementedError(opt.dataset) # no trainval supported yet
            n_cls = 150
        else:
            n_cls = 100
    else:
        raise NotImplementedError(opt.dataset)

    print('Amount training data: {}'.format(len(train_loader.dataset)))
    print('Amount val data:      {}'.format(len(val_loader.dataset)))

    # model
    model = create_model(opt.model, n_cls, opt.dataset)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    criterion = nn.CrossEntropyLoss()


    # lsl
    lang_model = None
    if opt.lsl:
        if opt.glove_init:
            vecs = lang_utils.glove_init(vocab, emb_size=opt.lang_emb_size)
        embedding_model = nn.Embedding(
            len(vocab), opt.lang_emb_size, _weight=vecs if opt.glove_init else None
        )
        if opt.freeze_emb:
            embedding_model.weight.requires_grad = False

        lang_input_size = n_cls if opt.use_logit else 640 # 640 for resnet12
        lang_model = TextProposal(
            embedding_model,
            input_size=lang_input_size,
            hidden_size=opt.lang_hidden_size,
            project_input=lang_input_size != opt.lang_hidden_size,
            rnn=opt.rnn_type,
            num_layers=opt.rnn_num_layers,
            dropout=opt.rnn_dropout,
            vocab=vocab,
            **lang_utils.get_special_indices(vocab)
        )


    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True
        if opt.lsl:
            embedding_model = embedding_model.cuda()
            lang_model = lang_model.cuda()

    # tensorboard
    #logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1)

    # routine: supervised pre-training
    best_val_acc = 0
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss, train_lang_loss = train(
            epoch, train_loader, model, criterion, optimizer, opt, lang_model,
            devocab=devocab if opt.lsl else None
        )
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))


        print("==> validating...")
        test_acc, test_acc_top5, test_loss, test_lang_loss = validate(
            val_loader, model, criterion, opt, lang_model
        )

        # wandb
        log_metrics = {
            'train_acc':     train_acc,
            'train_loss':    train_loss,
            'val_acc':      test_acc,
            'val_acc_top5': test_acc_top5,
            'val_loss':     test_loss
        }
        if opt.lsl:
            log_metrics['train_lang_loss'] = train_lang_loss
            log_metrics['val_lang_loss']  = test_lang_loss
        wandb.log(log_metrics, step=epoch)

        # # regular saving
        # if epoch % opt.save_freq == 0 and not opt.dryrun:
        #     print('==> Saving...')
        #     state = {
        #         'epoch': epoch,
        #         'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
        #     }
        #     save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
        #     torch.save(state, save_file)


        if test_acc > best_val_acc:
            wandb.run.summary['best_val_acc'] = test_acc
            wandb.run.summary['best_val_acc_epoch'] = epoch

    # save the last model
    state = {
        'opt': opt,
        'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(wandb.run.dir, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)

    # evaluate on test set
    print("==> testing...")
    start = time.time()
    (test_acc, test_std), (test_acc5, test_std5) = meta_test(model, meta_testloader)
    test_time = time.time() - start
    print('Using logit layer for embedding')
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))
    print('test_acc top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(test_acc5, test_std5, test_time))

    start = time.time()
    (test_acc_feat, test_std_feat), (test_acc5_feat, test_std5_feat)  = meta_test(model, meta_testloader,
                                                                              use_logit=False)
    test_time = time.time() - start
    print('Using layer before logits for embedding')
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))
    print('test_acc_feat top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(
        test_acc5_feat, test_std5_feat, test_time))

    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_std'] = test_std
    wandb.run.summary['test_acc5'] = test_acc5
    wandb.run.summary['test_std5'] = test_std5
    wandb.run.summary['test_acc_feat'] = test_acc_feat
    wandb.run.summary['test_std_feat'] = test_std_feat
    wandb.run.summary['test_acc5_feat'] = test_acc5_feat
    wandb.run.summary['test_std5_feat'] = test_std5_feat
def get_dataloaders(opt):
    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'

    if opt.dataset == 'miniImageNet':

        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)

        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64

        no_sample = len(
            ImageNet(args=opt,
                     partition=train_partition,
                     transform=train_trans))

    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351

        no_sample = len(
            TieredImageNet(args=opt,
                           partition=train_partition,
                           transform=train_trans))

    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]

        meta_trainloader = DataLoader(MetaCIFAR100(args=opt,
                                                   partition='train',
                                                   train_transform=train_trans,
                                                   test_transform=test_trans),
                                      batch_size=1,
                                      shuffle=True,
                                      drop_last=False,
                                      num_workers=opt.num_workers)

        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
        no_sample = len(
            CIFAR100(args=opt,
                     partition=train_partition,
                     transform=train_trans))
    else:
        raise NotImplementedError(opt.dataset)

    return train_loader, val_loader, meta_testloader, meta_valloader, n_cls, no_sample
예제 #9
0
def main():

    opt = parse_option()

    # dataloader
    train_partition = "trainval" if opt.use_trainval else "train"
    if opt.dataset == "miniImageNet":
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(
            ImageNet(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            ImageNet(args=opt, partition="val", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        # meta_testloader = DataLoader(
        #     MetaImageNet(
        #         args=opt,
        #         partition="test",
        #         train_transform=train_trans,
        #         test_transform=test_trans,
        #     ),
        #     batch_size=opt.test_batch_size,
        #     shuffle=False,
        #     drop_last=False,
        #     num_workers=opt.num_workers,
        # )
        # meta_valloader = DataLoader(
        #     MetaImageNet(
        #         args=opt,
        #         partition="val",
        #         train_transform=train_trans,
        #         test_transform=test_trans,
        #     ),
        #     batch_size=opt.test_batch_size,
        #     shuffle=False,
        #     drop_last=False,
        #     num_workers=opt.num_workers,
        # )
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == "tieredImageNet":
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(
            TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            TieredImageNet(args=opt, partition="train_phase_val", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        meta_testloader = DataLoader(
            MetaTieredImageNet(
                args=opt,
                partition="test",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        meta_valloader = DataLoader(
            MetaTieredImageNet(
                args=opt,
                partition="val",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == "CIFAR-FS" or opt.dataset == "FC100":
        train_trans, test_trans = transforms_options["D"]

        train_loader = DataLoader(
            CIFAR100(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            CIFAR100(args=opt, partition="train", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        meta_testloader = DataLoader(
            MetaCIFAR100(
                args=opt,
                partition="test",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        meta_valloader = DataLoader(
            MetaCIFAR100(
                args=opt,
                partition="val",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == "CIFAR-FS":
                n_cls = 64
            elif opt.dataset == "FC100":
                n_cls = 60
            else:
                raise NotImplementedError(
                    "dataset not supported: {}".format(opt.dataset)
                )
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model = create_model(opt.model, n_cls, opt.dataset, opt.drop_rate, opt.dropblock)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(
            model.parameters(), lr=opt.learning_rate, weight_decay=0.0005
        )
    else:
        optimizer = optim.SGD(
            model.parameters(),
            lr=opt.learning_rate,
            momentum=opt.momentum,
            weight_decay=opt.weight_decay,
        )

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
    comet_logger = Experiment(
        api_key=os.environ["COMET_API_KEY"],
        project_name=opt.comet_project_name,
        workspace=opt.comet_workspace,
        disabled=not opt.logcomet,
    )
    comet_logger.set_name(opt.model_name)
    comet_logger.log_parameters(vars(opt))

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** opt.cosine_factor)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1
        )

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        with comet_logger.train():
            train_acc, train_loss = train(
                epoch, train_loader, model, criterion, optimizer, opt
            )
            comet_logger.log_metrics(
                {"acc": train_acc.cpu(), "loss_epoch": train_loss}, epoch=epoch
            )
        time2 = time.time()
        print("epoch {}, total time {:.2f}".format(epoch, time2 - time1))

        logger.log_value("train_acc", train_acc, epoch)
        logger.log_value("train_loss", train_loss, epoch)

        with comet_logger.validate():
            test_acc, test_acc_top5, test_loss = validate(
                val_loader, model, criterion, opt
            )
            comet_logger.log_metrics(
                {"acc": test_acc.cpu(), "acc_top5": test_acc_top5.cpu(), "loss": test_loss,},
                epoch=epoch,
            )

        logger.log_value("test_acc", test_acc, epoch)
        logger.log_value("test_acc_top5", test_acc_top5, epoch)
        logger.log_value("test_loss", test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print("==> Saving...")
            state = {
                "epoch": epoch,
                "model": model.state_dict()
                if opt.n_gpu <= 1
                else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, "ckpt_epoch_{epoch}.pth".format(epoch=epoch)
            )
            torch.save(state, save_file)

    # save the last model
    state = {
        "opt": opt,
        "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model))
    torch.save(state, save_file)
예제 #10
0
파일: dataloader.py 프로젝트: yyht/SKD
def get_dataloaders(opt):
    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'

    if opt.dataset == 'toy':

        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100_toy(args=opt,
                                               partition=train_partition,
                                               transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100_toy(args=opt,
                                             partition='train',
                                             transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        #         train_trans, test_trans = transforms_test_options[opt.transform]

        #         meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
        #                                                   train_transform=train_trans,
        #                                                   test_transform=test_trans),
        #                                      batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
        #                                      num_workers=opt.num_workers)
        #         meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
        #                                                  train_transform=train_trans,
        #                                                  test_transform=test_trans),
        #                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
        #                                     num_workers=opt.num_workers)
        n_cls = 5

        return train_loader, val_loader, 5, 5, n_cls

    if opt.dataset == 'miniImageNet':

        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)

        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]

        #         ns = [opt.n_shots].copy()
        #         opt.n_ways = 32
        #         opt.n_shots = 5
        #         opt.n_aug_support_samples = 2
        meta_trainloader = DataLoader(MetaCIFAR100(args=opt,
                                                   partition='train',
                                                   train_transform=train_trans,
                                                   test_transform=test_trans),
                                      batch_size=1,
                                      shuffle=True,
                                      drop_last=False,
                                      num_workers=opt.num_workers)

        #         opt.n_ways = 5
        #         opt.n_shots = ns[0]
        #         print(opt.n_shots)
        #         opt.n_aug_support_samples = 5
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
#         return train_loader, val_loader, meta_trainloader, meta_testloader, meta_valloader, n_cls
    else:
        raise NotImplementedError(opt.dataset)

    return train_loader, val_loader, meta_testloader, meta_valloader, n_cls