Exemple #1
0
def test_invalid_dataset2():
    """Test dataset is invalid as different splits contain different columns"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 4, 5.5))
    val = (("ipsum quia dolor sit", 3.5), )
    with pytest.raises(ValueError):
        t = TabularDataset(train, val)
Exemple #2
0
def autogen_dataset_with_test():
    """Dummy dataset from file with auto-generated val and given test"""
    return TabularDataset.autogen(
        'tests/data/dummy_tabular/train.csv',
        test_path='tests/data/dummy_tabular_test/test.csv',
        seed=42,
        sep=',')
Exemple #3
0
def test_dataset_transform_mixed_multiple_named_cols():
    train = (
            ("Lorem ipsum dolor sit amet", "POSITIVE"),
            ("Sed ut perspiciatis unde", "NEGATIVE"))

    class DummyField(Field):
        def setup(self, *data: np.ndarray) -> None:
            pass

        def process(self, ex1, ex2):
            return torch.tensor(0)

    transform = {
        "text": {
            "field": DummyField(),
            "columns": ['text', 'label']
        },
        "other": {
            "field": DummyField(),
            "columns": [0, 1]
        },
        "other2": {
            "field": DummyField(),
            "columns": [0, 'label']
        }
    }

    t = TabularDataset(train, transform=transform, named_columns=['text', 'label'])
    assert t.train.cols() == 3
Exemple #4
0
def test_incomplete_dataset():
    """Test dataset missing either val or test"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 4, 5.5))
    t = TabularDataset(train)

    assert len(t.val) == 0
    assert len(t.test) == 0
def autogen_val_test_dataset_dir_ratios():
    """Dummy dataset from directory with auto-generated val and test with
    different ratios
    """
    return TabularDataset.autogen_val_test('tests/data/dummy_tabular',
                                           seed=42,
                                           sep=',',
                                           test_ratio=0.5,
                                           val_ratio=0.5)
Exemple #6
0
def test_dataset_transform_8():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"tx": {"field": LabelField(), "columns": [0, 1]}}

    with pytest.raises(TypeError):
        t = TabularDataset(train, transform=transform)
        t.train.cols()
Exemple #7
0
def autogen_dataset_ratios():
    """Dummy dataset from file with auto-generated val and test with
    different ratios
    """
    return TabularDataset.autogen('tests/data/dummy_tabular/train.csv',
                                  seed=42,
                                  sep=',',
                                  test_ratio=0.5,
                                  val_ratio=0.5)
Exemple #8
0
def test_dataset_transform_with_invalid_named_cols():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"tx": {"field": LabelField(), "columns": 'none_existent'}}

    with pytest.raises(ValueError):
        TabularDataset(train,
                       transform=transform,
                       named_columns=['text', 'label'])
Exemple #9
0
def test_dataset_transform_with_named_cols():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"tx": {"field": LabelField(), "columns": 'label'}}

    t = TabularDataset(train,
                       transform=transform,
                       named_columns=['text', 'label'])
    assert len(t.train[0]) == 1
Exemple #10
0
def autogen_dataset_ratios_with_test():
    """Dummy dataset from file with auto-generated val and given test
    with different ratios
    """
    return TabularDataset.autogen('tests/data/dummy_tabular/train.csv',
                                  test_path='tests/data/dummy_tabular_test/test.csv',
                                  seed=42,
                                  sep=',',
                                  test_ratio=0.5,  # no effect
                                  val_ratio=0.5)
Exemple #11
0
def test_dataset_transform():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {"text": TextField(), "label": LabelField()}

    t = TabularDataset(train, transform=transform)

    assert hasattr(t, "text")
    assert hasattr(t, "label")

    assert t.label.vocab_size == 2
    assert t.text.vocab_size == 11
Exemple #12
0
def test_dataset_transform_3():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {
        "text": {
            "columns": 0
        },
        "label": {
            "field": LabelField(),
            "columns": 1
        }
    }

    with pytest.raises(ValueError):
        TabularDataset(train, transform=transform)
Exemple #13
0
def test_dataset_transform_5():
    train = (("Lorem ipsum dolor sit amet", "POSITIVE"),
             ("Sed ut perspiciatis unde", "NEGATIVE"))

    transform = {
        "t1": {
            "field": TextField(),
            "columns": 0
        },
        "t2": {
            "field": TextField(),
            "columns": 0
        }
    }

    t = TabularDataset(train, transform=transform)
    assert t.train.cols() == 2
Exemple #14
0
def test_cache_dataset():
    """Test caching the dataset"""
    train = (("Lorem ipsum dolor sit amet", 3,
              4.5), ("Sed ut perspiciatis unde", 5,
                     5.5), ("Lorem ipsum dolor sit amet", 3,
                            4.5), ("Sed ut perspiciatis unde", 5, 5.5),
             ("Lorem ipsum dolor sit amet", 3,
              4.5), ("Sed ut perspiciatis unde", 5,
                     5.5), ("Lorem ipsum dolor sit amet", 3,
                            4.5), ("Sed ut perspiciatis unde", 5, 5.5),
             ("Lorem ipsum dolor sit amet", 3,
              4.5), ("Sed ut perspiciatis unde", 5, 5.5))

    t = TabularDataset(train, cache=True)

    assert len(t.train.cached_data) == 0
    for i, _ in enumerate(t.train):
        assert len(t.train.cached_data) == i + 1
Exemple #15
0
def test_valid_dataset():
    """Test trivial dataset build process"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 5, 5.5))
    val = (("ipsum quia dolor sit", 10, 3.5),)
    test = (("Ut enim ad minima veniam", 100, 35),)

    t = TabularDataset(train, val, test)

    assert len(t) == 4
    assert len(t.train) == 2
    assert len(t.val) == 1
    assert len(t.test) == 1

    def check(d, t):
        for i, tu in enumerate(d):
            v0, v1, v2 = tu
            assert t[i][0] == v0
            assert t[i][1] == v1
            assert t[i][2] == v2

    check(train, t.train)
    check(val, t.val)
    check(test, t.test)
Exemple #16
0
def test_invalid_dataset():
    """Test dataset is invalid as it has different columns"""
    train = (("Lorem ipsum dolor sit amet", 3, 4.5),
             ("Sed ut perspiciatis unde", 5.5))
    with pytest.raises(ValueError):
        TabularDataset(train)
Exemple #17
0
def train_dataset():
    """Dummy dataset from file"""
    return TabularDataset.from_path('tests/data/dummy_tabular/train.csv', sep=',')
def train(args):
    """Run Training """

    global_step = 0
    best_metric = None
    best_model: Dict[str, torch.Tensor] = dict()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    writer = SummaryWriter(log_dir=args.output_dir)

    # We use flambe to do the data preprocessing
    # More info at https://flambe.ai
    print("Performing preprocessing (possibly download embeddings).")
    embeddings = args.embeddings if args.use_pretrained_embeddings else None
    text_field = TextField(lower=args.lowercase,
                           embeddings=embeddings,
                           embeddings_format='gensim')
    label_field = LabelField()
    transforms = {'text': text_field, 'label': label_field}
    dataset = TabularDataset.from_path(
        args.train_path,
        args.val_path,
        sep=',' if args.file_type == 'csv' else '\t',
        transform=transforms)

    # Create samplers
    train_sampler = EpisodicSampler(dataset.train,
                                    n_support=args.n_support,
                                    n_query=args.n_query,
                                    n_episodes=args.n_episodes,
                                    n_classes=args.n_classes)

    # The train_eval_sampler is used to computer prototypes over the full dataset
    train_eval_sampler = BaseSampler(dataset.train,
                                     batch_size=args.eval_batch_size)
    val_sampler = BaseSampler(dataset.val, batch_size=args.eval_batch_size)

    if args.device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = args.device

    # Build model, criterion and optimizers
    model = PrototypicalTextClassifier(
        vocab_size=dataset.text.vocab_size,
        distance=args.distance,
        embedding_dim=args.embedding_dim,
        pretrained_embeddings=dataset.text.embedding_matrix,
        rnn_type='sru',
        n_layers=args.n_layers,
        hidden_dim=args.hidden_dim,
        freeze_pretrained_embeddings=True)

    loss_fn = nn.CrossEntropyLoss()

    parameters = (p for p in model.parameters() if p.requires_grad)
    optimizer = torch.optim.Adam(parameters, lr=args.learning_rate)

    print("Beginning training.")
    for epoch in range(args.num_epochs):

        ######################
        #       TRAIN        #
        ######################

        print(f'Epoch: {epoch}')

        model.train()

        with torch.enable_grad():
            for batch in train_sampler:
                # Zero the gradients and clear the accumulated loss
                optimizer.zero_grad()

                # Move to device
                batch = tuple(t.to(device) for t in batch)
                query, query_label, support, support_label = batch

                # Compute loss
                pred = model(query, support, support_label)
                loss = loss_fn(pred, query_label)
                loss.backward()

                # Clip gradients if necessary
                if args.max_grad_norm is not None:
                    clip_grad_norm_(model.parameters(), args.max_grad_norm)

                writer.add_scalar('Training/Loss', loss.item(), global_step)

                # Optimize
                optimizer.step()
                global_step += 1

            # Zero the gradients when exiting a train step
            optimizer.zero_grad()

        #########################
        #       EVALUATE        #
        #########################

        model.eval()

        with torch.no_grad():

            # First compute prototypes over the training data
            encodings, labels = [], []
            for text, label in train_eval_sampler:
                padding_mask = (text != model.padding_idx).byte()
                text_embeddings = model.embedding_dropout(
                    model.embedding(text))
                text_encoding = model.encoder(text_embeddings,
                                              padding_mask=padding_mask)
                labels.append(label.cpu())
                encodings.append(text_encoding.cpu())
            # Compute prototypes
            encodings = torch.cat(encodings, dim=0)
            labels = torch.cat(labels, dim=0)
            prototypes = model.compute_prototypes(encodings, labels).to(device)

            _preds, _targets = [], []
            for batch in val_sampler:
                # Move to device
                source, target = tuple(t.to(device) for t in batch)

                pred = model(source, prototypes=prototypes)
                _preds.append(pred.cpu())
                _targets.append(target.cpu())

            preds = torch.cat(_preds, dim=0)
            targets = torch.cat(_targets, dim=0)

            val_loss = loss_fn(preds, targets).item()
            val_metric = (pred.argmax(dim=1) == target).float().mean().item()

        # Update best model
        if best_metric is None or val_metric > best_metric:
            best_metric = val_metric
            best_model_state = model.state_dict()
            for k, t in best_model_state.items():
                best_model_state[k] = t.cpu().detach()
            best_model = best_model_state

        # Log metrics
        print(f'Validation loss: {val_loss}')
        print(f'Validation accuracy: {val_metric}')
        writer.add_scalar('Validation/Loss', val_loss, epoch)
        writer.add_scalar('Validation/Accuracy', val_metric, epoch)

    # Save the best model
    print("Finisehd training.")
    torch.save(best_model, os.path.join(args.output_dir, 'model.pt'))
Exemple #19
0
def autogen_dataset_dir_with_test():
    """Dummy dataset from dir with auto-generated val and given test"""
    return TabularDataset.autogen('tests/data/dummy_tabular',
                                  test_path='tests/data/dummy_tabular_test',
                                  seed=42,
                                  sep=',')
Exemple #20
0
def autogen_dataset_dir():
    """Dummy dataset from directory with auto-generated val and test"""
    return TabularDataset.autogen('tests/data/dummy_tabular',
                                  seed=42,
                                  sep=',')
Exemple #21
0
def autogen_dataset():
    """Dummy dataset from file with auto-generated val and test"""
    return TabularDataset.autogen('tests/data/dummy_tabular/train.csv',
                                  seed=42,
                                  sep=',')
Exemple #22
0
def dir_dataset():
    """Dummy dataset from directory"""
    return TabularDataset.from_path('tests/data/dummy_tabular', sep=',')
Exemple #23
0
def full_dataset():
    """Dummy dataset from file"""
    return TabularDataset.from_path(train_path='tests/data/dummy_tabular/train.csv',
                                    val_path='tests/data/dummy_tabular/val.csv', sep=',')
Exemple #24
0
def test_invalid_columns():
    """Test dataset is invalid as it has different columns"""
    train = (("Lorem ipsum dolor sit amet", 3),
             ("Sed ut perspiciatis unde", 5.5))
    with pytest.raises(ValueError):
        TabularDataset(train, named_columns=['some_random_col'])
Exemple #25
0
def train_dataset_reversed():
    """Dummy dataset from file"""
    return TabularDataset.from_path('tests/data/dummy_tabular/train.csv', sep=',',
                                    columns=['label', 'text'])
Exemple #26
0
def test_named_columns():
    """Test dataset is invalid as it has different columns"""
    train = (("Lorem ipsum dolor sit amet", 3),
             ("Sed ut perspiciatis unde", 5.5))
    TabularDataset(train, named_columns=['col1', 'col2'])
Exemple #27
0
def train_dataset_no_header():
    """Dummy dataset from file"""
    return TabularDataset.from_path('tests/data/no_header_dataset.csv', sep=',',
                                    header=None)