Ejemplo n.º 1
0
def local_pretrained_cifar10_lossvar():
    # laptop
    args['data_dir'] = '../data/'
    args['loss_func'] = F.cross_entropy
    # args['learning_func_name'] = 'loss_var'
    args['learning_func_name'] = 'grad_var'
    args['stats_samplesize'] = 3
    args['num_eigens_hessian_approx'] = 1
    args['lr'] = 1e-9
    args['log_interval'] = 1
    args['batch_size'] = 128

    train_loader, test_loader = data.cifar10(args)

    batches = list(lib.iter_sample_fast(train_loader, args['stats_samplesize']))
    batch_loader = dataloader.get_subset_batch_loader(batches, args)
    args['subset_batches'] = True 
    print(f'\nTraining only on {args["stats_samplesize"]} batches of size {args["batch_size"]}!\n')

    # https://github.com/huyvnphan/PyTorch_CIFAR10
    model = resnet18(pretrained=True)

    trainer.main(model, batch_loader, test_loader, args)
    return
Ejemplo n.º 2
0
def cifar10():
    nb_per_class = 100
    (X_train, Y_train), (X_test, Y_test) = data.cifar10()
    (X_train, Y_train) = extract(X_train, Y_train, nb_per_class)
    return (X_train, Y_train), (X_test, Y_test)
Ejemplo n.º 3
0
def main():

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = import_module(f'model.{args.arch}').__dict__[args.teacher_model]().to(device)

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir, map_location=device)
    

    if args.arch == 'densenet':
        state_dict_t = {}
        for k, v in ckpt_t['state_dict'].items():
            new_key = '.'.join(k.split('.')[1:])
            if new_key == 'linear.weight':
                new_key = 'fc.weight'
            elif new_key == 'linear.bias':
                new_key = 'fc.bias'
            state_dict_t[new_key] = v
    else:
        state_dict_t = ckpt_t['state_dict']


    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(device)

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = import_module(f'model.{args.arch}').__dict__[args.student_model]().to(device)

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(device) 

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    param_s = [param for name, param in model_s.named_parameters() if 'mask' not in name]
    param_m = [param for name, param in model_s.named_parameters() if 'mask' in name]

    optimizer_s = optim.SGD(param_s, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume, map_location=device)
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']

        model_s.load_state_dict(ckpt['            state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))


    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]
    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print_logger.info(f"Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt', map_location=device)

    model = import_module('utils.preprocess').__dict__[f'{args.arch}'](args, best_model['state_dict_s'])
Ejemplo n.º 4
0
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = resnet_56().to(args.gpus[0])

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['state_dict']
    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = resnet_56_sparse().to(args.gpus[0])

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch,
              writer_train)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt',
                            map_location=torch.device(f"cuda:{args.gpus[0]}"))

    model = prune_resnet(args, best_model['state_dict_s'])
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    # model_t = resnet_56().to(args.gpus[0])
    #model_t = MobileNetV2()
    model_t = ResNet18()
    model_kd = ResNet101()

    print(model_kd)
    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['net']
    new_state_dict_t = OrderedDict()

    new_state_dict_t = state_dict_t
    #for k, v in state_dict_t.items():
    #print(k[0:6])
    #if k[0:6] == 'linear':
    #temp = v[0:10]
    #print(v[0:10].shape)
    #new_state_dict_t[k] = temp

    #model_t.load_state_dict(new_state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    #model_s = SpraseMobileNetV2().to(args.gpus[0])
    model_s = ResNet18_sprase().to(args.gpus[0])
    print(model_s)
    model_dict_s = model_s.state_dict()
    model_dict_s.update(new_state_dict_t)
    model_s.load_state_dict(model_dict_s)

    ckpt_kd = torch.load('resnet101.t7',
                         map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_kd = ckpt_kd['net']
    new_state_dict_kd = OrderedDict()
    for k, v in state_dict_kd.items():
        name = k[7:]
        new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
    model_kd.load_state_dict(new_state_dict_kd)
    model_kd = model_kd.to(args.gpus[0])

    for para in list(model_kd.parameters())[:-2]:
        para.requires_grad = False

    if len(args.gpus) != 1:
        print('@@@@@@')
        model_s = nn.DataParallel(model_s, device_ids=args.gpus[0, 1])

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d, model_kd]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(ckpt['epoch']))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        global g_e
        g_e = epoch
        gl.set_value('epoch', g_e)

        #train(args, loader.loader_train, models, optimizers, epoch, writer_train)
        #print('###########################')
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        if is_best:
            checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
Ejemplo n.º 6
0
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
        ])

    # configuration
    currentTime = datetime.datetime.now()
    currentTime = currentTime.strftime('%m%d%H%M%S')
    writer = SummaryWriter()
    model = Backbone()
    model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 40, gamma=0.1)
    if not args.eval:
        train_dataset = cifar10(transform=augmentation, eta=args.eta)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.workers,
                                  pin_memory=True)
    else:
        test_dataset = cifar10(transform=augmentation, if_test=True)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.workers,
                                 pin_memory=True)
    prefix = f'baseline_{currentTime}' if args.baseline else f'{args.alpha}_{args.beta}_{args.eta}_{currentTime}'
    if args.baseline:
        criterion = nn.CrossEntropyLoss().cuda()
Ejemplo n.º 7
0
def train_process_representation(config):
    # set up random seed
    torch.manual_seed(config.random_seed)
    # set up device
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # set up data
    train_data = cifar10(config, 'train')
    train_loader = DataLoader(train_data,
                              config.train_batch_size,
                              True,
                              num_workers=2)
    val_data = cifar10(config, 'val')
    val_loader = DataLoader(val_data,
                            config.eval_batch_size,
                            False,
                            num_workers=2)
    test_data = cifar10(config, 'test')
    test_loader = DataLoader(test_data,
                             config.eval_batch_size,
                             False,
                             num_workers=2)
    # set up nn, loss and optimizer
    net = ResNet18()
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)
    net.to(device)
    lr, momentum, weight_decay = config.init_lr, config.momentum, config.weight_decay
    optimizer = torch.optim.SGD(net.parameters(),
                                lr,
                                momentum,
                                weight_decay=weight_decay)
    # training process
    mb = MemoryBank(config.nr_train, config.dim_embed, config.mb_momentum,
                    config.mb_init_mode)
    opt_val_acc = 0
    for epoch in range(1, config.nr_epoch + 1):
        if epoch in config.boundary:
            adjust_lr(optimizer, config.lr_decay_rate)
        train_prob, train_sep = 0, 0
        for i, data in enumerate(train_loader, 1):
            # calc embedding of imgs, sample positive embedding and negative embedding from mb
            imgs, idx = data[0].to(device), data[2]
            embed = net(imgs)
            sz = embed.size(0)
            embed = embed.view(sz, 1, -1)
            embed_p = mb.sample(idx).to(device).view(sz, 1, -1)
            embed_n = mb.random_sample(sz * config.nr_sample).to(device).view(
                sz, config.nr_sample, -1)
            # calc logits and extract logits of the most hard examples
            logits_p, logits_n = (embed * embed_p).sum(
                dim=2) / config.tau, (embed * embed_n).sum(dim=2) / config.tau
            logits_n, _ = logits_n.sort(dim=1, descending=True)
            if epoch <= 10:
                idx_l, idx_r = 0, 1
            else:
                beta = (epoch - 10) / (config.nr_epoch - 10)
                idx_l, idx_r = config.final_l * beta, 1 - (
                    1 - config.final_r) * beta
            if config.adv_mode == 'fixed':
                idx_l, idx_r = config.final_l, config.final_r
            idx_l, idx_r = int(config.nr_sample * idx_l), int(
                config.nr_sample * idx_r)
            logits = torch.cat((logits_p, logits_n[:, idx_l:idx_r]), dim=1)
            embed = embed.view(sz, -1)
            # calc loss
            probs = F.softmax(logits, dim=1)[:, 0]
            alpha = (config.nr_train - 1) / (idx_r - idx_l) * config.alpha
            probs = probs / ((1 - alpha) * probs + alpha)
            loss = -probs.log().mean()
            # calc output metrics
            prob = probs.mean()
            train_prob = 0.5 * train_prob + 0.5 * prob.item()
            sep = torch.mm(embed, embed.permute(1, 0)).mean()
            train_sep = 0.5 * train_sep + 0.5 * sep.item()
            # update parameters and MemoryBank
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            mb.update(embed, idx)
            # output metrics
            if i % config.show_interval == 0:
                print('_______training_______')
                print('epoch: ', epoch, 'step: ', i, 'prob: ',
                      '%.3f' % train_prob, 'sep: ', '%.3f' % train_sep)
        if epoch % config.val_interval == 0:
            # validate by KNN
            net.eval()
            with torch.no_grad():
                x_train, y_train = mb.content.to(device), torch.tensor(
                    train_data.labels).to(device)
                knn = KNN(x_train, y_train, config.nr_label, config.K,
                          config.tau)
                val_acc, val_sep = 0, 0
                for i, data in enumerate(val_loader, 1):
                    # calc embedding
                    imgs, labels = data[0].to(device), data[1].to(device)
                    embed = net(imgs)
                    # calc output metrics
                    labels_pre = knn.predict(embed)
                    acc = (labels_pre == labels.long()).float().mean()
                    val_acc = (i - 1) / i * val_acc + acc.item() / i
                    sep = torch.mm(embed, embed.permute(1, 0)).mean()
                    val_sep = (i - 1) / i * val_sep + sep.item() / i
            print('_______validation_______')
            print('epoch: ', epoch, 'acc: ', '%.3f' % val_acc, 'sep: ',
                  '%.3f' % val_sep)
            if opt_val_acc <= val_acc:
                opt_val_acc = val_acc
                torch.save(net.state_dict(), config.save_dir + '/opt')
            net.train()
    print('finish training. optimal val acc: ', '%.3f' % opt_val_acc)
    ret = {}
    np.save('./tmp/mb.npy', mb.content.numpy())
    net.eval()
    net.load_state_dict(torch.load(config.save_dir + '/opt'))
    # testing embedding in MemoryBank
    x_train, y_train = mb.content.to(device), torch.tensor(
        train_data.labels).to(device)
    knn = KNN(x_train, y_train, config.nr_label, config.K, config.tau)
    test_acc = 0
    for i, data in enumerate(test_loader, 1):
        # calc embedding
        imgs, labels = data[0].to(device), data[1].to(device)
        embed = net(imgs)
        # calc output metric
        labels_pre = knn.predict(embed)
        acc = (labels_pre == labels.long()).float().mean()
        test_acc = (i - 1) / i * test_acc + acc.item() / i
    print('_______testing_______')
    print('acc: ', '%.3f' % test_acc)
    ret['test_by_mb'] = test_acc
    # testing embeding generated by net and train set
    x_train, y_train = [], []
    with torch.no_grad():
        for data in train_loader:
            imgs, labels = data[0].to(device), data[1]
            embed = net(imgs)
            x_train.extend(embed.cpu().tolist())
            y_train.extend(labels.tolist())
    knn = KNN(
        torch.tensor(x_train).to(device),
        torch.tensor(y_train).to(device), config.nr_label, config.K,
        config.tau)
    test_acc = 0
    for i, data in enumerate(test_loader, 1):
        # calc embedding
        imgs, labels = data[0].to(device), data[1].to(device)
        embed = net(imgs)
        # calc output metric
        labels_pre = knn.predict(embed)
        acc = (labels_pre == labels.long()).float().mean()
        test_acc = (i - 1) / i * test_acc + acc.item() / i
    print('_______testing_______')
    print('acc: ', '%.3f' % test_acc)
    ret['test_by_train'] = test_acc
    # testing embeding generated by net and val set
    x_train, y_train = [], []
    with torch.no_grad():
        for data in val_loader:
            imgs, labels = data[0].to(device), data[1]
            embed = net(imgs)
            x_train.extend(embed.cpu().tolist())
            y_train.extend(labels.tolist())
    knn = KNN(
        torch.tensor(x_train).to(device),
        torch.tensor(y_train).to(device), config.nr_label, config.K,
        config.tau)
    test_acc = 0
    for i, data in enumerate(test_loader, 1):
        # calc embed
        imgs, labels = data[0].to(device), data[1].to(device)
        embed = net(imgs)
        # calc output metric
        labels_pre = knn.predict(embed)
        acc = (labels_pre == labels.long()).float().mean()
        test_acc = (i - 1) / i * test_acc + acc.item() / i
    print('_______testing_______')
    print('acc: ', '%.3f' % test_acc)
    ret['test_by_val'] = test_acc
    return ret