示例#1
0
def return_accuracies(start_idxs, NUM_ROUND, NUM_QUERY, epoch, learning_rate,
                      datadir, data_name, feature):

    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.backends.cudnn.deterministic = True

    if data_name in [
            'dna', 'sklearn-digits', 'satimage', 'svmguide1', 'letter',
            'shuttle', 'ijcnn1', 'sensorless', 'connect_4', 'sensit_seismic',
            'usps', 'adult'
    ]:
        fullset, valset, testset, num_cls = load_dataset_numpy_old(
            datadir, data_name, feature=feature)
        write_knndata_old(datadir, data_name, feature=feature)
    elif data_name == 'cifar10':
        fullset, valset, testset, num_cls = load_dataset_pytorch(
            datadir, data_name)
        # Validation Data set is 10% of the Entire Trainset.
        validation_set_fraction = 0.1
        num_fulltrn = len(fullset)
        num_val = int(num_fulltrn * validation_set_fraction)
        num_trn = num_fulltrn - num_val
        trainset, validset = random_split(
            fullset,
            [num_trn, num_val])  #,generator=torch.Generator().manual_seed(42))
        '''x_trn = fullset.data[trainset.indices]
        y_trn = torch.from_numpy(np.array(fullset.targets)[trainset.indices].astype('float32'))
        x_val = fullset.data[validset.indices]
        y_val = torch.from_numpy(np.array(fullset.targets)[validset.indices])
        x_tst = testset.data
        y_tst = torch.from_numpy(np.array(testset.targets))'''
        trn_batch_size = 128
        val_batch_size = 1000
        tst_batch_size = 1000

        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=trn_batch_size,
                                                  shuffle=False,
                                                  pin_memory=True)

        valloader = torch.utils.data.DataLoader(valset,
                                                batch_size=val_batch_size,
                                                shuffle=False,
                                                sampler=SubsetRandomSampler(
                                                    validset.indices),
                                                pin_memory=True)

        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=tst_batch_size,
                                                 shuffle=False,
                                                 pin_memory=True)

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            if batch_idx == 0:
                x_trn = inputs
                y_trn = targets
            else:
                x_trn = torch.cat([x_trn, inputs], dim=0)
                y_trn = torch.cat([y_trn, targets], dim=0)

        for batch_idx, (inputs, targets) in enumerate(valloader):
            if batch_idx == 0:
                x_val = inputs
                y_val = targets

            else:
                x_val = torch.cat([x_val, inputs], dim=0)
                y_val = torch.cat([y_val, targets], dim=0)

        for batch_idx, (inputs, targets) in enumerate(testloader):
            if batch_idx == 0:
                x_tst = inputs
                y_tst = targets
            else:
                x_tst = torch.cat([x_tst, inputs], dim=0)
                y_tst = torch.cat([y_tst, targets], dim=0)

        #y_tst = y_tst.numpy().astype('float32')
        #y_val = y_val.numpy().astype('float32')
        y_trn = y_trn.numpy().astype('float32')

    elif data_name in ['mnist', "fashion-mnist"]:
        fullset, testset, num_cls = load_dataset_numpy_old(datadir,
                                                           data_name,
                                                           feature=feature)
        write_knndata_old(datadir, data_name, feature=feature)
    else:
        fullset, valset, testset, num_cls = load_dataset_numpy(datadir,
                                                               data_name,
                                                               feature=feature)
        write_knndata(datadir, data_name, feature=feature)

    if data_name == 'mnist' or data_name == "fashion-mnist":
        x_trn, y_trn = fullset.data, fullset.targets
        x_tst, y_tst = testset.data, testset.targets
        x_trn = x_trn.view(x_trn.shape[0], -1).numpy()
        x_tst = x_tst.view(x_tst.shape[0], -1).numpy()
        y_trn = y_trn.numpy()
        y_tst = y_tst.numpy()  #.float()
        # Get validation data: Its 10% of the entire (full) training data
        x_trn, x_val, y_trn, y_val = train_test_split(x_trn,
                                                      y_trn,
                                                      test_size=0.1,
                                                      random_state=42)

    else:
        if data_name != 'cifar10':
            x_trn, y_trn = fullset
            x_val, y_val = valset
            x_tst, y_tst = testset

    if data_name == 'cifar10':
        handler = DataHandler3
    else:
        handler = CustomDataset_WithId

    if data_name == 'cifar10':
        args = {
            'n_epoch': epoch,
            'transform': None,
            'loader_tr_args': {
                'batch_size': 128
            },
            'loader_te_args': {
                'batch_size': 1000
            },
            'optimizer_args': {
                'lr': learning_rate
            },
            'transformTest': None
        }
        #transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        args['lr'] = learning_rate

    else:
        args = {
            'transform': None,
            'n_epoch': epoch,
            'loader_tr_args': {
                'batch_size': NUM_QUERY
            },
            'loader_te_args': {
                'batch_size': 1000
            },
            'optimizer_args': {
                'lr': learning_rate
            },
            'transformTest': None
        }

        args['lr'] = learning_rate

    n_pool = len(y_trn)
    n_val = len(y_val)
    n_test = len(y_tst)

    if data_name == 'cifar10':
        net = resnet.ResNet18()
    else:
        net = mlpMod(x_trn.shape[1], num_cls,
                     100)  #linMod(x_trn.shape[1], num_cls)

    idxs_lb = np.zeros(n_pool, dtype=bool)
    idxs_lb[start_idxs] = True

    strategy = BadgeSampling(x_trn, y_trn, idxs_lb, net, handler, args)

    strategy.train()

    unlabled_acc = np.zeros(NUM_ROUND + 1)
    tst_acc = np.zeros(NUM_ROUND + 1)
    val_acc = np.zeros(NUM_ROUND + 1)

    P = strategy.predict(x_tst, y_tst)
    tst_acc[0] = 100.0 * P.eq(torch.tensor(y_tst)).sum().item() / n_test
    print('\ttesting accuracy {}'.format(tst_acc[0]), flush=True)
    #tst_acc[0] = 100.0 * P.eq(torch.tensor(y_tst)).sum().item()/ n_test
    #print('\ttesting accuracy {}'.format(tst_acc[0]), flush=True)

    P = strategy.predict(x_val, y_val)
    val_acc[0] = 100.0 * P.eq(torch.tensor(y_val)).sum().item() / n_val

    #idxs_unlabeled = (idxs_lb == False).nonzero().flatten().tolist()
    u_x_trn = x_trn[~idxs_lb]
    u_y_trn = y_trn[~idxs_lb]
    P = strategy.predict(u_x_trn, u_y_trn)
    unlabled_acc[0] = 100.0 * P.eq(
        torch.tensor(u_y_trn)).sum().item() / len(u_y_trn)

    for rd in range(1, NUM_ROUND + 1):
        print('Round {}'.format(rd), flush=True)

        # query
        output = strategy.query(NUM_QUERY)
        q_idxs = output
        idxs_lb[q_idxs] = True

        # report weighted accuracy
        #corr = (strategy.predict(X_tr[q_idxs], torch.Tensor(Y_tr.numpy()[q_idxs]).long())).numpy() == Y_tr.numpy()[q_idxs]

        # update
        strategy.update(idxs_lb)
        strategy.train()

        # round accuracy
        P = strategy.predict(x_tst, y_tst)
        tst_acc[rd] = 100.0 * P.eq(torch.tensor(y_tst)).sum().item() / n_test
        print(rd, '\ttesting accuracy {}'.format(tst_acc[rd]), flush=True)

        P = strategy.predict(x_val, y_val)
        val_acc[rd] = 100.0 * P.eq(torch.tensor(y_val)).sum().item() / n_val

        #idxs_unlabeled = (idxs_lb == False).nonzero().flatten().tolist()
        u_x_trn = x_trn[~idxs_lb]
        u_y_trn = y_trn[~idxs_lb]
        P = strategy.predict(u_x_trn, u_y_trn)
        unlabled_acc[rd] = 100.0 * P.eq(
            torch.tensor(u_y_trn)).sum().item() / len(u_y_trn)

        #print(str(sum(idxs_lb)) + '\t' + 'unlabled data', len(u_y_trn),flush=True)

        if sum(~strategy.idxs_lb) < NUM_QUERY:
            sys.exit('too few remaining points to query')

    return val_acc, tst_acc, unlabled_acc, np.arange(n_pool)[idxs_lb]
示例#2
0
subprocess.run(["mkdir", "-p", all_logs_dir])
path_logfile = os.path.join(all_logs_dir, data_name + '.txt')

logfile = open(path_logfile, 'w')
exp_name = data_name + '_fraction:' + str(fraction) + '_epochs:' + str(num_epochs) + \
           '_selEvery:' + str(select_every) + '_variant' + str(warm_method) + '_runs' + str(num_runs)
print(exp_name)
#print("=======================================", file=logfile)
#print(exp_name, str(exp_start_time), file=logfile)

if data_name in [
        'dna', 'sklearn-digits', 'satimage', 'svmguide1', 'letter', 'shuttle',
        'ijcnn1', 'sensorless', 'connect_4', 'sensit_seismic', 'usps'
]:
    fullset, valset, testset, num_cls = load_dataset_numpy_old(datadir,
                                                               data_name,
                                                               feature=feature)
    #write_knndata_old(datadir, data_name,feature=feature)
elif data_name in ['mnist', "fashion-mnist"]:
    fullset, testset, num_cls = load_dataset_numpy_old(datadir,
                                                       data_name,
                                                       feature=feature)
    #write_knndata_old(datadir, data_name,feature=feature)
else:
    fullset, valset, testset, num_cls = load_dataset_numpy(datadir,
                                                           data_name,
                                                           feature=feature)
    #write_knndata(datadir, data_name,feature=feature)
'''if data_name == 'mnist' or data_name == "fashion-mnist":
    x_trn, y_trn = fullset.data, fullset.targets
    x_tst, y_tst = testset.data, testset.targets