Exemplo n.º 1
0
Arquivo: nklm.py Projeto: neulab/lrlm
    def create_one_batch(self, examples: List[NKLMExample], bptt_size: int) \
            -> Tuple[List[List[Relation]], List[BatchSequence]]:
        sentences = [ex.sentence for ex in examples]
        rels = [ex.relations for ex in examples]
        rel_ids = [ex.rel_ids for ex in examples]
        copy_pos = [ex.copy_pos for ex in examples]
        surface_indices = [ex.surface_indices for ex in examples]

        max_length = max(len(s) for s in sentences)
        n_splits = utils.ceil_div(
            max_length - 1,
            bptt_size)  # minus 1 because there's always one extra

        batches = []
        for split_idx in range(n_splits):
            interval = slice(split_idx * bptt_size,
                             (split_idx + 1) * bptt_size + 1)
            sequence = [s[interval] for s in sentences]
            batch = BatchSequence(
                self.word_vocab,
                sequence,
                rel_ids=([s[interval] for s in rel_ids], -1),
                copy_pos=([s[interval] for s in copy_pos], -1),
                surface_indices=([s[interval] for s in surface_indices], -1))
            batches.append(batch)
        return rels, batches
Exemplo n.º 2
0
Arquivo: lrlm.py Projeto: neulab/lrlm
    def create_one_batch(self, examples: List[LRLMExample], bptt_size: int) \
            -> Tuple[List[List[Relation]], List[BatchSequence]]:
        sentences = [x.sentence for x in examples]  # <s> accounted in construct_example method
        max_length = max(len(s) for s in sentences)
        n_splits = utils.ceil_div(max_length - 1, bptt_size)
        init_batch = [ex.relations for ex in examples]

        split_spans: List[List[List[MatchedSpan]]] = [[[] for __ in examples] for _ in range(n_splits)]
        for b_idx, ex in enumerate(examples):
            for span in ex.spans:
                # Begin index of the span should be the index before the expression begin.
                # End index of the span points to the index of the last word in the entity
                start = span.start % bptt_size
                end = (span.end - 1) % bptt_size
                # if start > end:
                #     continue
                split_spans[span.start // bptt_size][b_idx].append(span._replace(start=start, end=end))
        batches = []
        for split_idx in range(n_splits):
            interval = slice(split_idx * bptt_size, (split_idx + 1) * bptt_size + 1)
            rels_interval = split_spans[split_idx]
            sentence_interval = [sent[interval] for sent in sentences]
            batch = BatchSequence(self.word_vocab, sentence_interval, rels_interval)
            batches.append(batch)
        return init_batch, batches
Exemplo n.º 3
0
    def create_one_batch(self, examples: List[LMExample],
                         bptt_size: int) -> Tuple[None, List[BatchSequence]]:
        sentences = [x.sentence for x in examples]
        max_length = max(len(s) for s in sentences)
        n_splits = utils.ceil_div(max_length - 1, bptt_size)

        batches = []
        for split_idx in range(n_splits):
            interval = slice(split_idx * bptt_size,
                             (split_idx + 1) * bptt_size + 1)
            sentence_interval = [s[interval] for s in sentences]
            batch = BatchSequence(self.word_vocab, sentence_interval)
            batches.append(batch)

        return None, batches
Exemplo n.º 4
0
    def create_one_batch(self, examples: List[LRLMExample],
                         bptt_size: int) -> Tuple[Tensor, List[BatchSequence]]:
        sentences = [x.sentence for x in examples
                     ]  # <s> accounted in construct_example method
        max_length = max(len(s) for s in sentences)
        n_splits = utils.ceil_div(max_length - 1, bptt_size)
        init_batch = ([(ex.spans, ex.relations)
                       for ex in examples], self.vocab)

        mentioned_relation_ids = [(ex.relations,
                                   list(set(span.rel_idx
                                            for span in ex.spans)))
                                  for ex in examples]

        # linearized alias surface forms whose relation actually appear in the article
        linearized_aliases: List[List[int]] = [[
            self.word_vocab.w2i.get(word, 0)
            for rel_idx, relation in enumerate(relations)
            if rel_idx in mentioned for words in relation.obj_alias
            for word in words.split(" ")
        ] for relations, mentioned in mentioned_relation_ids]

        for i, aliases in enumerate(linearized_aliases):
            if len(aliases) == 0:
                linearized_aliases[i].append(
                    0)  # Add an UNK as a dummy, if the seq_len is zero.

        padded_aliases = torch.tensor(pad(linearized_aliases,
                                          pad_symbol=1,
                                          append=False),
                                      dtype=torch.long)

        batches = []
        for split_idx in range(n_splits):
            interval = slice(split_idx * bptt_size,
                             (split_idx + 1) * bptt_size + 1)
            sentence_interval = [sent[interval] for sent in sentences]
            batch = BatchSequence(self.word_vocab, sentence_interval)
            batches.append(batch)

        return padded_aliases, batches
Exemplo n.º 5
0
    def create_batches(self, batch_size: int, bptt_size: int):
        r"""A general routine to create batches of specified batch size and BPTT length.

        :param batch_size: The number of examples in one batch.
        :param bptt_size: The length for truncated-backprop, i.e. the maximum length of sentences in one batch.
        """
        self.batches = {}
        self.ntokens = {}
        for split, raw_dataset in self.data.items():
            ntokens = 0
            # sort the data by document length
            parts = sorted(raw_dataset, key=len)
            num_batches = utils.ceil_div(len(parts), batch_size)
            all_batches = []
            for batch_idx in utils.progress(num_batches,
                                            desc="Creating batches",
                                            ascii=True,
                                            ncols=80):
                part = parts[(batch_idx * batch_size):((batch_idx + 1) *
                                                       batch_size)]
                init_batch, batches = self.create_one_batch(part, bptt_size)
                ntokens += sum(batch.ntokens for batch in batches)
                all_batches.append((init_batch, batches))
            self.batches[split] = all_batches
            self.ntokens[split] = ntokens

        unk_probs = self.unk_probs
        if unk_probs is not None:
            total_w2i = self.total_w2i
            for split, dataset in self.batches.items():
                dataset = utils.progress(
                    dataset,
                    ncols=80,
                    desc=f"Adding unk vocab for {split} set",
                    ascii=True)
                for _, batches in dataset:
                    for batch in batches:
                        batch.add_unk_probs(unk_probs, total_w2i)