def expert(acq_num, model, init_weights, strategies, train_dataset, pool_subset, valid_dataset, test_dataset, device): strategy_queries = [] strategy_acc = [] trained_weights = deepcopy(model.state_dict()) for strategy in strategies: sel_ind, remain_ind = strategy.query(prop.ACQ_SIZE, model, train_dataset, pool_subset) sel_dataset = make_tensordataset(pool_subset, sel_ind) curr_train_dataset = concat_datasets(train_dataset, sel_dataset) model.load_state_dict(init_weights) test_acc = train_validate_model(model, device, curr_train_dataset, valid_dataset, test_dataset) model.load_state_dict(trained_weights) strategy_acc.append(test_acc) strategy_queries.append(sel_ind) sel_strategy = strategy_acc.index(max(strategy_acc)) print("Expert for {} acquisition is {} sampling with model acccuracy {}". format(acq_num, strategies[sel_strategy].name, strategy_acc)) return strategy_queries[sel_strategy], strategy_acc, strategy_queries
def transform_data(data): data = data.unsqueeze(1).float().div(255) return data train_dataset = EMNIST(DATA_PATH, split='letters', train=True, download=True) # alternatives: letters, balanced trainX, trainy = transform_data(train_dataset.data), (train_dataset.targets-1) train_dataset = TensorDataset(trainX, trainy) ################ test dataset ################################ test_dataset = EMNIST(DATA_PATH, split='letters', train=False, download=True) # alternatives: letters, balanced testX, testy = transform_data(test_dataset.data), (test_dataset.targets-1) test_dataset = TensorDataset(testX, testy) full_dataset = concat_datasets(train_dataset, test_dataset) def get_data_splits(): validation_dataset, split_train_dataset = split_dataset(train_dataset, prop.VAL_SIZE) return split_train_dataset, validation_dataset, test_dataset def get_policy_training_splits(): test_dataset, train_dataset = split_dataset(full_dataset, prop.POLICY_TEST_SIZE) validation_dataset, split_train_dataset = split_dataset(train_dataset, prop.VAL_SIZE) return split_train_dataset, validation_dataset, test_dataset
def run_episode(strategies, policy, beta, device, num_worker): states, actions = [], [] # all strategies use same initial training data and model weights reinit_seed(prop.RANDOM_SEED) if prop.MODEL == "MLP": model = MLP().apply(weights_init).to(device) if prop.MODEL == "CNN": model = CNN().apply(weights_init).to(device) if prop.MODEL == "RESNET18": model = models.resnet.ResNet18().to(device) init_weights = deepcopy(model.state_dict()) # re-init seed was here before use_learner = True if np.random.rand(1) > beta else False if use_learner: policy = policy.to( device) # load policy only when learner is used for states dataset_pool, valid_dataset, test_dataset = get_policy_training_splits() train_dataset, pool_dataset = stratified_split_dataset( dataset_pool, prop.INIT_SIZE, prop.NUM_CLASSES) # Initial sampling if prop.SINGLE_HEAD: my_strategies = [] for StrategyClass in strategies: my_strategies.append( StrategyClass(dataset_pool, valid_dataset, test_dataset)) if prop.CLUSTER_EXPERT_HEAD: UncertaintyStrategieClasses, DiversityStrategieClasses = strategies un_strategies = [] di_strategies = [] for StrategyClass in UncertaintyStrategieClasses: un_strategies.append( StrategyClass(dataset_pool, valid_dataset, test_dataset)) for StrategyClass in DiversityStrategieClasses: di_strategies.append( StrategyClass(dataset_pool, valid_dataset, test_dataset)) if prop.CLUSTERING_AUX_LOSS_HEAD: my_strategies = [] for StrategyClass in strategies: my_strategies.append( StrategyClass(dataset_pool, valid_dataset, test_dataset)) init_acc = train_validate_model(model, device, train_dataset, valid_dataset, test_dataset) t = trange(1, prop.NUM_ACQS + 1, desc="Aquisitions (size {})".format(prop.ACQ_SIZE), leave=True) for acq_num in t: subset_ind = np.random.choice(a=len(pool_dataset), size=prop.K, replace=False) pool_subset = make_tensordataset(pool_dataset, subset_ind) if prop.CLUSTER_EXPERT_HEAD: un_sel_ind = expert(acq_num, model, init_weights, un_strategies, train_dataset, pool_subset, valid_dataset, test_dataset, device) di_sel_ind = expert(acq_num, model, init_weights, un_strategies, train_dataset, pool_subset, valid_dataset, test_dataset, device) state, action = get_state_action(model, train_dataset, pool_subset, un_sel_ind=un_sel_ind, di_sel_ind=di_sel_ind) if prop.SINGLE_HEAD: sel_ind = expert(acq_num, model, init_weights, my_strategies, train_dataset, pool_subset, valid_dataset, test_dataset, device) state, action = get_state_action(model, train_dataset, pool_subset, sel_ind=sel_ind) if prop.CLUSTERING_AUX_LOSS_HEAD: sel_ind = expert(acq_num, model, init_weights, my_strategies, train_dataset, pool_subset, valid_dataset, test_dataset, device) state, action = get_state_action(model, train_dataset, pool_subset, sel_ind=sel_ind, clustering=None) # not implemented states.append(state) actions.append(action) if use_learner: with torch.no_grad(): if prop.SINGLE_HEAD: policy_outputs = policy(state.to(device)).flatten() sel_ind = torch.topk(policy_outputs, prop.ACQ_SIZE)[1].cpu().numpy() if prop.CLUSTER_EXPERT_HEAD: policy_output_uncertainty, policy_output_diversity = policy( state.to(device)) # clustering_space = policy_output_diversity.reshape(prop.K, prop.POLICY_OUTPUT_SIZE) # one topk for uncertainty, one topk for diversity diversity_selection = torch.topk( policy_output_diversity.reshape(prop.K), int(prop.ACQ_SIZE / 2.0))[1].cpu().numpy() uncertainty_selection = torch.topk( policy_output_uncertainty.reshape(prop.K), int(prop.ACQ_SIZE / 2.0))[1].cpu().numpy() sel_ind = (uncertainty_selection, diversity_selection) if prop.CLUSTERING_AUX_LOSS_HEAD: # not implemented policy_outputs = policy(state.to(device)).flatten() sel_ind = torch.topk(policy_outputs, prop.ACQ_SIZE)[1].cpu().numpy() if prop.SINGLE_HEAD: q_idxs = subset_ind[sel_ind] # from subset to full pool if prop.CLUSTER_EXPERT_HEAD: unified_sel_ind = np.concatenate((sel_ind[0], sel_ind[1])) q_idxs = subset_ind[unified_sel_ind] # from subset to full pool remaining_ind = list(set(np.arange(len(pool_dataset))) - set(q_idxs)) sel_dataset = make_tensordataset(pool_dataset, q_idxs) train_dataset = concat_datasets(train_dataset, sel_dataset) pool_dataset = make_tensordataset(pool_dataset, remaining_ind) test_acc = train_validate_model(model, device, train_dataset, valid_dataset, test_dataset) return states, actions
def active_learn(exp_num, StrategyClass, subsample): # all strategies use same initial training data and model weights reinit_seed(prop.RANDOM_SEED) test_acc_list = [] if prop.MODEL.lower() == "mlp": model = MLP().apply(weights_init).to(device) if prop.MODEL.lower() == "cnn": model = CNN().apply(weights_init).to(device) if prop.MODEL.lower() == "resnet18": model = models.resnet.ResNet18().to(device) init_weights = copy.deepcopy(model.state_dict()) reinit_seed(exp_num * 10) dataset_pool, valid_dataset, test_dataset = get_data_splits() train_dataset, pool_dataset = stratified_split_dataset( dataset_pool, 2 * prop.NUM_CLASSES, prop.NUM_CLASSES) # # initial data strategy = StrategyClass(dataset_pool, valid_dataset, test_dataset, device) # calculate the overlap of strategy with other strategies strategies = [ MCDropoutSampling, EnsembleSampling, EntropySampling, LeastConfidenceSampling, CoreSetAltSampling, BadgeSampling ] overlapping_strategies = [] for StrategyClass in strategies: overlapping_strategies.append( StrategyClass(dataset_pool, valid_dataset, test_dataset)) t = trange(1, prop.NUM_ACQS + 1, desc="Aquisitions (size {})".format(prop.ACQ_SIZE), leave=True) for acq_num in t: model.load_state_dict(init_weights) test_acc = train_validate_model(model, device, train_dataset, valid_dataset, test_dataset) test_acc_list.append(test_acc) if subsample: subset_ind = np.random.choice(a=len(pool_dataset), size=prop.K, replace=False) pool_subset = make_tensordataset(pool_dataset, subset_ind) sel_ind, remain_ind = strategy.query(prop.ACQ_SIZE, model, train_dataset, pool_subset) q_idxs = subset_ind[sel_ind] # from subset to full pool remaining_ind = list( set(np.arange(len(pool_dataset))) - set(q_idxs)) sel_dataset = make_tensordataset(pool_dataset, q_idxs) train_dataset = concat_datasets(train_dataset, sel_dataset) pool_dataset = make_tensordataset(pool_dataset, remaining_ind) else: # all strategies work on k-sized windows in semi-batch setting sel_ind, remaining_ind = strategy.query(prop.ACQ_SIZE, model, train_dataset, pool_dataset) sel_dataset = make_tensordataset(pool_dataset, sel_ind) pool_dataset = make_tensordataset(pool_dataset, remaining_ind) train_dataset = concat_datasets(train_dataset, sel_dataset) logging.info( "Accuracy for {} sampling and {} acquisition is {}".format( strategy.name, acq_num, test_acc)) return test_acc_list