def test_invalid_dataset2(): """Test dataset is invalid as different splits contain different columns""" train = (("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 4, 5.5)) val = (("ipsum quia dolor sit", 3.5), ) with pytest.raises(ValueError): t = TabularDataset(train, val)
def autogen_dataset_with_test(): """Dummy dataset from file with auto-generated val and given test""" return TabularDataset.autogen( 'tests/data/dummy_tabular/train.csv', test_path='tests/data/dummy_tabular_test/test.csv', seed=42, sep=',')
def test_dataset_transform_mixed_multiple_named_cols(): train = ( ("Lorem ipsum dolor sit amet", "POSITIVE"), ("Sed ut perspiciatis unde", "NEGATIVE")) class DummyField(Field): def setup(self, *data: np.ndarray) -> None: pass def process(self, ex1, ex2): return torch.tensor(0) transform = { "text": { "field": DummyField(), "columns": ['text', 'label'] }, "other": { "field": DummyField(), "columns": [0, 1] }, "other2": { "field": DummyField(), "columns": [0, 'label'] } } t = TabularDataset(train, transform=transform, named_columns=['text', 'label']) assert t.train.cols() == 3
def test_incomplete_dataset(): """Test dataset missing either val or test""" train = (("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 4, 5.5)) t = TabularDataset(train) assert len(t.val) == 0 assert len(t.test) == 0
def autogen_val_test_dataset_dir_ratios(): """Dummy dataset from directory with auto-generated val and test with different ratios """ return TabularDataset.autogen_val_test('tests/data/dummy_tabular', seed=42, sep=',', test_ratio=0.5, val_ratio=0.5)
def test_dataset_transform_8(): train = (("Lorem ipsum dolor sit amet", "POSITIVE"), ("Sed ut perspiciatis unde", "NEGATIVE")) transform = {"tx": {"field": LabelField(), "columns": [0, 1]}} with pytest.raises(TypeError): t = TabularDataset(train, transform=transform) t.train.cols()
def autogen_dataset_ratios(): """Dummy dataset from file with auto-generated val and test with different ratios """ return TabularDataset.autogen('tests/data/dummy_tabular/train.csv', seed=42, sep=',', test_ratio=0.5, val_ratio=0.5)
def test_dataset_transform_with_invalid_named_cols(): train = (("Lorem ipsum dolor sit amet", "POSITIVE"), ("Sed ut perspiciatis unde", "NEGATIVE")) transform = {"tx": {"field": LabelField(), "columns": 'none_existent'}} with pytest.raises(ValueError): TabularDataset(train, transform=transform, named_columns=['text', 'label'])
def test_dataset_transform_with_named_cols(): train = (("Lorem ipsum dolor sit amet", "POSITIVE"), ("Sed ut perspiciatis unde", "NEGATIVE")) transform = {"tx": {"field": LabelField(), "columns": 'label'}} t = TabularDataset(train, transform=transform, named_columns=['text', 'label']) assert len(t.train[0]) == 1
def autogen_dataset_ratios_with_test(): """Dummy dataset from file with auto-generated val and given test with different ratios """ return TabularDataset.autogen('tests/data/dummy_tabular/train.csv', test_path='tests/data/dummy_tabular_test/test.csv', seed=42, sep=',', test_ratio=0.5, # no effect val_ratio=0.5)
def test_dataset_transform(): train = (("Lorem ipsum dolor sit amet", "POSITIVE"), ("Sed ut perspiciatis unde", "NEGATIVE")) transform = {"text": TextField(), "label": LabelField()} t = TabularDataset(train, transform=transform) assert hasattr(t, "text") assert hasattr(t, "label") assert t.label.vocab_size == 2 assert t.text.vocab_size == 11
def test_dataset_transform_3(): train = (("Lorem ipsum dolor sit amet", "POSITIVE"), ("Sed ut perspiciatis unde", "NEGATIVE")) transform = { "text": { "columns": 0 }, "label": { "field": LabelField(), "columns": 1 } } with pytest.raises(ValueError): TabularDataset(train, transform=transform)
def test_dataset_transform_5(): train = (("Lorem ipsum dolor sit amet", "POSITIVE"), ("Sed ut perspiciatis unde", "NEGATIVE")) transform = { "t1": { "field": TextField(), "columns": 0 }, "t2": { "field": TextField(), "columns": 0 } } t = TabularDataset(train, transform=transform) assert t.train.cols() == 2
def test_cache_dataset(): """Test caching the dataset""" train = (("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 5, 5.5), ("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 5, 5.5), ("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 5, 5.5), ("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 5, 5.5), ("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 5, 5.5)) t = TabularDataset(train, cache=True) assert len(t.train.cached_data) == 0 for i, _ in enumerate(t.train): assert len(t.train.cached_data) == i + 1
def test_valid_dataset(): """Test trivial dataset build process""" train = (("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 5, 5.5)) val = (("ipsum quia dolor sit", 10, 3.5),) test = (("Ut enim ad minima veniam", 100, 35),) t = TabularDataset(train, val, test) assert len(t) == 4 assert len(t.train) == 2 assert len(t.val) == 1 assert len(t.test) == 1 def check(d, t): for i, tu in enumerate(d): v0, v1, v2 = tu assert t[i][0] == v0 assert t[i][1] == v1 assert t[i][2] == v2 check(train, t.train) check(val, t.val) check(test, t.test)
def test_invalid_dataset(): """Test dataset is invalid as it has different columns""" train = (("Lorem ipsum dolor sit amet", 3, 4.5), ("Sed ut perspiciatis unde", 5.5)) with pytest.raises(ValueError): TabularDataset(train)
def train_dataset(): """Dummy dataset from file""" return TabularDataset.from_path('tests/data/dummy_tabular/train.csv', sep=',')
def train(args): """Run Training """ global_step = 0 best_metric = None best_model: Dict[str, torch.Tensor] = dict() if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) writer = SummaryWriter(log_dir=args.output_dir) # We use flambe to do the data preprocessing # More info at https://flambe.ai print("Performing preprocessing (possibly download embeddings).") embeddings = args.embeddings if args.use_pretrained_embeddings else None text_field = TextField(lower=args.lowercase, embeddings=embeddings, embeddings_format='gensim') label_field = LabelField() transforms = {'text': text_field, 'label': label_field} dataset = TabularDataset.from_path( args.train_path, args.val_path, sep=',' if args.file_type == 'csv' else '\t', transform=transforms) # Create samplers train_sampler = EpisodicSampler(dataset.train, n_support=args.n_support, n_query=args.n_query, n_episodes=args.n_episodes, n_classes=args.n_classes) # The train_eval_sampler is used to computer prototypes over the full dataset train_eval_sampler = BaseSampler(dataset.train, batch_size=args.eval_batch_size) val_sampler = BaseSampler(dataset.val, batch_size=args.eval_batch_size) if args.device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device # Build model, criterion and optimizers model = PrototypicalTextClassifier( vocab_size=dataset.text.vocab_size, distance=args.distance, embedding_dim=args.embedding_dim, pretrained_embeddings=dataset.text.embedding_matrix, rnn_type='sru', n_layers=args.n_layers, hidden_dim=args.hidden_dim, freeze_pretrained_embeddings=True) loss_fn = nn.CrossEntropyLoss() parameters = (p for p in model.parameters() if p.requires_grad) optimizer = torch.optim.Adam(parameters, lr=args.learning_rate) print("Beginning training.") for epoch in range(args.num_epochs): ###################### # TRAIN # ###################### print(f'Epoch: {epoch}') model.train() with torch.enable_grad(): for batch in train_sampler: # Zero the gradients and clear the accumulated loss optimizer.zero_grad() # Move to device batch = tuple(t.to(device) for t in batch) query, query_label, support, support_label = batch # Compute loss pred = model(query, support, support_label) loss = loss_fn(pred, query_label) loss.backward() # Clip gradients if necessary if args.max_grad_norm is not None: clip_grad_norm_(model.parameters(), args.max_grad_norm) writer.add_scalar('Training/Loss', loss.item(), global_step) # Optimize optimizer.step() global_step += 1 # Zero the gradients when exiting a train step optimizer.zero_grad() ######################### # EVALUATE # ######################### model.eval() with torch.no_grad(): # First compute prototypes over the training data encodings, labels = [], [] for text, label in train_eval_sampler: padding_mask = (text != model.padding_idx).byte() text_embeddings = model.embedding_dropout( model.embedding(text)) text_encoding = model.encoder(text_embeddings, padding_mask=padding_mask) labels.append(label.cpu()) encodings.append(text_encoding.cpu()) # Compute prototypes encodings = torch.cat(encodings, dim=0) labels = torch.cat(labels, dim=0) prototypes = model.compute_prototypes(encodings, labels).to(device) _preds, _targets = [], [] for batch in val_sampler: # Move to device source, target = tuple(t.to(device) for t in batch) pred = model(source, prototypes=prototypes) _preds.append(pred.cpu()) _targets.append(target.cpu()) preds = torch.cat(_preds, dim=0) targets = torch.cat(_targets, dim=0) val_loss = loss_fn(preds, targets).item() val_metric = (pred.argmax(dim=1) == target).float().mean().item() # Update best model if best_metric is None or val_metric > best_metric: best_metric = val_metric best_model_state = model.state_dict() for k, t in best_model_state.items(): best_model_state[k] = t.cpu().detach() best_model = best_model_state # Log metrics print(f'Validation loss: {val_loss}') print(f'Validation accuracy: {val_metric}') writer.add_scalar('Validation/Loss', val_loss, epoch) writer.add_scalar('Validation/Accuracy', val_metric, epoch) # Save the best model print("Finisehd training.") torch.save(best_model, os.path.join(args.output_dir, 'model.pt'))
def autogen_dataset_dir_with_test(): """Dummy dataset from dir with auto-generated val and given test""" return TabularDataset.autogen('tests/data/dummy_tabular', test_path='tests/data/dummy_tabular_test', seed=42, sep=',')
def autogen_dataset_dir(): """Dummy dataset from directory with auto-generated val and test""" return TabularDataset.autogen('tests/data/dummy_tabular', seed=42, sep=',')
def autogen_dataset(): """Dummy dataset from file with auto-generated val and test""" return TabularDataset.autogen('tests/data/dummy_tabular/train.csv', seed=42, sep=',')
def dir_dataset(): """Dummy dataset from directory""" return TabularDataset.from_path('tests/data/dummy_tabular', sep=',')
def full_dataset(): """Dummy dataset from file""" return TabularDataset.from_path(train_path='tests/data/dummy_tabular/train.csv', val_path='tests/data/dummy_tabular/val.csv', sep=',')
def test_invalid_columns(): """Test dataset is invalid as it has different columns""" train = (("Lorem ipsum dolor sit amet", 3), ("Sed ut perspiciatis unde", 5.5)) with pytest.raises(ValueError): TabularDataset(train, named_columns=['some_random_col'])
def train_dataset_reversed(): """Dummy dataset from file""" return TabularDataset.from_path('tests/data/dummy_tabular/train.csv', sep=',', columns=['label', 'text'])
def test_named_columns(): """Test dataset is invalid as it has different columns""" train = (("Lorem ipsum dolor sit amet", 3), ("Sed ut perspiciatis unde", 5.5)) TabularDataset(train, named_columns=['col1', 'col2'])
def train_dataset_no_header(): """Dummy dataset from file""" return TabularDataset.from_path('tests/data/no_header_dataset.csv', sep=',', header=None)