def test_sad(max_sample, expected): dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: -1) active_loop = ActiveLearningLoop(dataset, get_probs_iter, heuristics.Random(), max_sample=max_sample, query_size=10, dummy_param=1) dataset.label_randomly(10) active_loop.step() assert len(dataset) == 10 + expected
def test_should_stop_iter(heur): dataset = ActiveLearningDataset(MyDataset(), make_unlabelled=lambda x: -1) active_loop = ActiveLearningLoop(dataset, get_probs_iter, heur, query_size=10, dummy_param=1) dataset.label_randomly(10) step = 0 for _ in range(15): flg = active_loop.step() step += 1 if not flg: break assert step == 10
def test_file_saving(tmpdir): tmpdir = str(tmpdir) heur = heuristics.BALD() ds = MyDataset() dataset = ActiveLearningDataset(ds, make_unlabelled=lambda x: -1) active_loop = ActiveLearningLoop(dataset, get_probs_iter, heur, uncertainty_folder=tmpdir, query_size=10, dummy_param=1) dataset.label_randomly(10) _ = active_loop.step() assert len(os.listdir(tmpdir)) == 1 file = pjoin(tmpdir, os.listdir(tmpdir)[0]) assert "pool=90" in file and "labelled=10" in file data = pickle.load(open(file, 'rb')) assert len(data['uncertainty']) == 90 # The diff between the current state and the step before is the newly labelled item. assert (data['dataset']['labelled'] != dataset.labelled).sum() == 10
def test_deprecation(): heur = heuristics.BALD() ds = MyDataset() dataset = ActiveLearningDataset(ds, make_unlabelled=lambda x: -1) with warnings.catch_warnings(record=True) as w: active_loop = ActiveLearningLoop(dataset, get_probs_iter, heur, ndata_to_label=10, dummy_param=1) assert issubclass(w[-1].category, DeprecationWarning) assert "ndata_to_label" in str(w[-1].message)
def main(): args = parse_args() use_cuda = torch.cuda.is_available() torch.backends.cudnn.benchmark = True random.seed(1337) torch.manual_seed(1337) if not use_cuda: print("warning, the experiments would take ages to run on cpu") hyperparams = vars(args) active_set, test_set = get_datasets(hyperparams["initial_pool"]) heuristic = get_heuristic(hyperparams["heuristic"], hyperparams["shuffle_prop"]) criterion = CrossEntropyLoss() model = vgg16(pretrained=False, num_classes=10) weights = load_state_dict_from_url( "https://download.pytorch.org/models/vgg16-397923af.pth") weights = {k: v for k, v in weights.items() if "classifier.6" not in k} model.load_state_dict(weights, strict=False) # change dropout layer to MCDropout model = patch_module(model) if use_cuda: model.cuda() optimizer = optim.SGD(model.parameters(), lr=hyperparams["lr"], momentum=0.9) # Wraps the model into a usable API. model = ModelWrapper(model, criterion) logs = {} logs["epoch"] = 0 # for prediction we use a smaller batchsize # since it is slower active_loop = ActiveLearningLoop( active_set, model.predict_on_dataset, heuristic, hyperparams.get("query_size", 1), batch_size=10, iterations=hyperparams["iterations"], use_cuda=use_cuda, ) # We will reset the weights at each active learning step. init_weights = deepcopy(model.state_dict()) for epoch in tqdm(range(args.epoch)): # Load the initial weights. model.load_state_dict(init_weights) model.train_on_dataset( active_set, optimizer, hyperparams["batch_size"], hyperparams["learning_epoch"], use_cuda, ) # Validation! model.test_on_dataset(test_set, hyperparams["batch_size"], use_cuda) metrics = model.metrics should_continue = active_loop.step() if not should_continue: break val_loss = metrics["test_loss"].value logs = { "val": val_loss, "epoch": epoch, "train": metrics["train_loss"].value, "labeled_data": active_set.labelled, "Next Training set size": len(active_set), } print(logs)
def main(): args = parse_args() use_cuda = torch.cuda.is_available() torch.backends.cudnn.benchmark = True random.seed(1337) torch.manual_seed(1337) if not use_cuda: print("warning, the experiments would take ages to run on cpu") hyperparams = vars(args) heuristic = get_heuristic(hyperparams['heuristic'], hyperparams['shuffle_prop']) model = BertForSequenceClassification.from_pretrained( pretrained_model_name_or_path=hyperparams["model"]) tokenizer = BertTokenizer.from_pretrained( pretrained_model_name_or_path=hyperparams["model"]) # In this example we use tokenizer once only in the beginning since it would # make the whole process faster. However, it is also possible to input tokenizer # in trainer. active_set, test_set = get_datasets(hyperparams['initial_pool'], tokenizer) # change dropout layer to MCDropout model = patch_module(model) if use_cuda: model.cuda() init_weights = deepcopy(model.state_dict()) training_args = TrainingArguments( output_dir='/app/baal/results', # output directory num_train_epochs=hyperparams['learning_epoch'], # total # of training epochs per_device_train_batch_size=16, # batch size per device during training per_device_eval_batch_size=64, # batch size for evaluation weight_decay=0.01, # strength of weight decay logging_dir='/app/baal/logs', # directory for storing logs ) # We wrap the huggingface Trainer to create an Active Learning Trainer model = BaalTransformersTrainer(model=model, args=training_args, train_dataset=active_set, eval_dataset=test_set, tokenizer=None) logs = {} logs['epoch'] = 0 # In this case, nlp data is fast to process and we do NoT need to use a smaller batch_size active_loop = ActiveLearningLoop(active_set, model.predict_on_dataset, heuristic, hyperparams.get('n_data_to_label', 1), iterations=hyperparams['iterations']) for epoch in tqdm(range(args.epoch)): # we use the default setup of HuggingFace for training (ex: epoch=1). # The setup is adjustable when BaalHuggingFaceTrainer is defined. model.train() # Validation! eval_metrics = model.evaluate() # We reorder the unlabelled pool at the frequency of learning_epoch # This helps with speed while not changing the quality of uncertainty estimation. should_continue = active_loop.step() # We reset the model weights to relearn from the new trainset. model.load_state_dict(init_weights) model.lr_scheduler = None if not should_continue: break active_logs = {"epoch": epoch, "labeled_data": active_set._labelled, "Next Training set size": len(active_set)} logs = {**eval_metrics, **active_logs} print(logs)