コード例 #1
0
    def validate_few_shot(self):
        # start = time.time()
        # train_acc, train_std = meta_test(self.model, self.meta_trainloader)
        # train_time = time.time() - start
        # print('train_acc: {:.4f}, train_std: {:.4f}, time: {:.1f}'.format(train_acc, train_std, train_time))
        #
        # start = time.time()
        # train_acc_feat, train_std_feat = meta_test(self.model, self.meta_trainloader, use_logit=False)
        # train_time = time.time() - start
        # print('train_acc_feat: {:.4f}, train_std: {:.4f}, time: {:.1f}'.format(
        #     train_acc_feat, train_std_feat, train_time))

        start = time.time()
        val_acc, val_std = meta_test(self.model, self.meta_valloader)
        val_time = time.time() - start
        print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc, val_std, val_time))

        start = time.time()
        val_acc_feat, val_std_feat = meta_test(self.model, self.meta_valloader, use_logit=False)
        val_time = time.time() - start
        print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc_feat, val_std_feat, val_time))

        start = time.time()
        test_acc, test_std = meta_test(self.model, self.meta_testloader)
        test_time = time.time() - start
        print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))

        start = time.time()
        test_acc_feat, test_std_feat = meta_test(self.model, self.meta_testloader, use_logit=False)
        test_time = time.time() - start
        print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
            test_acc_feat, test_std_feat, test_time))
        pass
コード例 #2
0
def main():

    opt = parse_option()

    opt.n_test_runs = 600
    train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(
        opt)

    # load model
    model = create_model(opt.model, n_cls, opt.dataset)
    ckpt = torch.load(opt.model_path)
    model.load_state_dict(ckpt["model"])

    if torch.cuda.is_available():
        model = model.cuda()
        cudnn.benchmark = True

    start = time.time()
    test_acc, test_std = meta_test(model, meta_testloader)
    test_time = time.time() - start
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc, test_std, test_time))

    start = time.time()
    test_acc_feat, test_std_feat = meta_test(model,
                                             meta_testloader,
                                             use_logit=False)
    test_time = time.time() - start
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))
コード例 #3
0
ファイル: eval.py プロジェクト: guybuk/ANCOR
def evaluate(meta_valloader, model, args, mode):
    start = time.time()
    val_acc1, val_std1 = meta_test(model,
                                   meta_valloader,
                                   only_base=args.only_base,
                                   classifier=args.cls,
                                   is_norm=True)
    val_time = time.time() - start
    print(f'Mode: ' + mode)
    print(f'Partition: {args.partition} Accuracy: {round(val_acc1 * 100, 2)}' +
          u" \u00B1 " + f'{round(val_std1 * 100, 2)}, Time: {val_time}')
コード例 #4
0
def main():

    opt = parse_option()

    opt.n_test_runs = 600
    train_loader, val_loader, meta_testloader, meta_valloader, n_cls, _ = get_dataloaders(
        opt)

    # load model
    model = create_model(opt.model, n_cls, opt.dataset)
    ckpt = torch.load(opt.model_path)["model"]

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in ckpt.items():
        name = k.replace("module.", "")
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)

    # model.load_state_dict(ckpt["model"])

    if torch.cuda.is_available():
        model = model.cuda()
        cudnn.benchmark = True

    start = time.time()
    test_acc, test_std = meta_test(model, meta_testloader)
    test_time = time.time() - start
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc, test_std, test_time))

    start = time.time()
    test_acc_feat, test_std_feat = meta_test(model,
                                             meta_testloader,
                                             use_logit=False)
    test_time = time.time() - start
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))
コード例 #5
0
ファイル: my_test.py プロジェクト: MichalisLazarou/rfs
def main():
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    opt = parse_option()
    train_trans, test_trans = transforms_test_options[opt.transform]
    meta_testloader = DataLoader(MetaImageNet(args=opt,
                                              partition='test',
                                              train_transform=train_trans,
                                              test_transform=test_trans,
                                              fix_seed=False),
                                 batch_size=opt.test_batch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 num_workers=opt.num_workers)
    meta_valloader = DataLoader(MetaImageNet(args=opt,
                                             partition='val',
                                             train_transform=train_trans,
                                             test_transform=test_trans,
                                             fix_seed=False),
                                batch_size=opt.test_batch_size,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers)
    n_cls = 64
    model = create_model(opt.model, n_cls, opt.dataset)
    ckpt = torch.load(opt.model_path)
    model.load_state_dict(ckpt['model'])

    if torch.cuda.is_available():
        model = model.cuda()
        cudnn.benchmark = True

    start = time.time()
    test_acc, test_std = meta_test(model,
                                   meta_testloader,
                                   use_logit=False,
                                   classifier='original_avrithis')
    test_time = time.time() - start
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc, test_std, test_time))
コード例 #6
0
def main():

    opt = parse_option()

    # test loader
    args = opt

    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans,
                                                  fix_seed=False),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans,
                                                 fix_seed=False),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_test_options[opt.transform]
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans,
            fix_seed=False),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans,
            fix_seed=False),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_test_options['D']
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans,
                                                  fix_seed=False),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans,
                                                 fix_seed=False),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # load model
    model = create_model(opt.model, n_cls, opt.dataset)
    ckpt = torch.load(opt.model_path)
    model.load_state_dict(ckpt['model'])

    if torch.cuda.is_available():
        model = model.cuda()
        cudnn.benchmark = True

    # evalation
    start = time.time()
    val_acc, val_std = meta_test(model, meta_valloader)
    val_time = time.time() - start
    print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc, val_std, val_time))

    start = time.time()
    val_acc_feat, val_std_feat = meta_test(model,
                                           meta_valloader,
                                           use_logit=False)
    val_time = time.time() - start
    print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc_feat, val_std_feat, val_time))

    start = time.time()
    test_acc, test_std = meta_test(model, meta_testloader)
    test_time = time.time() - start
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc, test_std, test_time))

    start = time.time()
    test_acc_feat, test_std_feat = meta_test(model,
                                             meta_testloader,
                                             use_logit=False)
    test_time = time.time() - start
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))
コード例 #7
0
def main():

    opt = parse_option()

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt,
                                                 partition=train_partition,
                                                 transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt,
                                           partition=train_partition,
                                           transform=train_trans),
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model = create_model(opt.model, n_cls, opt.dataset)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)

    # routine: supervised pre-training
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, model, criterion,
                                      optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('train_loss', train_loss, epoch)

        test_acc, test_acc_top5, test_loss = validate(val_loader, model,
                                                      criterion, opt)

        logger.log_value('test_acc', test_acc, epoch)
        logger.log_value('test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_loss', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch':
                epoch,
                'model':
                model.state_dict()
                if opt.n_gpu <= 1 else model.module.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)

    # few-shot evaluation
    start = time.time()
    val_acc, val_std = meta_test(model, meta_valloader, use_logit=True)
    val_time = time.time() - start
    print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc, val_std, val_time))

    start = time.time()
    val_acc_feat, val_std_feat = meta_test(model,
                                           meta_valloader,
                                           use_logit=False)
    val_time = time.time() - start
    print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc_feat, val_std_feat, val_time))

    start = time.time()
    test_acc, test_std = meta_test(model, meta_testloader, use_logit=True)
    test_time = time.time() - start
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc, test_std, test_time))

    start = time.time()
    test_acc_feat, test_std_feat = meta_test(model,
                                             meta_testloader,
                                             use_logit=False)
    test_time = time.time() - start
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))

    logger.log_value('meta_val_acc', val_acc, opt.epochs)
    logger.log_value('meta_val_acc_feat', val_acc_feat, opt.epochs)
    logger.log_value('meta_test_acc', test_acc, opt.epochs)
    logger.log_value('meta_test_acc_feat', test_acc_feat, opt.epochs)

    # save the last model
    state = {
        'opt':
        opt,
        'model':
        model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)
コード例 #8
0
def main():

    opt = parse_option()

    if opt.name is not None:
        wandb.init(name=opt.name)
    else:
        wandb.init()
    wandb.config.update(opt)

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test',
                                                        train_transform=train_trans,
                                                        test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val',
                                                       train_transform=train_trans,
                                                       test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']

        train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(opt.dataset))
    elif opt.dataset == 'CUB_200_2011':
        train_trans, test_trans = transforms_options['C']

        vocab = lang_utils.load_vocab(opt.lang_dir) if opt.lsl else None
        devocab = {v:k for k,v in vocab.items()} if opt.lsl else None

        train_loader = DataLoader(CUB2011(args=opt, partition=train_partition, transform=train_trans,
                                          vocab=vocab),
                                  batch_size=opt.batch_size, shuffle=True, drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CUB2011(args=opt, partition='val', transform=test_trans, vocab=vocab),
                                batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCUB2011(args=opt, partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans, vocab=vocab),
                                     batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCUB2011(args=opt, partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans, vocab=vocab),
                                    batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            raise NotImplementedError(opt.dataset) # no trainval supported yet
            n_cls = 150
        else:
            n_cls = 100
    else:
        raise NotImplementedError(opt.dataset)

    print('Amount training data: {}'.format(len(train_loader.dataset)))
    print('Amount val data:      {}'.format(len(val_loader.dataset)))

    # model
    model = create_model(opt.model, n_cls, opt.dataset)

    # optimizer
    if opt.adam:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opt.learning_rate,
                                     weight_decay=0.0005)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum,
                              weight_decay=opt.weight_decay)

    criterion = nn.CrossEntropyLoss()


    # lsl
    lang_model = None
    if opt.lsl:
        if opt.glove_init:
            vecs = lang_utils.glove_init(vocab, emb_size=opt.lang_emb_size)
        embedding_model = nn.Embedding(
            len(vocab), opt.lang_emb_size, _weight=vecs if opt.glove_init else None
        )
        if opt.freeze_emb:
            embedding_model.weight.requires_grad = False

        lang_input_size = n_cls if opt.use_logit else 640 # 640 for resnet12
        lang_model = TextProposal(
            embedding_model,
            input_size=lang_input_size,
            hidden_size=opt.lang_hidden_size,
            project_input=lang_input_size != opt.lang_hidden_size,
            rnn=opt.rnn_type,
            num_layers=opt.rnn_num_layers,
            dropout=opt.rnn_dropout,
            vocab=vocab,
            **lang_utils.get_special_indices(vocab)
        )


    if torch.cuda.is_available():
        if opt.n_gpu > 1:
            model = nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True
        if opt.lsl:
            embedding_model = embedding_model.cuda()
            lang_model = lang_model.cuda()

    # tensorboard
    #logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1)

    # routine: supervised pre-training
    best_val_acc = 0
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss, train_lang_loss = train(
            epoch, train_loader, model, criterion, optimizer, opt, lang_model,
            devocab=devocab if opt.lsl else None
        )
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))


        print("==> validating...")
        test_acc, test_acc_top5, test_loss, test_lang_loss = validate(
            val_loader, model, criterion, opt, lang_model
        )

        # wandb
        log_metrics = {
            'train_acc':     train_acc,
            'train_loss':    train_loss,
            'val_acc':      test_acc,
            'val_acc_top5': test_acc_top5,
            'val_loss':     test_loss
        }
        if opt.lsl:
            log_metrics['train_lang_loss'] = train_lang_loss
            log_metrics['val_lang_loss']  = test_lang_loss
        wandb.log(log_metrics, step=epoch)

        # # regular saving
        # if epoch % opt.save_freq == 0 and not opt.dryrun:
        #     print('==> Saving...')
        #     state = {
        #         'epoch': epoch,
        #         'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
        #     }
        #     save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
        #     torch.save(state, save_file)


        if test_acc > best_val_acc:
            wandb.run.summary['best_val_acc'] = test_acc
            wandb.run.summary['best_val_acc_epoch'] = epoch

    # save the last model
    state = {
        'opt': opt,
        'model': model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
    }
    save_file = os.path.join(wandb.run.dir, '{}_last.pth'.format(opt.model))
    torch.save(state, save_file)

    # evaluate on test set
    print("==> testing...")
    start = time.time()
    (test_acc, test_std), (test_acc5, test_std5) = meta_test(model, meta_testloader)
    test_time = time.time() - start
    print('Using logit layer for embedding')
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))
    print('test_acc top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(test_acc5, test_std5, test_time))

    start = time.time()
    (test_acc_feat, test_std_feat), (test_acc5_feat, test_std5_feat)  = meta_test(model, meta_testloader,
                                                                              use_logit=False)
    test_time = time.time() - start
    print('Using layer before logits for embedding')
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))
    print('test_acc_feat top 5: {:.4f}, test_std top 5: {:.4f}, time: {:.1f}'.format(
        test_acc5_feat, test_std5_feat, test_time))

    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_std'] = test_std
    wandb.run.summary['test_acc5'] = test_acc5
    wandb.run.summary['test_std5'] = test_std5
    wandb.run.summary['test_acc_feat'] = test_acc_feat
    wandb.run.summary['test_std_feat'] = test_std_feat
    wandb.run.summary['test_acc5_feat'] = test_acc5_feat
    wandb.run.summary['test_std5_feat'] = test_std5_feat
コード例 #9
0
ファイル: train_mrcl.py プロジェクト: sebamenabar/rfs
def main():
    opt = parse_option()

    print(pp.pformat(vars(opt)))

    train_partition = "trainval" if opt.use_trainval else "train"
    if opt.dataset == "miniImageNet":
        train_trans, test_trans = transforms_options[opt.transform]

        if opt.augment == "none":
            train_train_trans = train_test_trans = test_trans
        elif opt.augment == "all":
            train_train_trans = train_test_trans = train_trans
        elif opt.augment == "spt":
            train_train_trans = train_trans
            train_test_trans = test_trans
        elif opt.augment == "qry":
            train_train_trans = test_trans
            train_test_trans = train_trans

        print("spt trans")
        print(train_train_trans)
        print("qry trans")
        print(train_test_trans)

        sub_batch_size, rmd = divmod(opt.batch_size, opt.apply_every)
        assert rmd == 0
        print("Train sub batch-size:", sub_batch_size)

        meta_train_dataset = MetaImageNet(
            args=opt,
            partition="train",
            train_transform=train_train_trans,
            test_transform=train_test_trans,
            fname="miniImageNet_category_split_train_phase_%s.pickle",
            fix_seed=False,
            n_test_runs=10000000,  # big number to never stop
            new_labels=False,
        )
        meta_trainloader = DataLoader(
            meta_train_dataset,
            batch_size=sub_batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        meta_train_dataset_qry = MetaImageNet(
            args=opt,
            partition="train",
            train_transform=train_train_trans,
            test_transform=train_test_trans,
            fname="miniImageNet_category_split_train_phase_%s.pickle",
            fix_seed=False,
            n_test_runs=10000000,  # big number to never stop
            new_labels=False,
            n_ways=opt.n_qry_way,
            n_shots=opt.n_qry_shot,
            n_queries=0,
        )
        meta_trainloader_qry = DataLoader(
            meta_train_dataset_qry,
            batch_size=sub_batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        meta_val_dataset = MetaImageNet(
            args=opt,
            partition="val",
            train_transform=test_trans,
            test_transform=test_trans,
            fix_seed=False,
            n_test_runs=200,
            n_ways=5,
            n_shots=5,
            n_queries=15,
        )
        meta_valloader = DataLoader(
            meta_val_dataset,
            batch_size=opt.test_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        val_loader = DataLoader(
            ImageNet(args=opt, partition="val", transform=test_trans),
            batch_size=opt.sup_val_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=opt.num_workers,
            pin_memory=True,
        )
        # if opt.use_trainval:
        #     n_cls = 80
        # else:
        #     n_cls = 64
        n_cls = len(meta_train_dataset.classes)

    print(n_cls)

    # x_spt, y_spt, x_qry, y_qry = next(iter(meta_trainloader))

    # x_spt2, y_spt2, x_qry2, y_qry2 = next(iter(meta_trainloader_qry))

    # print(x_spt, y_spt, x_qry, y_qry)
    # print(x_spt2, y_spt2, x_qry2, y_qry2)
    # print(x_spt.shape, y_spt.shape, x_qry.shape, y_qry.shape)
    # print(x_spt2.shape, y_spt2.shape, x_qry2.shape, y_qry2.shape)

    model = create_model(
        opt.model,
        n_cls,
        opt.dataset,
        opt.drop_rate,
        opt.dropblock,
        opt.track_stats,
        opt.initializer,
        opt.weight_norm,
        activation=opt.activation,
        normalization=opt.normalization,
    )

    print(model)

    criterion = nn.CrossEntropyLoss()

    if torch.cuda.is_available():
        print(torch.cuda.get_device_name())
        device = torch.device("cuda")
        # if opt.n_gpu > 1:
        #     model = nn.DataParallel(model)
        model = model.to(device)
        criterion = criterion.to(device)
        cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    print("Learning rate")
    print(opt.learning_rate)
    print("Inner Learning rate")
    print(opt.inner_lr)
    if opt.learn_lr:
        print("Optimizing learning rate")
    inner_lr = nn.Parameter(torch.tensor(opt.inner_lr),
                            requires_grad=opt.learn_lr)
    optimizer = torch.optim.Adam(
        list(model.parameters()) +
        [inner_lr] if opt.learn_lr else model.parameters(),
        lr=opt.learning_rate,
    )
    # classifier = model.classifier()
    inner_opt = torch.optim.SGD(
        model.classifier.parameters(),
        lr=opt.inner_lr,
    )
    logger = SummaryWriter(logdir=opt.tb_folder,
                           flush_secs=10,
                           comment=opt.model_name)
    comet_logger = Experiment(
        api_key=os.environ["COMET_API_KEY"],
        project_name=opt.comet_project_name,
        workspace=opt.comet_workspace,
        disabled=not opt.logcomet,
        auto_metric_logging=False,
    )
    comet_logger.set_name(opt.model_name)
    comet_logger.log_parameters(vars(opt))
    comet_logger.set_model_graph(str(model))

    if opt.cosine:
        eta_min = opt.learning_rate * opt.cosine_factor
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.num_steps, eta_min, -1)

    # routine: supervised pre-training
    data_sampler = iter(meta_trainloader)
    data_sampler_qry = iter(meta_trainloader_qry)
    pbar = tqdm(
        range(1, opt.num_steps + 1),
        miniters=opt.print_freq,
        mininterval=3,
        maxinterval=30,
        ncols=0,
    )
    best_val_acc = 0.0
    for step in pbar:

        if not opt.cosine:
            adjust_learning_rate(step, opt, optimizer)
        # print("==> training...")

        time1 = time.time()

        foa = 0.0
        fol = 0.0
        ioa = 0.0
        iil = 0.0
        fil = 0.0
        iia = 0.0
        fia = 0.0
        for j in range(opt.apply_every):

            x_spt, y_spt, x_qry, y_qry = [
                t.to(device) for t in next(data_sampler)
            ]
            x_qry2, y_qry2, _, _ = [
                t.to(device) for t in next(data_sampler_qry)
            ]
            y_spt = y_spt.flatten(1)
            y_qry2 = y_qry2.flatten(1)

            x_qry = torch.cat((x_spt, x_qry, x_qry2), 1)
            y_qry = torch.cat((y_spt, y_qry, y_qry2), 1)

            if step == 1 and j == 0:
                print(x_spt.size(), y_spt.size(), x_qry.size(), y_qry.size())

            info = train_step(
                model,
                model.classifier,
                None,
                # inner_opt,
                inner_lr,
                x_spt,
                y_spt,
                x_qry,
                y_qry,
                reset_head=opt.reset_head,
                num_steps=opt.num_inner_steps,
            )

            _foa = info["foa"] / opt.batch_size
            _fol = info["fol"] / opt.batch_size
            _ioa = info["ioa"] / opt.batch_size
            _iil = info["iil"] / opt.batch_size
            _fil = info["fil"] / opt.batch_size
            _iia = info["iia"] / opt.batch_size
            _fia = info["fia"] / opt.batch_size

            _fol.backward()

            foa += _foa.detach()
            fol += _fol.detach()
            ioa += _ioa.detach()
            iil += _iil.detach()
            fil += _fil.detach()
            iia += _iia.detach()
            fia += _fia.detach()

        optimizer.step()
        optimizer.zero_grad()
        inner_lr.data.clamp_(min=0.001)

        if opt.cosine:
            scheduler.step()
        if (step == 1) or (step % opt.eval_freq == 0):
            val_info = test_run(
                iter(meta_valloader),
                model,
                model.classifier,
                torch.optim.SGD(model.classifier.parameters(),
                                lr=inner_lr.item()),
                num_inner_steps=opt.num_inner_steps_test,
                device=device,
            )
            val_acc_feat, val_std_feat = meta_test(
                model,
                meta_valloader,
                use_logit=False,
            )

            val_acc = val_info["outer"]["acc"].cpu()
            val_loss = val_info["outer"]["loss"].cpu()

            sup_acc, sup_acc_top5, sup_loss = validate(
                val_loader,
                model,
                criterion,
                print_freq=100000000,
            )
            sup_acc = sup_acc.item()
            sup_acc_top5 = sup_acc_top5.item()

            print(f"\nValidation step {step}")
            print(f"MAML 5-way-5-shot accuracy: {val_acc.item()}")
            print(f"LR 5-way-5-shot accuracy: {val_acc_feat}+-{val_std_feat}")
            print(
                f"Supervised accuracy: Acc@1: {sup_acc} Acc@5: {sup_acc_top5} Loss: {sup_loss}"
            )

            if val_acc_feat > best_val_acc:
                best_val_acc = val_acc_feat
                print(
                    f"New best validation accuracy {best_val_acc.item()} saving checkpoints\n"
                )
                # print(val_acc.item())

                torch.save(
                    {
                        "opt":
                        opt,
                        "model":
                        model.state_dict()
                        if opt.n_gpu <= 1 else model.module.state_dict(),
                        "optimizer":
                        optimizer.state_dict(),
                        "step":
                        step,
                        "val_acc":
                        val_acc,
                        "val_loss":
                        val_loss,
                        "val_acc_lr":
                        val_acc_feat,
                        "sup_acc":
                        sup_acc,
                        "sup_acc_top5":
                        sup_acc_top5,
                        "sup_loss":
                        sup_loss,
                    },
                    os.path.join(opt.save_folder,
                                 "{}_best.pth".format(opt.model)),
                )

            comet_logger.log_metrics(
                dict(
                    fol=val_loss,
                    foa=val_acc,
                    acc_lr=val_acc_feat,
                    sup_acc=sup_acc,
                    sup_acc_top5=sup_acc_top5,
                    sup_loss=sup_loss,
                ),
                step=step,
                prefix="val",
            )

            logger.add_scalar("val_acc", val_acc, step)
            logger.add_scalar("val_loss", val_loss, step)
            logger.add_scalar("val_acc_lr", val_acc_feat, step)
            logger.add_scalar("sup_acc", sup_acc, step)
            logger.add_scalar("sup_acc_top5", sup_acc_top5, step)
            logger.add_scalar("sup_loss", sup_loss, step)

        if (step == 1) or (step % opt.eval_freq == 0) or (step % opt.print_freq
                                                          == 0):

            tfol = fol.cpu()
            tfoa = foa.cpu()
            tioa = ioa.cpu()
            tiil = iil.cpu()
            tfil = fil.cpu()
            tiia = iia.cpu()
            tfia = fia.cpu()

            comet_logger.log_metrics(
                dict(
                    fol=tfol,
                    foa=tfoa,
                    ioa=tfoa,
                    iil=tiil,
                    fil=tfil,
                    iia=tiia,
                    fia=tfia,
                ),
                step=step,
                prefix="train",
            )

            logger.add_scalar("train_acc", tfoa.item(), step)
            logger.add_scalar("train_loss", tfol.item(), step)
            logger.add_scalar("train_ioa", tioa, step)
            logger.add_scalar("train_iil", tiil, step)
            logger.add_scalar("train_fil", tfil, step)
            logger.add_scalar("train_iia", tiia, step)
            logger.add_scalar("train_fia", tfia, step)

            pbar.set_postfix(
                # iol=f"{info['iol'].item():.2f}",
                fol=f"{tfol.item():.2f}",
                # ioa=f"{info['ioa'].item():.2f}",
                foa=f"{tfoa.item():.2f}",
                ioa=f"{tioa.item():.2f}",
                iia=f"{tiia.item():.2f}",
                fia=f"{tfia.item():.2f}",
                vl=f"{val_loss.item():.2f}",
                va=f"{val_acc.item():.2f}",
                valr=f"{val_acc_feat:.2f}",
                lr=f"{inner_lr.item():.4f}",
                vsa=f"{sup_acc:.2f}",
                # iil=f"{info['iil'].item():.2f}",
                # fil=f"{info['fil'].item():.2f}",
                # iia=f"{info['iia'].item():.2f}",
                # fia=f"{info['fia'].item():.2f}",
                # counter=info["counter"],
                refresh=True,
            )

    # save the last model
    state = {
        "opt":
        opt,
        "model":
        model.state_dict() if opt.n_gpu <= 1 else model.module.state_dict(),
        "optimizer":
        optimizer.state_dict(),
        "step":
        step,
    }
    save_file = os.path.join(opt.save_folder, "{}_last.pth".format(opt.model))
    torch.save(state, save_file)
コード例 #10
0
ファイル: train_distillation.py プロジェクト: lilujunai/rfs
def main():
    best_acc = 0

    opt = parse_option()

    # tensorboard logger
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # dataloader
    train_partition = 'trainval' if opt.use_trainval else 'train'
    if opt.dataset == 'miniImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        if opt.distill in ['contrast']:
            train_set = ImageNet(args=opt,
                                 partition=train_partition,
                                 transform=train_trans,
                                 is_sample=True,
                                 k=opt.nce_k)
        else:
            train_set = ImageNet(args=opt,
                                 partition=train_partition,
                                 transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(ImageNet(args=opt,
                                         partition='val',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaImageNet(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaImageNet(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        train_trans, test_trans = transforms_options[opt.transform]
        if opt.distill in ['contrast']:
            train_set = TieredImageNet(args=opt,
                                       partition=train_partition,
                                       transform=train_trans,
                                       is_sample=True,
                                       k=opt.nce_k)
        else:
            train_set = TieredImageNet(args=opt,
                                       partition=train_partition,
                                       transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(TieredImageNet(args=opt,
                                               partition='train_phase_val',
                                               transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='test',
            train_transform=train_trans,
            test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaTieredImageNet(
            args=opt,
            partition='val',
            train_transform=train_trans,
            test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 448
        else:
            n_cls = 351
    elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
        train_trans, test_trans = transforms_options['D']
        if opt.distill in ['contrast']:
            train_set = CIFAR100(args=opt,
                                 partition=train_partition,
                                 transform=train_trans,
                                 is_sample=True,
                                 k=opt.nce_k)
        else:
            train_set = CIFAR100(args=opt,
                                 partition=train_partition,
                                 transform=train_trans)
        n_data = len(train_set)
        train_loader = DataLoader(train_set,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(CIFAR100(args=opt,
                                         partition='train',
                                         transform=test_trans),
                                batch_size=opt.batch_size // 2,
                                shuffle=False,
                                drop_last=False,
                                num_workers=opt.num_workers // 2)
        meta_testloader = DataLoader(MetaCIFAR100(args=opt,
                                                  partition='test',
                                                  train_transform=train_trans,
                                                  test_transform=test_trans),
                                     batch_size=opt.test_batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=opt.num_workers)
        meta_valloader = DataLoader(MetaCIFAR100(args=opt,
                                                 partition='val',
                                                 train_transform=train_trans,
                                                 test_transform=test_trans),
                                    batch_size=opt.test_batch_size,
                                    shuffle=False,
                                    drop_last=False,
                                    num_workers=opt.num_workers)
        if opt.use_trainval:
            n_cls = 80
        else:
            if opt.dataset == 'CIFAR-FS':
                n_cls = 64
            elif opt.dataset == 'FC100':
                n_cls = 60
            else:
                raise NotImplementedError('dataset not supported: {}'.format(
                    opt.dataset))
    else:
        raise NotImplementedError(opt.dataset)

    # model
    model_t = load_teacher(opt.path_t, n_cls, opt.dataset)
    model_s = create_model(opt.model_s, n_cls, opt.dataset)

    data = torch.randn(2, 3, 84, 84)
    model_t.eval()
    model_s.eval()
    feat_t, _ = model_t(data, is_feat=True)
    feat_s, _ = model_s(data, is_feat=True)

    module_list = nn.ModuleList([])
    module_list.append(model_s)
    trainable_list = nn.ModuleList([])
    trainable_list.append(model_s)

    criterion_cls = nn.CrossEntropyLoss()
    criterion_div = DistillKL(opt.kd_T)
    if opt.distill == 'kd':
        criterion_kd = DistillKL(opt.kd_T)
    elif opt.distill == 'contrast':
        criterion_kd = NCELoss(opt, n_data)
        embed_s = Embed(feat_s[-1].shape[1], opt.feat_dim)
        embed_t = Embed(feat_t[-1].shape[1], opt.feat_dim)
        module_list.append(embed_s)
        module_list.append(embed_t)
        trainable_list.append(embed_s)
        trainable_list.append(embed_t)
    elif opt.distill == 'attention':
        criterion_kd = Attention()
    elif opt.distill == 'hint':
        criterion_kd = HintLoss()
    else:
        raise NotImplementedError(opt.distill)

    criterion_list = nn.ModuleList([])
    criterion_list.append(criterion_cls)  # classification loss
    criterion_list.append(
        criterion_div)  # KL divergence loss, original knowledge distillation
    criterion_list.append(criterion_kd)  # other knowledge distillation loss

    # optimizer
    optimizer = optim.SGD(trainable_list.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    # append teacher after optimizer to avoid weight_decay
    module_list.append(model_t)

    if torch.cuda.is_available():
        module_list.cuda()
        criterion_list.cuda()
        cudnn.benchmark = True

    # validate teacher accuracy
    teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt)
    print('teacher accuracy: ', teacher_acc)

    # set cosine annealing scheduler
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate**3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, opt.epochs, eta_min, -1)

    # routine: supervised model distillation
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, module_list,
                                      criterion_list, optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('train_loss', train_loss, epoch)

        test_acc, test_acc_top5, test_loss = validate(val_loader, model_s,
                                                      criterion_cls, opt)

        logger.log_value('test_acc', test_acc, epoch)
        logger.log_value('test_acc_top5', test_acc_top5, epoch)
        logger.log_value('test_loss', test_loss, epoch)

        # regular saving
        if epoch % opt.save_freq == 0:
            print('==> Saving...')
            state = {
                'epoch': epoch,
                'model': model_s.state_dict(),
            }
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)

    # few-shot evaluation
    start = time.time()
    val_acc, val_std = meta_test(model_s, meta_valloader, use_logit=True)
    val_time = time.time() - start
    print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc, val_std, val_time))

    start = time.time()
    val_acc_feat, val_std_feat = meta_test(model_s,
                                           meta_valloader,
                                           use_logit=False)
    val_time = time.time() - start
    print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(
        val_acc_feat, val_std_feat, val_time))

    start = time.time()
    test_acc, test_std = meta_test(model_s, meta_testloader, use_logit=True)
    test_time = time.time() - start
    print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc, test_std, test_time))

    start = time.time()
    test_acc_feat, test_std_feat = meta_test(model_s,
                                             meta_testloader,
                                             use_logit=False)
    test_time = time.time() - start
    print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(
        test_acc_feat, test_std_feat, test_time))

    logger.log_value('meta_val_acc', val_acc, opt.epochs)
    logger.log_value('meta_val_acc_feat', val_acc_feat, opt.epochs)
    logger.log_value('meta_test_acc', test_acc, opt.epochs)
    logger.log_value('meta_test_acc_feat', test_acc_feat, opt.epochs)

    # save the last model
    state = {
        'opt': opt,
        'model': model_s.state_dict(),
    }
    save_file = os.path.join(opt.save_folder,
                             '{}_last.pth'.format(opt.model_s))
    torch.save(state, save_file)
コード例 #11
0
ファイル: eval_fewshot.py プロジェクト: ChengJiacheng/rfs
    # start = time.time()
    # val_acc_feat, val_std_feat = meta_test(model, meta_valloader, use_logit=False, classifier=opt.classifier)
    # val_time = time.time() - start
    # print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc_feat, val_std_feat, val_time))

    # start = time.time()
    # test_acc, test_std = meta_test(model, meta_testloader, classifier=opt.classifier, model_list=model_list)
    # test_time = time.time() - start
    # print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))

    # start = time.time()
    # test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False, classifier=opt.classifier, model_list=model_list)
    # test_time = time.time() - start
    # print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc_feat, test_std_feat, test_time))

    # start = time.time()
    # test_acc, test_std = meta_test(model, meta_testloader, is_norm=False, classifier=opt.classifier, model_list=model_list)
    # test_time = time.time() - start
    # print('test_acc (no normalization): {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))

    start = time.time()
    test_acc_feat, test_std_feat = meta_test(model,
                                             meta_testloader,
                                             use_logit=False,
                                             is_norm=False,
                                             classifier=opt.classifier)
    test_time = time.time() - start
    print(
        'test_acc_feat (no normalization): {:.4f}, test_std: {:.4f}, time: {:.1f}'
        .format(test_acc_feat, test_std_feat, test_time))
コード例 #12
0
ファイル: util.py プロジェクト: yyht/SKD
def generate_final_report(model, opt, wandb):
    from eval.meta_eval import meta_test

    opt.n_shots = 1
    train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(
        opt)

    #validate
    meta_val_acc, meta_val_std = meta_test(model, meta_valloader)

    meta_val_acc_feat, meta_val_std_feat = meta_test(model,
                                                     meta_valloader,
                                                     use_logit=False)

    #evaluate
    meta_test_acc, meta_test_std = meta_test(model, meta_testloader)

    meta_test_acc_feat, meta_test_std_feat = meta_test(model,
                                                       meta_testloader,
                                                       use_logit=False)

    print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(
        meta_val_acc, meta_val_std))
    print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(
        meta_val_acc_feat, meta_val_std_feat))
    print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(
        meta_test_acc, meta_test_std))
    print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(
        meta_test_acc_feat, meta_test_std_feat))

    wandb.log({
        'Final Meta Test Acc @1': meta_test_acc,
        'Final Meta Test std @1': meta_test_std,
        'Final Meta Test Acc  (feat) @1': meta_test_acc_feat,
        'Final Meta Test std  (feat) @1': meta_test_std_feat,
        'Final Meta Val Acc @1': meta_val_acc,
        'Final Meta Val std @1': meta_val_std,
        'Final Meta Val Acc   (feat) @1': meta_val_acc_feat,
        'Final Meta Val std   (feat) @1': meta_val_std_feat
    })

    opt.n_shots = 5
    train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(
        opt)

    #validate
    meta_val_acc, meta_val_std = meta_test(model, meta_valloader)

    meta_val_acc_feat, meta_val_std_feat = meta_test(model,
                                                     meta_valloader,
                                                     use_logit=False)

    #evaluate
    meta_test_acc, meta_test_std = meta_test(model, meta_testloader)

    meta_test_acc_feat, meta_test_std_feat = meta_test(model,
                                                       meta_testloader,
                                                       use_logit=False)

    print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(
        meta_val_acc, meta_val_std))
    print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(
        meta_val_acc_feat, meta_val_std_feat))
    print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(
        meta_test_acc, meta_test_std))
    print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(
        meta_test_acc_feat, meta_test_std_feat))

    wandb.log({
        'Final Meta Test Acc @5': meta_test_acc,
        'Final Meta Test std @5': meta_test_std,
        'Final Meta Test Acc  (feat) @5': meta_test_acc_feat,
        'Final Meta Test std  (feat) @5': meta_test_std_feat,
        'Final Meta Val Acc @5': meta_val_acc,
        'Final Meta Val std @5': meta_val_std,
        'Final Meta Val Acc   (feat) @5': meta_val_acc_feat,
        'Final Meta Val std   (feat) @5': meta_val_std_feat
    })
コード例 #13
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    logger = get_logger(name='log', log_dir='.')

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        logger.info("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    if args.dataset == 'tiered':
        train_dataset = TieredImageNet(
            root=args.data,
            partition='train',
            mode=args.mode,
            transform=ancor.loader.TwoCropsTransform(
                transforms.Compose(AUGS[f"train_{args.dataset}"])))
        val_dataset = MetaTieredImageNet(
            args=Box(data_root=args.data,
                     mode='fine',
                     n_ways=5,
                     n_shots=1,
                     n_queries=15,
                     n_test_runs=200,
                     n_aug_support_samples=5),
            partition='validation',
            train_transform=transforms.Compose(
                AUGS[f"meta_test_{args.dataset}"]),
            test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"]))
        fg_val_dataset = MetaFGTieredImageNet(
            args=Box(data_root=args.data,
                     mode='fine',
                     n_ways=5,
                     n_shots=1,
                     n_queries=15,
                     n_test_runs=200,
                     n_aug_support_samples=5),
            partition='validation',
            train_transform=transforms.Compose(
                AUGS[f"meta_test_{args.dataset}"]),
            test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"]))
    elif args.dataset == 'cifar100':
        train_transforms = transforms.Compose(
            AUGS[f"train_{args.dataset}"][1:])
        train_dataset = Cifar100(
            root=args.data,
            train=True,
            mode=args.mode,
            transform=ancor.loader.TwoCropsTransform(train_transforms))
        val_dataset = MetaCifar100(args=Box(
            data_root=args.data,
            mode='fine',
            n_ways=5,
            n_shots=1,
            n_queries=15,
            n_test_runs=200,
            n_aug_support_samples=5,
        ),
                                   partition='test',
                                   train_transform=transforms.Compose(
                                       AUGS[f"meta_test_{args.dataset}"]),
                                   test_transform=transforms.Compose(
                                       AUGS[f"test_{args.dataset}"]))
        fg_val_dataset = MetaFGCifar100(args=Box(
            data_root=args.data,
            mode='fine',
            n_ways=5,
            n_shots=1,
            n_queries=15,
            n_test_runs=200,
            n_aug_support_samples=5,
        ),
                                        partition='test',
                                        train_transform=transforms.Compose(
                                            AUGS[f"meta_test_{args.dataset}"]),
                                        test_transform=transforms.Compose(
                                            AUGS[f"test_{args.dataset}"]))
    elif args.dataset in ['living17', 'entity13', 'nonliving26', 'entity30']:
        breeds_factory = BREEDSFactory(
            info_dir=os.path.join(args.data, "BREEDS"),
            data_dir=os.path.join(args.data, "Data", "CLS-LOC"))
        train_dataset = breeds_factory.get_breeds(
            ds_name=args.dataset,
            partition='train',
            mode=args.mode,
            transforms=ancor.loader.TwoCropsTransform(
                transforms.Compose(AUGS[f"train_{args.dataset}"])),
            split=args.split)
        val_dataset = MetaDataset(
            args=Box(
                n_ways=5,
                n_shots=1,
                n_queries=15,
                n_test_runs=200,
                n_aug_support_samples=5,
            ),
            dataset=breeds_factory.get_breeds(ds_name=args.dataset,
                                              partition='val',
                                              mode='fine',
                                              transforms=None,
                                              split=args.split),
            train_transform=transforms.Compose(
                AUGS[f"meta_test_{args.dataset}"]),
            test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"]))
        fg_val_dataset = MetaDataset(
            args=Box(
                n_ways=5,
                n_shots=1,
                n_queries=15,
                n_test_runs=200,
                n_aug_support_samples=5,
            ),
            dataset=breeds_factory.get_breeds(ds_name=args.dataset,
                                              partition='val',
                                              mode='fine',
                                              transforms=None,
                                              split=args.split),
            train_transform=transforms.Compose(
                AUGS[f"meta_test_{args.dataset}"]),
            test_transform=transforms.Compose(AUGS[f"test_{args.dataset}"]),
            fg=True)
    else:
        raise NotImplementedError
    # create model
    model, criterions = ANCORModelGenerator().generate_ancor_model(
        arch=args.arch,
        head_type=args.head,
        dim=args.cst_dim,
        K=args.queue_k,
        m=args.encoder_m,
        T=args.cst_t,
        mlp=args.mlp,
        num_classes=train_dataset.num_classes,
        queue_type=args.queue,
        metric=args.metric,
        calc_types=args.calc_types,
        loss_types=args.loss_types,
        gpu=args.gpu)
    log(args.rank, logger, "loaded model")

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu], find_unused_parameters=True)
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(
                model, find_unused_parameters=True)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        # raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            log(args.rank, logger,
                "=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            msg = model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if 'best_accs' in checkpoint:
                best_accs = checkpoint['best_accs']
            else:
                best_accs = [0.]
                log(
                    args.rank, logger,
                    " WARNING: BACKWARDS COMPATIBLE RESUME. NO BEST MODEL CHECKPOINT"
                )
            log(
                args.rank, logger,
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
        else:
            log(args.rank, logger,
                "=> no checkpoint found at '{}'".format(args.resume))
            raise ValueError()
    else:
        best_accs = [0.]

    cudnn.benchmark = True

    # Data loading code

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    fg_val_loader = torch.utils.data.DataLoader(fg_val_dataset,
                                                batch_size=1,
                                                shuffle=False,
                                                num_workers=args.workers,
                                                pin_memory=True)
    for epoch in range(args.start_epoch, args.epochs):
        best_flag = False
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterions, optimizer, epoch, logger, args)
        if args.rank % ngpus_per_node == 0:
            if (epoch + 1) % args.save_freq == 0 and val_dataset is not None:
                val_acc, val_std = meta_test(model.module.encoder_q,
                                             val_loader,
                                             only_base=True,
                                             is_norm=True,
                                             classifier="Cosine")
                if best_accs[-1] < val_acc:
                    best_accs.append(val_acc)
                    with open("best_accs.log", 'a') as f:
                        f.write(
                            f"EPOCH {epoch}: Validation Accuracy: {round(val_acc * 100, 2)}+-{round(val_std * 100, 2)}\n"
                        )
                    best_flag = True
                log(
                    args.rank, logger,
                    f"EPOCH {epoch}: Validation Accuracy: {round(val_acc * 100, 2)}+-{round(val_std * 100, 2)}"
                )
            if (epoch +
                    1) % args.save_freq == 0 and fg_val_dataset is not None:
                val_acc, val_std = meta_test(model.module.encoder_q,
                                             fg_val_loader,
                                             only_base=True,
                                             is_norm=True,
                                             classifier="Cosine")
                log(
                    args.rank, logger,
                    f"EPOCH {epoch}: Validation FG - Accuracy: {round(val_acc * 100, 2)}+-{round(val_std * 100, 2)}"
                )

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if (epoch + 1) % args.save_freq == 0 or best_flag:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_accs': best_accs
                    },
                    is_best=best_flag,
                    filename='checkpoint_{:04d}.pth.tar'.format(epoch))
    if args.rank % ngpus_per_node == 0:
        remove_excess_epochs(args.keep_epochs)
コード例 #14
0
ファイル: train_distillation.py プロジェクト: yyht/SKD
def main():
    best_acc = 0

    opt = parse_option()
    wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags)
    wandb.config.update(opt)
    wandb.save('*.py')
    wandb.run.save()

    # dataloader
    train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(
        opt)

    # model
    model_t = []
    if ("," in opt.path_t):
        for path in opt.path_t.split(","):
            model_t.append(load_teacher(path, opt.model_t, n_cls, opt.dataset))
    else:
        model_t.append(
            load_teacher(opt.path_t, opt.model_t, n_cls, opt.dataset))


#     model_s = create_model(opt.model_s, n_cls, opt.dataset, dropout=0.4)
#     model_s = Wrapper(model_, opt)
    model_s = copy.deepcopy(model_t[0])

    criterion_cls = nn.CrossEntropyLoss()
    criterion_div = DistillKL(opt.kd_T)
    criterion_kd = DistillKL(opt.kd_T)

    optimizer = optim.SGD(model_s.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)

    if torch.cuda.is_available():
        for m in model_t:
            m.cuda()
        model_s.cuda()
        criterion_cls = criterion_cls.cuda()
        criterion_div = criterion_div.cuda()
        criterion_kd = criterion_kd.cuda()
        cudnn.benchmark = True

    meta_test_acc = 0
    meta_test_std = 0
    # routine: supervised model distillation
    for epoch in range(1, opt.epochs + 1):

        if opt.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, opt, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_loss = train(epoch, train_loader, model_s, model_t,
                                      criterion_cls, criterion_div,
                                      criterion_kd, optimizer, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        val_acc = 0
        val_loss = 0
        meta_val_acc = 0
        meta_val_std = 0
        #         val_acc, val_acc_top5, val_loss = validate(val_loader, model_s, criterion_cls, opt)

        #         #evaluate
        #         start = time.time()
        #         meta_val_acc, meta_val_std = meta_test(model_s, meta_valloader)
        #         test_time = time.time() - start
        #         print('Meta Val Acc: {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(meta_val_acc, meta_val_std, test_time))

        #evaluate

        start = time.time()
        meta_test_acc, meta_test_std = meta_test(model_s,
                                                 meta_testloader,
                                                 use_logit=False)
        test_time = time.time() - start
        print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.
              format(meta_test_acc, meta_test_std, test_time))

        # regular saving
        if epoch % opt.save_freq == 0 or epoch == opt.epochs:
            print('==> Saving...')
            state = {
                'epoch': epoch,
                'model': model_s.state_dict(),
            }
            save_file = os.path.join(opt.save_folder,
                                     'model_' + str(wandb.run.name) + '.pth')
            torch.save(state, save_file)

            #wandb saving
            torch.save(state, os.path.join(wandb.run.dir, "model.pth"))

        wandb.log({
            'epoch': epoch,
            'Train Acc': train_acc,
            'Train Loss': train_loss,
            'Val Acc': val_acc,
            'Val Loss': val_loss,
            'Meta Test Acc': meta_test_acc,
            'Meta Test std': meta_test_std,
            'Meta Val Acc': meta_val_acc,
            'Meta Val std': meta_val_std
        })

    #final report
    generate_final_report(model_s, opt, wandb)

    #remove output.txt log file
    output_log_file = os.path.join(wandb.run.dir, "output.log")
    if os.path.isfile(output_log_file):
        os.remove(output_log_file)
    else:  ## Show an error ##
        print("Error: %s file not found" % output_log_file)