Exemplo n.º 1
0
    def helper(index):
        trees = []

        while index < len(tokens) and tokens[index] == "(":
            paren_count = 0
            while tokens[index] == "(":
                index += 1
                paren_count += 1

            label = tokens[index]
            index += 1

            if tokens[index] == "(":
                children, index = helper(index)
                trees.append(InternalTreebankNode(label, children))
            else:
                word = tokens[index]
                index += 1
                trees.append(LeafTreebankNode(label, word))

            while paren_count > 0:
                assert tokens[index] == ")"
                index += 1
                paren_count -= 1

        return trees, index
Exemplo n.º 2
0
 def decode_from_chart_batch(self, sentences, charts_np, golds=None):
     trees = []
     scores = []
     if golds is None:
         golds = [None] * len(sentences)
     for sentence, chart_np, gold in zip(sentences, charts_np, golds):
         tree, score = self.decode_from_chart(sentence, chart_np, gold)
         trees.append(tree)
         scores.append(score)
     return trees, scores
Exemplo n.º 3
0
def tf_parse_batch_new(sentences):
    inp_val = charify_batch([[word for (tag, word) in sentence] for sentence in sentences])
    out_val = new_sess.run(new_out, {new_inp: inp_val})

    trees = []
    scores = []
    for snum, sentence in enumerate(sentences):
        chart_size = len(sentence) + 1
        tf_chart = out_val[snum,:chart_size,:chart_size,:]
        tree, score = parser.decode_from_chart(sentence, tf_chart)
        trees.append(tree)
        scores.append(score)
    return trees, scores
Exemplo n.º 4
0
def tf_parse_batch_new(sentences):
    inp_val_tokens, inp_val_mask = bertify_batch([[word for (tag, word) in sentence] for sentence in sentences])
    out_val_chart, out_val_tags = new_sess.run((new_out_chart, new_out_tags), {new_inp_tokens: inp_val_tokens, new_inp_mask: inp_val_mask})

    trees = []
    scores = []
    for snum, sentence in enumerate(sentences):
        chart_size = len(sentence) + 1
        tf_chart = out_val_chart[snum,:chart_size,:chart_size,:]
        sentence = list(zip([TAG_VOCAB[idx] for idx in out_val_tags[snum,1:chart_size]], [x[1] for x in sentence]))
        tree, score = parser.decode_from_chart(sentence, tf_chart)
        trees.append(tree)
        scores.append(score)
    return trees, scores
Exemplo n.º 5
0
    def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
        is_train = golds is not None
        self.train(is_train)
        torch.set_grad_enabled(is_train)

        if golds is None:
            golds = [None] * len(sentences)

        packed_len = sum([(len(sentence) + 2) for sentence in sentences])

        i = 0
        tag_idxs = np.zeros(packed_len, dtype=int)
        word_idxs = np.zeros(packed_len, dtype=int)
        batch_idxs = np.zeros(packed_len, dtype=int)
        for snum, sentence in enumerate(sentences):
            for (tag, word) in [(START, START)] + sentence + [(STOP, STOP)]:
                tag_idxs[i] = 0 if (not self.use_tags and self.f_tag is None) else self.tag_vocab.index_or_unk(tag,
                                                                                                               TAG_UNK)
                if word not in (START, STOP):
                    count = self.word_vocab.count(word)
                    if not count or (is_train and np.random.rand() < 1 / (1 + count)):
                        word = UNK
                word_idxs[i] = self.word_vocab.index(word)
                batch_idxs[i] = snum
                i += 1
        assert i == packed_len

        batch_idxs = BatchIndices(batch_idxs)

        emb_idxs_map = {
            'tags': tag_idxs,
            'words': word_idxs,
        }
        emb_idxs = [
            from_numpy(emb_idxs_map[emb_type])
            for emb_type in self.emb_types
        ]

        if is_train and self.f_tag is not None:
            gold_tag_idxs = from_numpy(emb_idxs_map['tags'])

        extra_content_annotations = None
        if self.char_encoder is not None:
            assert isinstance(self.char_encoder, CharacterLSTM)
            max_word_len = max([max([len(word) for tag, word in sentence]) for sentence in sentences])
            # Add 2 for start/stop tokens
            max_word_len = max(max_word_len, 3) + 2
            char_idxs_encoder = np.zeros((packed_len, max_word_len), dtype=int)
            word_lens_encoder = np.zeros(packed_len, dtype=int)

            i = 0
            for snum, sentence in enumerate(sentences):
                for wordnum, (tag, word) in enumerate([(START, START)] + sentence + [(STOP, STOP)]):
                    j = 0
                    char_idxs_encoder[i, j] = self.char_vocab.index(CHAR_START_WORD)
                    j += 1
                    if word in (START, STOP):
                        char_idxs_encoder[i, j:j + 3] = self.char_vocab.index(
                            CHAR_START_SENTENCE if (word == START) else CHAR_STOP_SENTENCE
                        )
                        j += 3
                    else:
                        for char in word:
                            char_idxs_encoder[i, j] = self.char_vocab.index_or_unk(char, CHAR_UNK)
                            j += 1
                    char_idxs_encoder[i, j] = self.char_vocab.index(CHAR_STOP_WORD)
                    word_lens_encoder[i] = j + 1
                    i += 1
            assert i == packed_len

            extra_content_annotations = self.char_encoder(char_idxs_encoder, word_lens_encoder, batch_idxs)
        elif self.elmo is not None:
            # See https://github.com/allenai/allennlp/blob/c3c3549887a6b1fb0bc8abf77bc820a3ab97f788/allennlp/data/token_indexers/elmo_indexer.py#L61
            # ELMO_START_SENTENCE = 256
            # ELMO_STOP_SENTENCE = 257
            ELMO_START_WORD = 258
            ELMO_STOP_WORD = 259
            ELMO_CHAR_PAD = 260

            # Sentence start/stop tokens are added inside the ELMo module
            max_sentence_len = max([(len(sentence)) for sentence in sentences])
            max_word_len = 50
            char_idxs_encoder = np.zeros((len(sentences), max_sentence_len, max_word_len), dtype=int)

            for snum, sentence in enumerate(sentences):
                for wordnum, (tag, word) in enumerate(sentence):
                    char_idxs_encoder[snum, wordnum, :] = ELMO_CHAR_PAD

                    j = 0
                    char_idxs_encoder[snum, wordnum, j] = ELMO_START_WORD
                    j += 1
                    assert word not in (START, STOP)
                    for char_id in word.encode('utf-8', 'ignore')[:(max_word_len - 2)]:
                        char_idxs_encoder[snum, wordnum, j] = char_id
                        j += 1
                    char_idxs_encoder[snum, wordnum, j] = ELMO_STOP_WORD

                    # +1 for masking (everything that stays 0 is past the end of the sentence)
                    char_idxs_encoder[snum, wordnum, :] += 1

            char_idxs_encoder = from_numpy(char_idxs_encoder)

            elmo_out = self.elmo.forward(char_idxs_encoder)
            elmo_rep0 = elmo_out['elmo_representations'][0]
            elmo_mask = elmo_out['mask']

            elmo_annotations_packed = elmo_rep0[elmo_mask.byte()].view(packed_len, -1)

            # Apply projection to match dimensionality
            extra_content_annotations = self.project_elmo(elmo_annotations_packed)
        elif self.bert is not None:
            all_input_ids = np.zeros((len(sentences), self.bert_max_len), dtype=int)
            all_input_mask = np.zeros((len(sentences), self.bert_max_len), dtype=int)
            all_word_start_mask = np.zeros((len(sentences), self.bert_max_len), dtype=int)
            all_word_end_mask = np.zeros((len(sentences), self.bert_max_len), dtype=int)

            subword_max_len = 0
            for snum, sentence in enumerate(sentences):
                tokens = []
                word_start_mask = []
                word_end_mask = []

                tokens.append("[CLS]")
                word_start_mask.append(1)
                word_end_mask.append(1)

                if self.bert_transliterate is None:
                    cleaned_words = []
                    for _, word in sentence:
                        word = BERT_TOKEN_MAPPING.get(word, word)
                        if word == "n't" and cleaned_words:
                            cleaned_words[-1] = cleaned_words[-1] + "n"
                            word = "'t"
                        cleaned_words.append(word)
                else:
                    # When transliterating, assume that the token mapping is
                    # taken care of elsewhere
                    cleaned_words = [self.bert_transliterate(word) for _, word in sentence]

                for word in cleaned_words:
                    word_tokens = self.bert_tokenizer.tokenize(word)
                    for _ in range(len(word_tokens)):
                        word_start_mask.append(0)
                        word_end_mask.append(0)
                    word_start_mask[len(tokens)] = 1
                    word_end_mask[-1] = 1
                    tokens.extend(word_tokens)
                tokens.append("[SEP]")
                word_start_mask.append(1)
                word_end_mask.append(1)

                input_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens)

                # The mask has 1 for real tokens and 0 for padding tokens. Only real
                # tokens are attended to.
                input_mask = [1] * len(input_ids)

                subword_max_len = max(subword_max_len, len(input_ids))

                all_input_ids[snum, :len(input_ids)] = input_ids
                all_input_mask[snum, :len(input_mask)] = input_mask
                all_word_start_mask[snum, :len(word_start_mask)] = word_start_mask
                all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

            all_input_ids = from_numpy(np.ascontiguousarray(all_input_ids[:, :subword_max_len]))
            all_input_mask = from_numpy(np.ascontiguousarray(all_input_mask[:, :subword_max_len]))
            all_word_start_mask = from_numpy(np.ascontiguousarray(all_word_start_mask[:, :subword_max_len]))
            all_word_end_mask = from_numpy(np.ascontiguousarray(all_word_end_mask[:, :subword_max_len]))
            all_encoder_layers, _ = self.bert(all_input_ids, attention_mask=all_input_mask)
            del _
            features = all_encoder_layers[-1]

            if self.encoder is not None:
                features_packed = features.masked_select(all_word_end_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1,
                                                                                                                  features.shape[
                                                                                                                      -1])

                # For now, just project the features from the last word piece in each word
                extra_content_annotations = self.project_bert(features_packed)

        if self.encoder is not None:
            annotations, _ = self.encoder(emb_idxs, batch_idxs, extra_content_annotations=extra_content_annotations)

            if self.partitioned:
                # Rearrange the annotations to ensure that the transition to
                # fenceposts captures an even split between position and content.
                # TODO(nikita): try alternatives, such as omitting position entirely
                annotations = torch.cat([
                    annotations[:, 0::2],
                    annotations[:, 1::2],
                ], 1)

            if self.f_tag is not None:
                tag_annotations = annotations

            fencepost_annotations = torch.cat([
                annotations[:-1, :self.d_model // 2],
                annotations[1:, self.d_model // 2:],
            ], 1)
            fencepost_annotations_start = fencepost_annotations
            fencepost_annotations_end = fencepost_annotations
        else:
            assert self.bert is not None
            features = self.project_bert(features)
            fencepost_annotations_start = features.masked_select(
                all_word_start_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1, features.shape[-1])
            fencepost_annotations_end = features.masked_select(all_word_end_mask.to(torch.uint8).unsqueeze(-1)).reshape(
                -1, features.shape[-1])
            if self.f_tag is not None:
                tag_annotations = fencepost_annotations_end

        if self.f_tag is not None:
            tag_logits = self.f_tag(tag_annotations)
            if is_train:
                tag_loss = self.tag_loss_scale * nn.functional.cross_entropy(tag_logits, gold_tag_idxs, reduction='sum')

        # Note that the subtraction above creates fenceposts at sentence
        # boundaries, which are not used by our parser. Hence subtract 1
        # when creating fp_endpoints
        fp_startpoints = batch_idxs.boundaries_np[:-1]
        fp_endpoints = batch_idxs.boundaries_np[1:] - 1

        # Just return the charts, for ensembling
        if return_label_scores_charts:
            charts = []
            for i, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
                chart = self.label_scores_from_annotations(fencepost_annotations_start[start:end, :],
                                                           fencepost_annotations_end[start:end, :])
                charts.append(chart.cpu().data.numpy())
            return charts

        if not is_train:
            trees = []
            scores = []
            if self.f_tag is not None:
                # Note that tag_logits includes tag predictions for start/stop tokens
                tag_idxs = torch.argmax(tag_logits, -1).cpu()
                per_sentence_tag_idxs = torch.split_with_sizes(tag_idxs, [len(sentence) + 2 for sentence in sentences])
                per_sentence_tags = [[self.tag_vocab.value(idx) for idx in idxs[1:-1]] for idxs in
                                     per_sentence_tag_idxs]

            for i, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
                sentence = sentences[i]
                if self.f_tag is not None:
                    sentence = list(zip(per_sentence_tags[i], [x[1] for x in sentence]))
                tree, score = self.parse_from_annotations(fencepost_annotations_start[start:end, :],
                                                          fencepost_annotations_end[start:end, :], sentence, golds[i])
                trees.append(tree)
                scores.append(score)
            return trees, scores

        # During training time, the forward pass needs to be computed for every
        # cell of the chart, but the backward pass only needs to be computed for
        # cells in either the predicted or the gold parse tree. It's slightly
        # faster to duplicate the forward pass for a subset of the chart than it
        # is to perform a backward pass that doesn't take advantage of sparsity.
        # Since this code is not undergoing algorithmic changes, it makes sense
        # to include the optimization even though it may only be a 10% speedup.
        # Note that no dropout occurs in the label portion of the network
        pis = []
        pjs = []
        plabels = []
        paugment_total = 0.0
        num_p = 0
        gis = []
        gjs = []
        glabels = []
        with torch.no_grad():
            for i, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
                p_i, p_j, p_label, p_augment, g_i, g_j, g_label = self.parse_from_annotations(
                    fencepost_annotations_start[start:end, :], fencepost_annotations_end[start:end, :], sentences[i],
                    golds[i])
                paugment_total += p_augment
                num_p += p_i.shape[0]
                pis.append(p_i + start)
                pjs.append(p_j + start)
                gis.append(g_i + start)
                gjs.append(g_j + start)
                plabels.append(p_label)
                glabels.append(g_label)

        cells_i = from_numpy(np.concatenate(pis + gis))
        cells_j = from_numpy(np.concatenate(pjs + gjs))
        cells_label = from_numpy(np.concatenate(plabels + glabels))

        cells_label_scores = self.f_label(fencepost_annotations_end[cells_j] - fencepost_annotations_start[cells_i])
        cells_label_scores = torch.cat([
            cells_label_scores.new_zeros((cells_label_scores.size(0), 1)),
            cells_label_scores
        ], 1)
        cells_scores = torch.gather(cells_label_scores, 1, cells_label[:, None])
        loss = cells_scores[:num_p].sum() - cells_scores[num_p:].sum() + paugment_total

        if self.f_tag is not None:
            return None, (loss, tag_loss)
        else:
            return None, loss
    def parse_batch(self, sentences, golds=None):
        is_train = golds is not None
        self.train(is_train)
        torch.set_grad_enabled(is_train)

        if golds is None:
            golds = [None] * len(sentences)

        packed_len = sum([(len(sentence) + 2) for sentence in sentences])

        i = 0
        tag_idxs = np.zeros(packed_len, dtype=int)
        word_idxs = np.zeros(packed_len, dtype=int)
        batch_idxs = np.zeros(packed_len, dtype=int)
        for snum, sentence in enumerate(sentences):
            for (tag, word) in [(START, START)] + sentence + [(STOP, STOP)]:
                tag_idxs[i] = 0 if (
                    not self.use_tags) else self.tag_vocab.index_or_unk(
                        tag, TAG_UNK)
                if word not in (START, STOP):
                    count = self.word_vocab.count(word)
                    if not count or (is_train and np.random.rand() < 1 /
                                     (1 + count)):
                        word = UNK
                word_idxs[i] = self.word_vocab.index(word)
                batch_idxs[i] = snum
                i += 1
        assert i == packed_len

        batch_idxs = BatchIndices(batch_idxs)

        emb_idxs_map = {
            'tags': tag_idxs,
            'words': word_idxs,
        }
        emb_idxs = [
            from_numpy(emb_idxs_map[emb_type]) for emb_type in self.emb_types
        ]

        all_input_ids = np.zeros((len(sentences), self.bert_max_len),
                                 dtype=int)
        all_input_mask = np.zeros((len(sentences), self.bert_max_len),
                                  dtype=int)
        all_word_start_mask = np.zeros((len(sentences), self.bert_max_len),
                                       dtype=int)
        all_word_end_mask = np.zeros((len(sentences), self.bert_max_len),
                                     dtype=int)

        subword_max_len = 0
        for snum, sentence in enumerate(sentences):
            tokens = []
            word_start_mask = []
            word_end_mask = []

            tokens.append("[CLS]")
            word_start_mask.append(1)
            word_end_mask.append(1)

            if self.bert_transliterate is None:
                cleaned_words = []
                for _, word in sentence:
                    word = BERT_TOKEN_MAPPING.get(word, word)
                    # This un-escaping for / and * was not yet added for the
                    # parser version in https://arxiv.org/abs/1812.11760v1
                    # and related model releases (e.g. benepar_en2)
                    word = word.replace('\\/', '/').replace('\\*', '*')
                    # Mid-token punctuation occurs in biomedical text
                    word = word.replace('-LSB-', '[').replace('-RSB-', ']')
                    word = word.replace('-LRB-', '(').replace('-RRB-', ')')
                    if word == "n't" and cleaned_words:
                        cleaned_words[-1] = cleaned_words[-1] + "n"
                        word = "'t"
                    cleaned_words.append(word)
            else:
                # When transliterating, assume that the token mapping is
                # taken care of elsewhere
                cleaned_words = [
                    self.bert_transliterate(word) for _, word in sentence
                ]

            for word in cleaned_words:
                word_tokens = self.bert_tokenizer.tokenize(word)
                for _ in range(len(word_tokens)):
                    word_start_mask.append(0)
                    word_end_mask.append(0)
                word_start_mask[len(tokens)] = 1
                word_end_mask[-1] = 1
                tokens.extend(word_tokens)
            tokens.append("[SEP]")
            word_start_mask.append(1)
            word_end_mask.append(1)

            input_ids = self.bert_tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            subword_max_len = max(subword_max_len, len(input_ids))

            all_input_ids[snum, :len(input_ids)] = input_ids
            all_input_mask[snum, :len(input_mask)] = input_mask
            all_word_start_mask[snum, :len(word_start_mask)] = word_start_mask
            all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

        all_input_ids = from_numpy(
            np.ascontiguousarray(all_input_ids[:, :subword_max_len]))
        all_input_mask = from_numpy(
            np.ascontiguousarray(all_input_mask[:, :subword_max_len]))
        all_word_start_mask = from_numpy(
            np.ascontiguousarray(all_word_start_mask[:, :subword_max_len]))
        all_word_end_mask = from_numpy(
            np.ascontiguousarray(all_word_end_mask[:, :subword_max_len]))
        # all_encoder_layers, _ = self.bert(all_input_ids, attention_mask=all_input_mask)
        #
        # del _
        # features = all_encoder_layers[-1]
        features = self.bert(all_input_ids, attention_mask=all_input_mask)[0]

        features_packed = \
            features.masked_select(all_word_end_mask.to(torch.bool).unsqueeze(-1)).reshape(-1, features.shape[-1])

        # For now, just project the features from the last word piece in each word
        extra_content_annotations = self.project_bert(features_packed)

        annotations, _ = self.encoder(
            emb_idxs,
            batch_idxs,
            extra_content_annotations=extra_content_annotations)

        if self.partitioned:
            # Rearrange the annotations to ensure that the transition to
            # fenceposts captures an even split between position and content.
            # TODO(nikita): try alternatives, such as omitting position entirely
            annotations = torch.cat([
                annotations[:, 0::2],
                annotations[:, 1::2],
            ], 1)

        fencepost_annotations = torch.cat([
            annotations[:-1, :self.d_model // 2],
            annotations[1:, self.d_model // 2:],
        ], 1)
        fencepost_annotations_start = fencepost_annotations
        fencepost_annotations_end = fencepost_annotations

        # Note that the subtraction above creates fenceposts at sentence
        # boundaries, which are not used by our parser. Hence subtract 1
        # when creating fp_endpoints
        fp_startpoints = batch_idxs.boundaries_np[:-1]
        fp_endpoints = batch_idxs.boundaries_np[1:] - 1

        if not is_train:
            trees = []
            scores = []

            for i, (start, end) in enumerate(zip(fp_startpoints,
                                                 fp_endpoints)):
                sentence = sentences[i]
                tree, score = self.parse_from_annotations(
                    fencepost_annotations_start[start:end, :],
                    fencepost_annotations_end[start:end, :], sentence,
                    golds[i])
                trees.append(tree)
                scores.append(score)
            return trees, scores
        else:
            loss_total = None
            for i, (start, end) in enumerate(zip(fp_startpoints,
                                                 fp_endpoints)):
                tree, loss = self.parse_from_annotations(
                    fencepost_annotations_start[start:end, :],
                    fencepost_annotations_end[start:end, :], sentences[i],
                    golds[i])
                if loss_total is None:
                    loss_total = loss
                else:
                    loss_total = loss_total + loss
            return None, loss_total
Exemplo n.º 7
0
    def parse_batch(self,
                    sentences,
                    golds=None,
                    return_label_scores_charts=False):
        is_train = golds is not None
        self.train(is_train)

        if golds is None:
            golds = [None] * len(sentences)

        packed_len = sum([(len(sentence) + 2) for sentence in sentences])

        i = 0
        tag_idxs = np.zeros(packed_len, dtype=int)
        word_idxs = np.zeros(packed_len, dtype=int)
        batch_idxs = np.zeros(packed_len, dtype=int)
        for snum, sentence in enumerate(sentences):
            for (tag, word) in [(START, START)] + sentence + [(STOP, STOP)]:
                tag_idxs[
                    i] = 0 if not self.use_tags else self.tag_vocab.index_or_unk(
                        tag, TAG_UNK)
                if word not in (START, STOP):
                    count = self.word_vocab.count(word)
                    if not count or (is_train and np.random.rand() < 1 /
                                     (1 + count)):
                        word = UNK
                word_idxs[i] = self.word_vocab.index(word)
                batch_idxs[i] = snum
                i += 1
        assert i == packed_len

        batch_idxs = BatchIndices(batch_idxs)

        emb_idxs_map = {
            'tags': tag_idxs,
            'words': word_idxs,
        }
        emb_idxs = [
            Variable(from_numpy(emb_idxs_map[emb_type]),
                     requires_grad=False,
                     volatile=not is_train) for emb_type in self.emb_types
        ]

        extra_content_annotations = None
        if self.char_encoder is not None:
            assert isinstance(self.char_encoder, CharacterLSTM)
            max_word_len = max([
                max([len(word) for tag, word in sentence])
                for sentence in sentences
            ])
            # Add 2 for start/stop tokens
            max_word_len = max(max_word_len, 3) + 2
            char_idxs_encoder = np.zeros((packed_len, max_word_len), dtype=int)
            word_lens_encoder = np.zeros(packed_len, dtype=int)

            i = 0
            for snum, sentence in enumerate(sentences):
                for wordnum, (tag,
                              word) in enumerate([(START, START)] + sentence +
                                                 [(STOP, STOP)]):
                    j = 0
                    char_idxs_encoder[i, j] = self.char_vocab.index(
                        CHAR_START_WORD)
                    j += 1
                    if word in (START, STOP):
                        char_idxs_encoder[i, j:j + 3] = self.char_vocab.index(
                            CHAR_START_SENTENCE if (
                                word == START) else CHAR_STOP_SENTENCE)
                        j += 3
                    else:
                        for char in word:
                            char_idxs_encoder[
                                i, j] = self.char_vocab.index_or_unk(
                                    char, CHAR_UNK)
                            j += 1
                    char_idxs_encoder[i, j] = self.char_vocab.index(
                        CHAR_STOP_WORD)
                    word_lens_encoder[i] = j + 1
                    i += 1
            assert i == packed_len

            extra_content_annotations = self.char_encoder(
                char_idxs_encoder, word_lens_encoder, batch_idxs)
        elif self.char_embedding is not None:
            char_idxs_encoder = np.zeros((packed_len, self.num_chars_flat),
                                         dtype=int)

            i = 0
            for snum, sentence in enumerate(sentences):
                for wordnum, (tag,
                              word) in enumerate([(START, START)] + sentence +
                                                 [(STOP, STOP)]):
                    if word == START:
                        char_idxs_encoder[i, :] = self.char_vocab.index(
                            CHAR_START_SENTENCE)
                    elif word == STOP:
                        char_idxs_encoder[i, :] = self.char_vocab.index(
                            CHAR_STOP_SENTENCE)
                    else:
                        word_chars = (
                            ([self.char_vocab.index(CHAR_START_WORD)] *
                             self.num_chars_flat) + [
                                 self.char_vocab.index_or_unk(char, CHAR_UNK)
                                 for char in word
                             ] + ([self.char_vocab.index(CHAR_STOP_WORD)] *
                                  self.num_chars_flat))
                        char_idxs_encoder[
                            i, :self.num_chars_flat //
                            2] = word_chars[self.num_chars_flat:self.
                                            num_chars_flat +
                                            self.num_chars_flat // 2]
                        char_idxs_encoder[
                            i, self.num_chars_flat //
                            2:] = word_chars[::-1][self.num_chars_flat:self.
                                                   num_chars_flat +
                                                   self.num_chars_flat // 2]
                    i += 1
            assert i == packed_len

            char_idxs_encoder = Variable(from_numpy(char_idxs_encoder),
                                         requires_grad=False,
                                         volatile=not is_train)

            extra_content_annotations = self.char_embedding(char_idxs_encoder)
            extra_content_annotations = extra_content_annotations.view(
                -1, self.num_chars_flat * self.char_embedding.embedding_dim)
        if self.elmo is not None:
            # See https://github.com/allenai/allennlp/blob/c3c3549887a6b1fb0bc8abf77bc820a3ab97f788/allennlp/data/token_indexers/elmo_indexer.py#L61
            # ELMO_START_SENTENCE = 256
            # ELMO_STOP_SENTENCE = 257
            ELMO_START_WORD = 258
            ELMO_STOP_WORD = 259
            ELMO_CHAR_PAD = 260

            # Sentence start/stop tokens are added inside the ELMo module
            max_sentence_len = max([(len(sentence)) for sentence in sentences])
            max_word_len = 50
            char_idxs_encoder = np.zeros(
                (len(sentences), max_sentence_len, max_word_len), dtype=int)

            for snum, sentence in enumerate(sentences):
                for wordnum, (tag, word) in enumerate(sentence):
                    char_idxs_encoder[snum, wordnum, :] = ELMO_CHAR_PAD

                    j = 0
                    char_idxs_encoder[snum, wordnum, j] = ELMO_START_WORD
                    j += 1
                    assert word not in (START, STOP)
                    for char_id in word.encode('utf-8',
                                               'ignore')[:(max_word_len - 2)]:
                        char_idxs_encoder[snum, wordnum, j] = char_id
                        j += 1
                    char_idxs_encoder[snum, wordnum, j] = ELMO_STOP_WORD

                    # +1 for masking (everything that stays 0 is past the end of the sentence)
                    char_idxs_encoder[snum, wordnum, :] += 1

            char_idxs_encoder = Variable(from_numpy(char_idxs_encoder),
                                         requires_grad=False)

            elmo_out = self.elmo.forward(char_idxs_encoder)
            elmo_rep0 = elmo_out['elmo_representations'][0]
            elmo_mask = elmo_out['mask']

            elmo_annotations_packed = elmo_rep0[elmo_mask.byte().unsqueeze(
                -1)].view(packed_len, -1)

            # Apply projection to match dimensionality
            extra_content_annotations = self.project_elmo(
                elmo_annotations_packed)

        annotations, _ = self.encoder(
            emb_idxs,
            batch_idxs,
            extra_content_annotations=extra_content_annotations)

        if self.partitioned:
            # Rearrange the annotations to ensure that the transition to
            # fenceposts captures an even split between position and content.
            # TODO(nikita): try alternatives, such as omitting position entirely
            annotations = torch.cat([
                annotations[:, 0::2],
                annotations[:, 1::2],
            ], 1)

        fencepost_annotations = torch.cat([
            annotations[:-1, :self.d_model // 2],
            annotations[1:, self.d_model // 2:],
        ], 1)
        # Note that the subtraction above creates fenceposts at sentence
        # boundaries, which are not used by our parser. Hence subtract 1
        # when creating fp_endpoints
        fp_startpoints = batch_idxs.boundaries_np[:-1]
        fp_endpoints = batch_idxs.boundaries_np[1:] - 1

        # Just return the charts, for ensembling
        if return_label_scores_charts:
            charts = []
            for i, (start, end) in enumerate(zip(fp_startpoints,
                                                 fp_endpoints)):
                chart = self.label_scores_from_annotations(
                    fencepost_annotations[start:end, :])
                charts.append(chart.cpu().data.numpy())
            return charts

        if not is_train:
            trees = []
            scores = []
            for i, (start, end) in enumerate(zip(fp_startpoints,
                                                 fp_endpoints)):
                tree, score = self.parse_from_annotations(
                    fencepost_annotations[start:end, :], sentences[i],
                    golds[i])
                trees.append(tree)
                scores.append(score)
            return trees, scores

        # During training time, the forward pass needs to be computed for every
        # cell of the chart, but the backward pass only needs to be computed for
        # cells in either the predicted or the gold parse tree. It's slightly
        # faster to duplicate the forward pass for a subset of the chart than it
        # is to perform a backward pass that doesn't take advantage of sparsity.
        # Since this code is not undergoing algorithmic changes, it makes sense
        # to include the optimization even though it may only be a 10% speedup.
        # Note that no dropout occurs in the label portion of the network
        pis = []
        pjs = []
        plabels = []
        paugment_total = 0.0
        num_p = 0
        gis = []
        gjs = []
        glabels = []
        fencepost_annotations_d = fencepost_annotations.detach()
        fencepost_annotations_d.volatile = True
        for i, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
            p_i, p_j, p_label, p_augment, g_i, g_j, g_label = self.parse_from_annotations(
                fencepost_annotations_d[start:end, :], sentences[i], golds[i])
            paugment_total += p_augment
            num_p += p_i.shape[0]
            pis.append(p_i + start)
            pjs.append(p_j + start)
            gis.append(g_i + start)
            gjs.append(g_j + start)
            plabels.append(p_label)
            glabels.append(g_label)

        cells_i = from_numpy(np.concatenate(pis + gis))
        cells_j = from_numpy(np.concatenate(pjs + gjs))
        cells_label = from_numpy(np.concatenate(plabels + glabels))

        cells_label_scores = self.f_label(fencepost_annotations[cells_j] -
                                          fencepost_annotations[cells_i])
        cells_label_scores = torch.cat([
            Variable(torch_t.FloatTensor(cells_label_scores.size(0),
                                         1).fill_(0),
                     requires_grad=False), cells_label_scores
        ], 1)
        cells_scores = torch.gather(cells_label_scores, 1,
                                    Variable(cells_label[:, None]))
        loss = cells_scores[:num_p].sum() - cells_scores[num_p:].sum(
        ) + paugment_total
        return None, loss