Exemplo n.º 1
0
    def __init__(self, opt):
        self.opt = opt
        self.n_cls = 64
        self.train_trans, self.test_trans = self.transforms_options()
        self.train_loader = DataLoader(ImageNet(args=self.opt, partition='train', transform=self.train_trans),
                                       batch_size=Config.batch_size, shuffle=True, drop_last=True,
                                       num_workers=Config.num_workers)
        self.val_loader = DataLoader(ImageNet(args=self.opt, partition='val', transform=self.test_trans),
                                     batch_size=Config.batch_size // 2, shuffle=False, drop_last=False,
                                     num_workers=Config.num_workers // 2)

        self.meta_trainloader = DataLoader(MetaImageNet(args=self.opt, partition='train_phase_test',
                                                        train_transform=self.train_trans,
                                                        test_transform=self.test_trans, fix_seed=False),
                                           batch_size=self.opt.test_batch_size, shuffle=False, drop_last=False,
                                           num_workers=Config.num_workers)
        self.meta_valloader = DataLoader(MetaImageNet(args=self.opt, partition='val', train_transform=self.train_trans,
                                                      test_transform=self.test_trans, fix_seed=False),
                                         batch_size=self.opt.test_batch_size, shuffle=False, drop_last=False,
                                         num_workers=Config.num_workers)
        self.meta_testloader = DataLoader(MetaImageNet(args=self.opt, partition='test', train_transform=self.train_trans,
                                                       test_transform=self.test_trans, fix_seed=False),
                                          batch_size=self.opt.test_batch_size, shuffle=False, drop_last=False,
                                          num_workers=Config.num_workers)

        # model
        self.model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=self.n_cls).cuda()

        self.optimizer = optim.SGD(self.model.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4)
        self.criterion = nn.CrossEntropyLoss().cuda()
        pass
Exemplo n.º 2
0
def main():
    opt = parse_option()

    with open(f"{opt.tb_folder}/config.json", "w") as fo:
        fo.write(json.dumps(vars(opt), indent=4))

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    elif opt.dataset == "imagenet":
        train_trans, test_trans = transforms_options["A"]
        train_dataset = ImagenetFolder(root=os.path.join(
            opt.data_root, "train"),
                                       transform=train_trans)
        val_dataset = ImagenetFolder(root=os.path.join(opt.data_root, "val"),
                                     transform=test_trans)
        train_loader = DataLoader(train_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(val_dataset,
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        n_cls = 1000
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model = create_model(opt.model, n_cls, opt.dataset, use_srl=opt.srl)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    if opt.label_smoothing:
        criterion = LabelSmoothing(smoothing=opt.smoothing_ratio)
    elif opt.gce:
        criterion = GuidedComplementEntropy(alpha=opt.gce_alpha, classes=n_cls)
    else:
        criterion = nn.CrossEntropyLoss()
    if opt.opl:
        auxiliary_loss = OrthogonalProjectionLoss(use_attention=True)
    elif opt.popl:
        auxiliary_loss = PerpetualOrthogonalProjectionLoss(feat_dim=640)
    else:
        auxiliary_loss = None

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        if auxiliary_loss is not None:
            auxiliary_loss = auxiliary_loss.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)
    else:
        scheduler = None

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        if auxiliary_loss is not None:
            train_acc, train_loss, [train_cel, train_opl
                                    ] = train(epoch=epoch,
                                              train_loader=train_loader,
                                              model=model,
                                              criterion=criterion,
                                              optimizer=optimizer,
                                              opt=opt,
                                              auxiliary=auxiliary_loss)
        else:
            train_acc, train_loss = train(epoch=epoch,
                                          train_loader=train_loader,
                                          model=model,
                                          criterion=criterion,
                                          optimizer=optimizer,
                                          opt=opt)

        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('accuracy/train_acc', train_acc, epoch)
        logger.log_value('train_losses/loss', train_loss, epoch)
        if auxiliary_loss is not None:
            logger.log_value('train_losses/cel', train_cel, epoch)
            logger.log_value('train_losses/opl', train_opl, epoch)
        else:
            logger.log_value('train_losses/cel', train_loss, epoch)

        if auxiliary_loss is not None:
            test_acc, test_acc_top5, test_loss, [test_cel, test_opl] = \
                validate(val_loader, model, criterion, opt, auxiliary=auxiliary_loss)
        else:
            test_acc, test_acc_top5, test_loss = validate(
                val_loader, model, criterion, opt)

        logger.log_value('accuracy/test_acc', test_acc, epoch)
        logger.log_value('accuracy/test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_losses/loss', test_loss, epoch)
        if auxiliary_loss is not None:
            logger.log_value('test_losses/cel', test_cel, epoch)
            logger.log_value('test_losses/opl', test_opl, epoch)
        else:
            logger.log_value('test_losses/cel', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch':
                epoch,
                'model':
                model.state_dict()
                if opt.n_gpu <= 1 else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)

    # save the last model
    state = {
        'opt':
        opt,
        'model':
        model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)
Exemplo n.º 3
0
def main():

    opt = parse_option()

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # model
    if not opt.load_latest:
        model = create_model(opt.model, n_cls, opt.dataset)
    else:
        latest_file = os.path.join(opt.save_folder, 'latest.pth')
        model = load_teacher(latest_file, n_cls, opt.dataset)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, model, criterion,
                                      optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('train_loss', train_loss, epoch)

        test_acc, test_acc_top5, test_loss = validate(val_loader, model,
                                                      criterion, opt)

        logger.log_value('test_acc', test_acc, epoch)
        logger.log_value('test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_loss', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch':
                epoch,
                'model':
                model.state_dict()
                if opt.n_gpu <= 1 else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            latest_file = os.path.join(opt.save_folder, 'latest.pth')
            os.symlink(save_file, latest_file)

    # save the last model
    state = {
        'opt':
        opt,
        'model':
        model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)
Exemplo n.º 4
0
def main():
    best_acc = 0

    opt = parse_option()

    # tensorboard logger
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        if opt.distill in ['contrast']:
            train_set = ImageNet(args=opt,
                                 partition=train_partition,
                                 transform=train_trans,
                                 is_sample=True,
                                 k=opt.nce_k)
        else:
            train_set = ImageNet(args=opt,
                                 partition=train_partition,
                                 transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        if opt.distill in ['contrast']:
            train_set = TieredImageNet(args=opt,
                                       partition=train_partition,
                                       transform=train_trans,
                                       is_sample=True,
                                       k=opt.nce_k)
        else:
            train_set = TieredImageNet(args=opt,
                                       partition=train_partition,
                                       transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']
        if opt.distill in ['contrast']:
            train_set = CIFAR100(args=opt,
                                 partition=train_partition,
                                 transform=train_trans,
                                 is_sample=True,
                                 k=opt.nce_k)
        else:
            train_set = CIFAR100(args=opt,
                                 partition=train_partition,
                                 transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model_t = load_teacher(opt.path_t, n_cls, opt.dataset)
    model_s = create_model(opt.model_s, n_cls, opt.dataset)

    data = torch.randn(2, 3, 84, 84)
    model_t.eval()
    model_s.eval()
    feat_t, _ = model_t(data, is_feat=True)
    feat_s, _ = model_s(data, is_feat=True)

    module_list = nn.ModuleList([])
    module_list.append(model_s)
    trainable_list = nn.ModuleList([])
    trainable_list.append(model_s)

    criterion_cls = nn.CrossEntropyLoss()
    criterion_div = DistillKL(opt.kd_T)
    if opt.distill == 'kd':
        criterion_kd = DistillKL(opt.kd_T)
    elif opt.distill == 'contrast':
        criterion_kd = NCELoss(opt, n_data)
        embed_s = Embed(feat_s[-1].shape[1], opt.feat_dim)
        embed_t = Embed(feat_t[-1].shape[1], opt.feat_dim)
        module_list.append(embed_s)
        module_list.append(embed_t)
        trainable_list.append(embed_s)
        trainable_list.append(embed_t)
    elif opt.distill == 'attention':
        criterion_kd = Attention()
    elif opt.distill == 'hint':
        criterion_kd = HintLoss()
    else:
        raise NotImplementedError(opt.distill)

    criterion_list = nn.ModuleList([])
    criterion_list.append(criterion_cls)  # classification loss
    criterion_list.append(
        criterion_div)  # KL divergence loss, original knowledge distillation
    criterion_list.append(criterion_kd)  # other knowledge distillation loss

    # optimizer
    optimizer = optim.SGD(trainable_list.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    # append teacher after optimizer to avoid weight_decay
    module_list.append(model_t)

    if torch.cuda.is_available():
        module_list.cuda()
        criterion_list.cuda()
        cudnn.benchmark = True

    # validate teacher accuracy
    teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt)
    print('teacher accuracy: ', teacher_acc)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)

    # routine: supervised model distillation
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, module_list,
                                      criterion_list, optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('train_loss', train_loss, epoch)

        test_acc, test_acc_top5, test_loss = validate(val_loader, model_s,
                                                      criterion_cls, opt)

        logger.log_value('test_acc', test_acc, epoch)
        logger.log_value('test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_loss', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch': epoch,
                'model': model_s.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)

    # save the last model
    state = {
        'opt': opt,
        'model': model_s.state_dict(),
    }
    save_file = os.path.join(opt.save_folder,
                             '{}_last.pth'.format(opt.model_s))
    torch.save(state, save_file)
Exemplo n.º 5
0
def get_train_loaders(opt, train_partition, worker_init_fn=None):
    """
    Create the training dataloaders
    """
    if opt.double_transform:
        train_trans_standard, test_trans = transforms_options[opt.transform]
        train_trans_contrast = get_contrastive_aug(dataset=opt.dataset, aug_type=opt.aug_type)
        train_trans = TwoCropTransform(train_trans_standard, train_trans_contrast)
    else:
        train_trans, test_trans = transforms_options[opt.transform]

    # ImagetNet derivatives - miniImageNet
    if opt.dataset == 'miniImageNet':
        assert opt.transform == "A"
        train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans),
                                batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                num_workers=opt.num_workers, worker_init_fn=worker_init_fn)
        val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64

    # ImagetNet derivatives - tieredImageNet
    elif opt.dataset == 'tieredImageNet':
        assert opt.transform == "A"
        train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers, worker_init_fn=worker_init_fn)
        val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351

    # CIFAR-100 derivatives - both CIFAR-FS & FC100
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        assert opt.transform == "D" or opt.transform == "Dcontrast"
        train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans),
                                batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                num_workers=opt.num_workers, worker_init_fn=worker_init_fn)
        val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(opt.dataset))
    
    # For cross-domain experiments we train on all of the sets (train, val and test)
    elif opt.dataset == 'cross':
        assert opt.transform == "A"
        
        train_dataset = ImageNet(args=opt, partition='train', transform=train_trans)
        val_dataset = ImageNet(args=opt, partition='val', transform=train_trans)
        test_dataset = ImageNet(args=opt, partition='test', transform=train_trans)
        
        all_datasets = ConcatDataset([train_dataset, val_dataset, test_dataset])

        train_loader = DataLoader(all_datasets, batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                num_workers=opt.num_workers, worker_init_fn=worker_init_fn)
        val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2, worker_init_fn=worker_init_fn)
        n_cls = 64+16+20 # train + val + test

    else:
        raise NotImplementedError(opt.dataset)


    return train_loader, val_loader, n_cls
Exemplo n.º 6
0
import pprint
_utils_pp = pprint.PrettyPrinter()
def pprint(x):
    _utils_pp.pprint(x)


if __name__ == '__main__':
    opt = parse_option()
    pprint(vars(opt))

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
Exemplo n.º 7
0
def main():

    opt = parse_option()

    if opt.name is not None:
        wandb.init(name=opt.name)
    else:
        wandb.init()
    wandb.config.update(opt)

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test',
                                                        train_transform=train_trans,
                                                        test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val',
                                                       train_transform=train_trans,
                                                       test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(opt.dataset))
    elif opt.dataset == 'CUB_200_2011':
        train_trans, test_trans = transforms_options['C']

        vocab = lang_utils.load_vocab(opt.lang_dir) if opt.lsl else None
        devocab = {v:k for k,v in vocab.items()} if opt.lsl else None

        train_loader = DataLoader(CUB2011(args=opt, partition=train_partition, transform=train_trans,
                                          vocab=vocab),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CUB2011(args=opt, partition='val', transform=test_trans, vocab=vocab),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCUB2011(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans, vocab=vocab),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCUB2011(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans, vocab=vocab),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            raise NotImplementedError(opt.dataset) # no trainval supported yet
            n_cls = 150
        else:
            n_cls = 100
    else:
        raise NotImplementedError(opt.dataset)

    print('Amount training data: {}'.format(len(train_loader.dataset)))
    print('Amount val data:      {}'.format(len(val_loader.dataset)))

    # model
    model = create_model(opt.model, n_cls, opt.dataset)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    criterion = nn.CrossEntropyLoss()


    # lsl
    lang_model = None
    if opt.lsl:
        if opt.glove_init:
            vecs = lang_utils.glove_init(vocab, emb_size=opt.lang_emb_size)
        embedding_model = nn.Embedding(
            len(vocab), opt.lang_emb_size, _weight=vecs if opt.glove_init else None
        )
        if opt.freeze_emb:
            embedding_model.weight.requires_grad = False

        lang_input_size = n_cls if opt.use_logit else 640 # 640 for resnet12
        lang_model = TextProposal(
            embedding_model,
            input_size=lang_input_size,
            hidden_size=opt.lang_hidden_size,
            project_input=lang_input_size != opt.lang_hidden_size,
            rnn=opt.rnn_type,
            num_layers=opt.rnn_num_layers,
            dropout=opt.rnn_dropout,
            vocab=vocab,
            **lang_utils.get_special_indices(vocab)
        )


    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True
        if opt.lsl:
            embedding_model = embedding_model.cuda()
            lang_model = lang_model.cuda()

    # tensorboard
    #logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1)

    # routine: supervised pre-training
    best_val_acc = 0
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss, train_lang_loss = train(
            epoch, train_loader, model, criterion, optimizer, opt, lang_model,
            devocab=devocab if opt.lsl else None
        )
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))


        print("==> validating...")
        test_acc, test_acc_top5, test_loss, test_lang_loss = validate(
            val_loader, model, criterion, opt, lang_model
        )

        # wandb
        log_metrics = {
            'train_acc':     train_acc,
            'train_loss':    train_loss,
            'val_acc':      test_acc,
            'val_acc_top5': test_acc_top5,
            'val_loss':     test_loss
        }
        if opt.lsl:
            log_metrics['train_lang_loss'] = train_lang_loss
            log_metrics['val_lang_loss']  = test_lang_loss
        wandb.log(log_metrics, step=epoch)

        # # regular saving
        # if epoch % opt.save_freq == 0 and not opt.dryrun:
        #     print('==> Saving...')
        #     state = {
        #         'epoch': epoch,
        #         'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
        #     }
        #     save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
        #     torch.save(state, save_file)


        if test_acc > best_val_acc:
            wandb.run.summary['best_val_acc'] = test_acc
            wandb.run.summary['best_val_acc_epoch'] = epoch

    # save the last model
    state = {
        'opt': opt,
        'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(wandb.run.dir, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)

    # evaluate on test set
    print("==> testing...")
    start = time.time()
    (test_acc, test_std), (test_acc5, test_std5) = meta_test(model, meta_testloader)
    test_time = time.time() - start
    print('Using logit layer for embedding')
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))
    print('test_acc top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(test_acc5, test_std5, test_time))

    start = time.time()
    (test_acc_feat, test_std_feat), (test_acc5_feat, test_std5_feat)  = meta_test(model, meta_testloader,
                                                                              use_logit=False)
    test_time = time.time() - start
    print('Using layer before logits for embedding')
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))
    print('test_acc_feat top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(
        test_acc5_feat, test_std5_feat, test_time))

    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_std'] = test_std
    wandb.run.summary['test_acc5'] = test_acc5
    wandb.run.summary['test_std5'] = test_std5
    wandb.run.summary['test_acc_feat'] = test_acc_feat
    wandb.run.summary['test_std_feat'] = test_std_feat
    wandb.run.summary['test_acc5_feat'] = test_acc5_feat
    wandb.run.summary['test_std5_feat'] = test_std5_feat
Exemplo n.º 8
0
def main():
    opt = parse_option()

    print(pp.pformat(vars(opt)))

    train_partition = "trainval" if opt.use_trainval else "train"
    if opt.dataset == "miniImageNet":
        train_trans, test_trans = transforms_options[opt.transform]

        if opt.augment == "none":
            train_train_trans = train_test_trans = test_trans
        elif opt.augment == "all":
            train_train_trans = train_test_trans = train_trans
        elif opt.augment == "spt":
            train_train_trans = train_trans
            train_test_trans = test_trans
        elif opt.augment == "qry":
            train_train_trans = test_trans
            train_test_trans = train_trans

        print("spt trans")
        print(train_train_trans)
        print("qry trans")
        print(train_test_trans)

        sub_batch_size, rmd = divmod(opt.batch_size, opt.apply_every)
        assert rmd == 0
        print("Train sub batch-size:", sub_batch_size)

        meta_train_dataset = MetaImageNet(
            args=opt,
            partition="train",
            train_transform=train_train_trans,
            test_transform=train_test_trans,
            fname="miniImageNet_category_split_train_phase_%s.pickle",
            fix_seed=False,
            n_test_runs=10000000,  # big number to never stop
            new_labels=False,
        )
        meta_trainloader = DataLoader(
            meta_train_dataset,
            batch_size=sub_batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        meta_train_dataset_qry = MetaImageNet(
            args=opt,
            partition="train",
            train_transform=train_train_trans,
            test_transform=train_test_trans,
            fname="miniImageNet_category_split_train_phase_%s.pickle",
            fix_seed=False,
            n_test_runs=10000000,  # big number to never stop
            new_labels=False,
            n_ways=opt.n_qry_way,
            n_shots=opt.n_qry_shot,
            n_queries=0,
        )
        meta_trainloader_qry = DataLoader(
            meta_train_dataset_qry,
            batch_size=sub_batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        meta_val_dataset = MetaImageNet(
            args=opt,
            partition="val",
            train_transform=test_trans,
            test_transform=test_trans,
            fix_seed=False,
            n_test_runs=200,
            n_ways=5,
            n_shots=5,
            n_queries=15,
        )
        meta_valloader = DataLoader(
            meta_val_dataset,
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        val_loader = DataLoader(
            ImageNet(args=opt, partition="val", transform=test_trans),
            batch_size=opt.sup_val_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        # if opt.use_trainval:
        #     n_cls = 80
        # else:
        #     n_cls = 64
        n_cls = len(meta_train_dataset.classes)

    print(n_cls)

    # x_spt, y_spt, x_qry, y_qry = next(iter(meta_trainloader))

    # x_spt2, y_spt2, x_qry2, y_qry2 = next(iter(meta_trainloader_qry))

    # print(x_spt, y_spt, x_qry, y_qry)
    # print(x_spt2, y_spt2, x_qry2, y_qry2)
    # print(x_spt.shape, y_spt.shape, x_qry.shape, y_qry.shape)
    # print(x_spt2.shape, y_spt2.shape, x_qry2.shape, y_qry2.shape)

    model = create_model(
        opt.model,
        n_cls,
        opt.dataset,
        opt.drop_rate,
        opt.dropblock,
        opt.track_stats,
        opt.initializer,
        opt.weight_norm,
        activation=opt.activation,
        normalization=opt.normalization,
    )

    print(model)

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        print(torch.cuda.get_device_name())
        device = torch.device("cuda")
        # if opt.n_gpu > 1:
        #     model = nn.DataParallel(model)
        model = model.to(device)
        criterion = criterion.to(device)
        cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    print("Learning rate")
    print(opt.learning_rate)
    print("Inner Learning rate")
    print(opt.inner_lr)
    if opt.learn_lr:
        print("Optimizing learning rate")
    inner_lr = nn.Parameter(torch.tensor(opt.inner_lr),
                            requires_grad=opt.learn_lr)
    optimizer = torch.optim.Adam(
        list(model.parameters()) +
        [inner_lr] if opt.learn_lr else model.parameters(),
        lr=opt.learning_rate,
    )
    # classifier = model.classifier()
    inner_opt = torch.optim.SGD(
        model.classifier.parameters(),
        lr=opt.inner_lr,
    )
    logger = SummaryWriter(logdir=opt.tb_folder,
                           flush_secs=10,
                           comment=opt.model_name)
    comet_logger = Experiment(
        api_key=os.environ["COMET_API_KEY"],
        project_name=opt.comet_project_name,
        workspace=opt.comet_workspace,
        disabled=not opt.logcomet,
        auto_metric_logging=False,
    )
    comet_logger.set_name(opt.model_name)
    comet_logger.log_parameters(vars(opt))
    comet_logger.set_model_graph(str(model))

    if opt.cosine:
        eta_min = opt.learning_rate * opt.cosine_factor
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.num_steps, eta_min, -1)

    # routine: supervised pre-training
    data_sampler = iter(meta_trainloader)
    data_sampler_qry = iter(meta_trainloader_qry)
    pbar = tqdm(
        range(1, opt.num_steps + 1),
        miniters=opt.print_freq,
        mininterval=3,
        maxinterval=30,
        ncols=0,
    )
    best_val_acc = 0.0
    for step in pbar:

        if not opt.cosine:
            adjust_learning_rate(step, opt, optimizer)
        # print("==> training...")

        time1 = time.time()

        foa = 0.0
        fol = 0.0
        ioa = 0.0
        iil = 0.0
        fil = 0.0
        iia = 0.0
        fia = 0.0
        for j in range(opt.apply_every):

            x_spt, y_spt, x_qry, y_qry = [
                t.to(device) for t in next(data_sampler)
            ]
            x_qry2, y_qry2, _, _ = [
                t.to(device) for t in next(data_sampler_qry)
            ]
            y_spt = y_spt.flatten(1)
            y_qry2 = y_qry2.flatten(1)

            x_qry = torch.cat((x_spt, x_qry, x_qry2), 1)
            y_qry = torch.cat((y_spt, y_qry, y_qry2), 1)

            if step == 1 and j == 0:
                print(x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size())

            info = train_step(
                model,
                model.classifier,
                None,
                # inner_opt,
                inner_lr,
                x_spt,
                y_spt,
                x_qry,
                y_qry,
                reset_head=opt.reset_head,
                num_steps=opt.num_inner_steps,
            )

            _foa = info["foa"] / opt.batch_size
            _fol = info["fol"] / opt.batch_size
            _ioa = info["ioa"] / opt.batch_size
            _iil = info["iil"] / opt.batch_size
            _fil = info["fil"] / opt.batch_size
            _iia = info["iia"] / opt.batch_size
            _fia = info["fia"] / opt.batch_size

            _fol.backward()

            foa += _foa.detach()
            fol += _fol.detach()
            ioa += _ioa.detach()
            iil += _iil.detach()
            fil += _fil.detach()
            iia += _iia.detach()
            fia += _fia.detach()

        optimizer.step()
        optimizer.zero_grad()
        inner_lr.data.clamp_(min=0.001)

        if opt.cosine:
            scheduler.step()
        if (step == 1) or (step % opt.eval_freq == 0):
            val_info = test_run(
                iter(meta_valloader),
                model,
                model.classifier,
                torch.optim.SGD(model.classifier.parameters(),
                                lr=inner_lr.item()),
                num_inner_steps=opt.num_inner_steps_test,
                device=device,
            )
            val_acc_feat, val_std_feat = meta_test(
                model,
                meta_valloader,
                use_logit=False,
            )

            val_acc = val_info["outer"]["acc"].cpu()
            val_loss = val_info["outer"]["loss"].cpu()

            sup_acc, sup_acc_top5, sup_loss = validate(
                val_loader,
                model,
                criterion,
                print_freq=100000000,
            )
            sup_acc = sup_acc.item()
            sup_acc_top5 = sup_acc_top5.item()

            print(f"\nValidation step {step}")
            print(f"MAML 5-way-5-shot accuracy: {val_acc.item()}")
            print(f"LR 5-way-5-shot accuracy: {val_acc_feat}+-{val_std_feat}")
            print(
                f"Supervised accuracy: Acc@1: {sup_acc} Acc@5: {sup_acc_top5} Loss: {sup_loss}"
            )

            if val_acc_feat > best_val_acc:
                best_val_acc = val_acc_feat
                print(
                    f"New best validation accuracy {best_val_acc.item()} saving checkpoints\n"
                )
                # print(val_acc.item())

                torch.save(
                    {
                        "opt":
                        opt,
                        "model":
                        model.state_dict()
                        if opt.n_gpu <= 1 else model.module.state_dict(),
                        "optimizer":
                        optimizer.state_dict(),
                        "step":
                        step,
                        "val_acc":
                        val_acc,
                        "val_loss":
                        val_loss,
                        "val_acc_lr":
                        val_acc_feat,
                        "sup_acc":
                        sup_acc,
                        "sup_acc_top5":
                        sup_acc_top5,
                        "sup_loss":
                        sup_loss,
                    },
                    os.path.join(opt.save_folder,
                                 "{}_best.pth".format(opt.model)),
                )

            comet_logger.log_metrics(
                dict(
                    fol=val_loss,
                    foa=val_acc,
                    acc_lr=val_acc_feat,
                    sup_acc=sup_acc,
                    sup_acc_top5=sup_acc_top5,
                    sup_loss=sup_loss,
                ),
                step=step,
                prefix="val",
            )

            logger.add_scalar("val_acc", val_acc, step)
            logger.add_scalar("val_loss", val_loss, step)
            logger.add_scalar("val_acc_lr", val_acc_feat, step)
            logger.add_scalar("sup_acc", sup_acc, step)
            logger.add_scalar("sup_acc_top5", sup_acc_top5, step)
            logger.add_scalar("sup_loss", sup_loss, step)

        if (step == 1) or (step % opt.eval_freq == 0) or (step % opt.print_freq
                                                          == 0):

            tfol = fol.cpu()
            tfoa = foa.cpu()
            tioa = ioa.cpu()
            tiil = iil.cpu()
            tfil = fil.cpu()
            tiia = iia.cpu()
            tfia = fia.cpu()

            comet_logger.log_metrics(
                dict(
                    fol=tfol,
                    foa=tfoa,
                    ioa=tfoa,
                    iil=tiil,
                    fil=tfil,
                    iia=tiia,
                    fia=tfia,
                ),
                step=step,
                prefix="train",
            )

            logger.add_scalar("train_acc", tfoa.item(), step)
            logger.add_scalar("train_loss", tfol.item(), step)
            logger.add_scalar("train_ioa", tioa, step)
            logger.add_scalar("train_iil", tiil, step)
            logger.add_scalar("train_fil", tfil, step)
            logger.add_scalar("train_iia", tiia, step)
            logger.add_scalar("train_fia", tfia, step)

            pbar.set_postfix(
                # iol=f"{info['iol'].item():.2f}",
                fol=f"{tfol.item():.2f}",
                # ioa=f"{info['ioa'].item():.2f}",
                foa=f"{tfoa.item():.2f}",
                ioa=f"{tioa.item():.2f}",
                iia=f"{tiia.item():.2f}",
                fia=f"{tfia.item():.2f}",
                vl=f"{val_loss.item():.2f}",
                va=f"{val_acc.item():.2f}",
                valr=f"{val_acc_feat:.2f}",
                lr=f"{inner_lr.item():.4f}",
                vsa=f"{sup_acc:.2f}",
                # iil=f"{info['iil'].item():.2f}",
                # fil=f"{info['fil'].item():.2f}",
                # iia=f"{info['iia'].item():.2f}",
                # fia=f"{info['fia'].item():.2f}",
                # counter=info["counter"],
                refresh=True,
            )

    # save the last model
    state = {
        "opt":
        opt,
        "model":
        model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
        "optimizer":
        optimizer.state_dict(),
        "step":
        step,
    }
    save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model))
    torch.save(state, save_file)
def get_dataloaders(opt):
    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'

    if opt.dataset == 'miniImageNet':

        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)

        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64

        no_sample = len(
            ImageNet(args=opt,
                     partition=train_partition,
                     transform=train_trans))

    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351

        no_sample = len(
            TieredImageNet(args=opt,
                           partition=train_partition,
                           transform=train_trans))

    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]

        meta_trainloader = DataLoader(MetaCIFAR100(args=opt,
                                                   partition='train',
                                                   train_transform=train_trans,
                                                   test_transform=test_trans),
                                      batch_size=1,
                                      shuffle=True,
                                      drop_last=False,
                                      num_workers=opt.num_workers)

        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
        no_sample = len(
            CIFAR100(args=opt,
                     partition=train_partition,
                     transform=train_trans))
    else:
        raise NotImplementedError(opt.dataset)

    return train_loader, val_loader, meta_testloader, meta_valloader, n_cls, no_sample
Exemplo n.º 10
0
def main():

    opt = parse_option()

    # dataloader
    train_partition = "trainval" if opt.use_trainval else "train"
    if opt.dataset == "miniImageNet":
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(
            ImageNet(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            ImageNet(args=opt, partition="val", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        # meta_testloader = DataLoader(
        #     MetaImageNet(
        #         args=opt,
        #         partition="test",
        #         train_transform=train_trans,
        #         test_transform=test_trans,
        #     ),
        #     batch_size=opt.test_batch_size,
        #     shuffle=False,
        #     drop_last=False,
        #     num_workers=opt.num_workers,
        # )
        # meta_valloader = DataLoader(
        #     MetaImageNet(
        #         args=opt,
        #         partition="val",
        #         train_transform=train_trans,
        #         test_transform=test_trans,
        #     ),
        #     batch_size=opt.test_batch_size,
        #     shuffle=False,
        #     drop_last=False,
        #     num_workers=opt.num_workers,
        # )
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == "tieredImageNet":
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(
            TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            TieredImageNet(args=opt, partition="train_phase_val", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        meta_testloader = DataLoader(
            MetaTieredImageNet(
                args=opt,
                partition="test",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        meta_valloader = DataLoader(
            MetaTieredImageNet(
                args=opt,
                partition="val",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == "CIFAR-FS" or opt.dataset == "FC100":
        train_trans, test_trans = transforms_options["D"]

        train_loader = DataLoader(
            CIFAR100(args=opt, partition=train_partition, transform=train_trans),
            batch_size=opt.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
        )
        val_loader = DataLoader(
            CIFAR100(args=opt, partition="train", transform=test_trans),
            batch_size=opt.batch_size // 2,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers // 2,
        )
        meta_testloader = DataLoader(
            MetaCIFAR100(
                args=opt,
                partition="test",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        meta_valloader = DataLoader(
            MetaCIFAR100(
                args=opt,
                partition="val",
                train_transform=train_trans,
                test_transform=test_trans,
            ),
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
        )
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == "CIFAR-FS":
                n_cls = 64
            elif opt.dataset == "FC100":
                n_cls = 60
            else:
                raise NotImplementedError(
                    "dataset not supported: {}".format(opt.dataset)
                )
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model = create_model(opt.model, n_cls, opt.dataset, opt.drop_rate, opt.dropblock)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(
            model.parameters(), lr=opt.learning_rate, weight_decay=0.0005
        )
    else:
        optimizer = optim.SGD(
            model.parameters(),
            lr=opt.learning_rate,
            momentum=opt.momentum,
            weight_decay=opt.weight_decay,
        )

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
    comet_logger = Experiment(
        api_key=os.environ["COMET_API_KEY"],
        project_name=opt.comet_project_name,
        workspace=opt.comet_workspace,
        disabled=not opt.logcomet,
    )
    comet_logger.set_name(opt.model_name)
    comet_logger.log_parameters(vars(opt))

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** opt.cosine_factor)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1
        )

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        with comet_logger.train():
            train_acc, train_loss = train(
                epoch, train_loader, model, criterion, optimizer, opt
            )
            comet_logger.log_metrics(
                {"acc": train_acc.cpu(), "loss_epoch": train_loss}, epoch=epoch
            )
        time2 = time.time()
        print("epoch {}, total time {:.2f}".format(epoch, time2 - time1))

        logger.log_value("train_acc", train_acc, epoch)
        logger.log_value("train_loss", train_loss, epoch)

        with comet_logger.validate():
            test_acc, test_acc_top5, test_loss = validate(
                val_loader, model, criterion, opt
            )
            comet_logger.log_metrics(
                {"acc": test_acc.cpu(), "acc_top5": test_acc_top5.cpu(), "loss": test_loss,},
                epoch=epoch,
            )

        logger.log_value("test_acc", test_acc, epoch)
        logger.log_value("test_acc_top5", test_acc_top5, epoch)
        logger.log_value("test_loss", test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print("==> Saving...")
            state = {
                "epoch": epoch,
                "model": model.state_dict()
                if opt.n_gpu <= 1
                else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, "ckpt_epoch_{epoch}.pth".format(epoch=epoch)
            )
            torch.save(state, save_file)

    # save the last model
    state = {
        "opt": opt,
        "model": model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model))
    torch.save(state, save_file)
Exemplo n.º 11
0
def get_dataloaders(opt):
    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'

    if opt.dataset == 'toy':

        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100_toy(args=opt,
                                               partition=train_partition,
                                               transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100_toy(args=opt,
                                             partition='train',
                                             transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        #         train_trans, test_trans = transforms_test_options[opt.transform]

        #         meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
        #                                                   train_transform=train_trans,
        #                                                   test_transform=test_trans),
        #                                      batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
        #                                      num_workers=opt.num_workers)
        #         meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
        #                                                  train_transform=train_trans,
        #                                                  test_transform=test_trans),
        #                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
        #                                     num_workers=opt.num_workers)
        n_cls = 5

        return train_loader, val_loader, 5, 5, n_cls

    if opt.dataset == 'miniImageNet':

        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)

        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)

        train_trans, test_trans = transforms_test_options[opt.transform]

        #         ns = [opt.n_shots].copy()
        #         opt.n_ways = 32
        #         opt.n_shots = 5
        #         opt.n_aug_support_samples = 2
        meta_trainloader = DataLoader(MetaCIFAR100(args=opt,
                                                   partition='train',
                                                   train_transform=train_trans,
                                                   test_transform=test_trans),
                                      batch_size=1,
                                      shuffle=True,
                                      drop_last=False,
                                      num_workers=opt.num_workers)

        #         opt.n_ways = 5
        #         opt.n_shots = ns[0]
        #         print(opt.n_shots)
        #         opt.n_aug_support_samples = 5
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
#         return train_loader, val_loader, meta_trainloader, meta_testloader, meta_valloader, n_cls
    else:
        raise NotImplementedError(opt.dataset)

    return train_loader, val_loader, meta_testloader, meta_valloader, n_cls