Beispiel #1
0
  def __init__(self, batch_size):

    self.text = Field(
        lower=True,
        tokenize=lambda x: [tok.text for tok in spacy_en.tokenizer(x)],
        batch_first=True)
    self.label = Field(sequential=False, unk_token=None, is_target=True)

    self.train, self.dev, self.test = SNLI.splits(self.text, self.label)
    self.sizes = {
        'train': len(self.train),
        'val': len(self.dev),
        'test': len(self.test)
    }
    self.text.build_vocab(self.train, self.dev)
    self.label.build_vocab(self.train)

    vector_cache_loc = '.vector_cache/snli_vectors.pt'
    if os.path.isfile(vector_cache_loc):
      self.text.vocab.vectors = torch.load(vector_cache_loc)
    else:
      self.text.vocab.load_vectors('glove.840B.300d')
      torch.save(self.text.vocab.vectors, vector_cache_loc)

    # Batching
    self.train_iter, self.dev_iter, self.test_iter = Iterator.splits(
        (self.train, self.dev, self.test),
        batch_size=batch_size,
        device='cuda:0' if torch.cuda.is_available() else 'cpu')

    self.vocab_size = len(self.text.vocab)
    self.out_dim = len(self.label.vocab)
    self.labels = self.label.vocab.stoi
Beispiel #2
0
def get_dataloaders(batch_size: int, data_path: Path):
    data_path.mkdir(parents=True, exist_ok=True)
    TEXT = torchtext.data.Field(lower=True,
                                batch_first=True,
                                tokenize="spacy",
                                include_lengths=True)
    LABEL = torchtext.data.Field(sequential=False, unk_token=None)
    train_data, val_data, test_data = SNLI.splits(text_field=TEXT,
                                                  label_field=LABEL,
                                                  root=data_path)
    TEXT.build_vocab(train_data, vectors=GloVe(cache=data_path))
    LABEL.build_vocab(train_data)
    train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits(
        (train_data, val_data, test_data), batch_size=batch_size)
    return train_iter, val_iter, test_iter, TEXT.vocab
Beispiel #3
0
def get_SNLI(text_field, label_field, percentage=None):
    """
    Returns the SNLI dataset in splits

    :param torchtext.data.Field text_field: the field that will be used for premise and hypothesis data
    :param torchtext.data.Field label_field: the field that will be used for label data
    :param float percentage: the percentage of the data to use
    :returns: the SNLI dataset in splits
    :rtype: tuple
    """
    train, dev, test = SNLI.splits(text_field, label_field)

    if percentage:
        train.examples = train.examples[:np.
                                        int(np.ceil(len(train) * percentage))]
        dev.examples = dev.examples[:np.int(np.ceil(len(dev) * percentage))]
        test.examples = test.examples[:np.int(np.ceil(len(test) * percentage))]

    return train, dev, test
Beispiel #4
0
torch.backends.cudnn.deterministic = True


def tokenize(text):
    return nltk.tokenize.word_tokenize(text)


TEXT = Field(tokenize=tokenize,
             init_token='<sos>',
             eos_token='<eos>',
             include_lengths=True,
             lower=True)

LABEL = Field(tokenize=tokenize, lower=True)

train_data, valid_data, test_data = SNLI.splits(TEXT, LABEL)

TEXT.build_vocab(train_data, min_freq=2, vectors='glove.42B.300d')
LABEL.build_vocab(train_data, min_freq=2)

BATCH_SIZE = 32

# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=BATCH_SIZE,
    sort_within_batch=True,
    sort_key=lambda x: x.label,
    device=device)
Beispiel #5
0
    def test_snli(self):
        batch_size = 4

        # create fields
        TEXT = ParsedTextField()
        TREE = ShiftReduceField()
        LABEL = LabelField()

        # create train/val/test splits
        train, val, test = SNLI.splits(TEXT, LABEL, TREE)

        # check all are SNLI datasets
        assert type(train) == type(val) == type(test) == SNLI

        # check all have correct number of fields
        assert len(train.fields) == len(val.fields) == len(test.fields) == 5

        # check fields are the correct type
        assert type(train.fields['premise']) == ParsedTextField
        assert type(train.fields['premise_transitions']) == ShiftReduceField
        assert type(train.fields['hypothesis']) == ParsedTextField
        assert type(train.fields['hypothesis_transitions']) == ShiftReduceField
        assert type(train.fields['label']) == LabelField

        assert type(val.fields['premise']) == ParsedTextField
        assert type(val.fields['premise_transitions']) == ShiftReduceField
        assert type(val.fields['hypothesis']) == ParsedTextField
        assert type(val.fields['hypothesis_transitions']) == ShiftReduceField
        assert type(val.fields['label']) == LabelField

        assert type(test.fields['premise']) == ParsedTextField
        assert type(test.fields['premise_transitions']) == ShiftReduceField
        assert type(test.fields['hypothesis']) == ParsedTextField
        assert type(test.fields['hypothesis_transitions']) == ShiftReduceField
        assert type(test.fields['label']) == LabelField

        # check each is the correct length
        assert len(train) == 549367
        assert len(val) == 9842
        assert len(test) == 9824

        # build vocabulary
        TEXT.build_vocab(train)
        LABEL.build_vocab(train)

        # ensure vocabulary has been created
        assert hasattr(TEXT, 'vocab')
        assert hasattr(TEXT.vocab, 'itos')
        assert hasattr(TEXT.vocab, 'stoi')

        # create iterators
        train_iter, val_iter, test_iter = Iterator.splits(
            (train, val, test), batch_size=batch_size)

        # get a batch to test
        batch = next(iter(train_iter))

        # split premise and hypothesis from tuples to tensors
        premise, premise_transitions = batch.premise
        hypothesis, hypothesis_transitions = batch.hypothesis
        label = batch.label

        # check each is actually a tensor
        assert type(premise) == torch.Tensor
        assert type(premise_transitions) == torch.Tensor
        assert type(hypothesis) == torch.Tensor
        assert type(hypothesis_transitions) == torch.Tensor
        assert type(label) == torch.Tensor

        # check have the correct batch dimension
        assert premise.shape[-1] == batch_size
        assert premise_transitions.shape[-1] == batch_size
        assert hypothesis.shape[-1] == batch_size
        assert hypothesis_transitions.shape[-1] == batch_size
        assert label.shape[-1] == batch_size

        # repeat the same tests with iters instead of split
        train_iter, val_iter, test_iter = SNLI.iters(batch_size=batch_size,
                                                     trees=True)

        # split premise and hypothesis from tuples to tensors
        premise, premise_transitions = batch.premise
        hypothesis, hypothesis_transitions = batch.hypothesis
        label = batch.label

        # check each is actually a tensor
        assert type(premise) == torch.Tensor
        assert type(premise_transitions) == torch.Tensor
        assert type(hypothesis) == torch.Tensor
        assert type(hypothesis_transitions) == torch.Tensor
        assert type(label) == torch.Tensor

        # check have the correct batch dimension
        assert premise.shape[-1] == batch_size
        assert premise_transitions.shape[-1] == batch_size
        assert hypothesis.shape[-1] == batch_size
        assert hypothesis_transitions.shape[-1] == batch_size
        assert label.shape[-1] == batch_size

        # remove downloaded snli directory
        shutil.rmtree('.data/snli')
Beispiel #6
0
def get_snli(text_field, label_field):
    # also filters out unknown label '-'! :)
    return SNLI.splits(text_field, label_field)