def return_accuracies(start_idxs, NUM_ROUND, NUM_QUERY, epoch, learning_rate, datadir, data_name, feature): torch.manual_seed(42) torch.cuda.manual_seed(42) np.random.seed(42) random.seed(42) torch.backends.cudnn.deterministic = True if data_name in [ 'dna', 'sklearn-digits', 'satimage', 'svmguide1', 'letter', 'shuttle', 'ijcnn1', 'sensorless', 'connect_4', 'sensit_seismic', 'usps', 'adult' ]: fullset, valset, testset, num_cls = load_dataset_numpy_old( datadir, data_name, feature=feature) write_knndata_old(datadir, data_name, feature=feature) elif data_name == 'cifar10': fullset, valset, testset, num_cls = load_dataset_pytorch( datadir, data_name) # Validation Data set is 10% of the Entire Trainset. validation_set_fraction = 0.1 num_fulltrn = len(fullset) num_val = int(num_fulltrn * validation_set_fraction) num_trn = num_fulltrn - num_val trainset, validset = random_split( fullset, [num_trn, num_val]) #,generator=torch.Generator().manual_seed(42)) '''x_trn = fullset.data[trainset.indices] y_trn = torch.from_numpy(np.array(fullset.targets)[trainset.indices].astype('float32')) x_val = fullset.data[validset.indices] y_val = torch.from_numpy(np.array(fullset.targets)[validset.indices]) x_tst = testset.data y_tst = torch.from_numpy(np.array(testset.targets))''' trn_batch_size = 128 val_batch_size = 1000 tst_batch_size = 1000 trainloader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size, shuffle=False, pin_memory=True) valloader = torch.utils.data.DataLoader(valset, batch_size=val_batch_size, shuffle=False, sampler=SubsetRandomSampler( validset.indices), pin_memory=True) testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size, shuffle=False, pin_memory=True) for batch_idx, (inputs, targets) in enumerate(trainloader): if batch_idx == 0: x_trn = inputs y_trn = targets else: x_trn = torch.cat([x_trn, inputs], dim=0) y_trn = torch.cat([y_trn, targets], dim=0) for batch_idx, (inputs, targets) in enumerate(valloader): if batch_idx == 0: x_val = inputs y_val = targets else: x_val = torch.cat([x_val, inputs], dim=0) y_val = torch.cat([y_val, targets], dim=0) for batch_idx, (inputs, targets) in enumerate(testloader): if batch_idx == 0: x_tst = inputs y_tst = targets else: x_tst = torch.cat([x_tst, inputs], dim=0) y_tst = torch.cat([y_tst, targets], dim=0) #y_tst = y_tst.numpy().astype('float32') #y_val = y_val.numpy().astype('float32') y_trn = y_trn.numpy().astype('float32') elif data_name in ['mnist', "fashion-mnist"]: fullset, testset, num_cls = load_dataset_numpy_old(datadir, data_name, feature=feature) write_knndata_old(datadir, data_name, feature=feature) else: fullset, valset, testset, num_cls = load_dataset_numpy(datadir, data_name, feature=feature) write_knndata(datadir, data_name, feature=feature) if data_name == 'mnist' or data_name == "fashion-mnist": x_trn, y_trn = fullset.data, fullset.targets x_tst, y_tst = testset.data, testset.targets x_trn = x_trn.view(x_trn.shape[0], -1).numpy() x_tst = x_tst.view(x_tst.shape[0], -1).numpy() y_trn = y_trn.numpy() y_tst = y_tst.numpy() #.float() # Get validation data: Its 10% of the entire (full) training data x_trn, x_val, y_trn, y_val = train_test_split(x_trn, y_trn, test_size=0.1, random_state=42) else: if data_name != 'cifar10': x_trn, y_trn = fullset x_val, y_val = valset x_tst, y_tst = testset if data_name == 'cifar10': handler = DataHandler3 else: handler = CustomDataset_WithId if data_name == 'cifar10': args = { 'n_epoch': epoch, 'transform': None, 'loader_tr_args': { 'batch_size': 128 }, 'loader_te_args': { 'batch_size': 1000 }, 'optimizer_args': { 'lr': learning_rate }, 'transformTest': None } #transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]) args['lr'] = learning_rate else: args = { 'transform': None, 'n_epoch': epoch, 'loader_tr_args': { 'batch_size': NUM_QUERY }, 'loader_te_args': { 'batch_size': 1000 }, 'optimizer_args': { 'lr': learning_rate }, 'transformTest': None } args['lr'] = learning_rate n_pool = len(y_trn) n_val = len(y_val) n_test = len(y_tst) if data_name == 'cifar10': net = resnet.ResNet18() else: net = mlpMod(x_trn.shape[1], num_cls, 100) #linMod(x_trn.shape[1], num_cls) idxs_lb = np.zeros(n_pool, dtype=bool) idxs_lb[start_idxs] = True strategy = BadgeSampling(x_trn, y_trn, idxs_lb, net, handler, args) strategy.train() unlabled_acc = np.zeros(NUM_ROUND + 1) tst_acc = np.zeros(NUM_ROUND + 1) val_acc = np.zeros(NUM_ROUND + 1) P = strategy.predict(x_tst, y_tst) tst_acc[0] = 100.0 * P.eq(torch.tensor(y_tst)).sum().item() / n_test print('\ttesting accuracy {}'.format(tst_acc[0]), flush=True) #tst_acc[0] = 100.0 * P.eq(torch.tensor(y_tst)).sum().item()/ n_test #print('\ttesting accuracy {}'.format(tst_acc[0]), flush=True) P = strategy.predict(x_val, y_val) val_acc[0] = 100.0 * P.eq(torch.tensor(y_val)).sum().item() / n_val #idxs_unlabeled = (idxs_lb == False).nonzero().flatten().tolist() u_x_trn = x_trn[~idxs_lb] u_y_trn = y_trn[~idxs_lb] P = strategy.predict(u_x_trn, u_y_trn) unlabled_acc[0] = 100.0 * P.eq( torch.tensor(u_y_trn)).sum().item() / len(u_y_trn) for rd in range(1, NUM_ROUND + 1): print('Round {}'.format(rd), flush=True) # query output = strategy.query(NUM_QUERY) q_idxs = output idxs_lb[q_idxs] = True # report weighted accuracy #corr = (strategy.predict(X_tr[q_idxs], torch.Tensor(Y_tr.numpy()[q_idxs]).long())).numpy() == Y_tr.numpy()[q_idxs] # update strategy.update(idxs_lb) strategy.train() # round accuracy P = strategy.predict(x_tst, y_tst) tst_acc[rd] = 100.0 * P.eq(torch.tensor(y_tst)).sum().item() / n_test print(rd, '\ttesting accuracy {}'.format(tst_acc[rd]), flush=True) P = strategy.predict(x_val, y_val) val_acc[rd] = 100.0 * P.eq(torch.tensor(y_val)).sum().item() / n_val #idxs_unlabeled = (idxs_lb == False).nonzero().flatten().tolist() u_x_trn = x_trn[~idxs_lb] u_y_trn = y_trn[~idxs_lb] P = strategy.predict(u_x_trn, u_y_trn) unlabled_acc[rd] = 100.0 * P.eq( torch.tensor(u_y_trn)).sum().item() / len(u_y_trn) #print(str(sum(idxs_lb)) + '\t' + 'unlabled data', len(u_y_trn),flush=True) if sum(~strategy.idxs_lb) < NUM_QUERY: sys.exit('too few remaining points to query') return val_acc, tst_acc, unlabled_acc, np.arange(n_pool)[idxs_lb]
subprocess.run(["mkdir", "-p", all_logs_dir]) path_logfile = os.path.join(all_logs_dir, data_name + '.txt') logfile = open(path_logfile, 'w') exp_name = data_name + '_fraction:' + str(fraction) + '_epochs:' + str(num_epochs) + \ '_selEvery:' + str(select_every) + '_variant' + str(warm_method) + '_runs' + str(num_runs) print(exp_name) #print("=======================================", file=logfile) #print(exp_name, str(exp_start_time), file=logfile) if data_name in [ 'dna', 'sklearn-digits', 'satimage', 'svmguide1', 'letter', 'shuttle', 'ijcnn1', 'sensorless', 'connect_4', 'sensit_seismic', 'usps' ]: fullset, valset, testset, num_cls = load_dataset_numpy_old(datadir, data_name, feature=feature) #write_knndata_old(datadir, data_name,feature=feature) elif data_name in ['mnist', "fashion-mnist"]: fullset, testset, num_cls = load_dataset_numpy_old(datadir, data_name, feature=feature) #write_knndata_old(datadir, data_name,feature=feature) else: fullset, valset, testset, num_cls = load_dataset_numpy(datadir, data_name, feature=feature) #write_knndata(datadir, data_name,feature=feature) '''if data_name == 'mnist' or data_name == "fashion-mnist": x_trn, y_trn = fullset.data, fullset.targets x_tst, y_tst = testset.data, testset.targets