def embed_stuff(self, batch, seq_lengths):
        """
        Build tensors from the token indices

        Args:
            batch: List[Example]
            seq_lengths: (batch_size,)
        Returns:
            bert_out: (batch_size, max_seq_len, seq_in_size)
        """
        batch_token_indices = []
        for example in batch:
            token_indices = example.token_indices
            # Pad with [CLS] and [SEP]
            token_indices = [self.CLS_index] + token_indices + [self.SEP_index]
            # Add to list
            token_indices = torch.tensor(token_indices)
            batch_token_indices.append(token_indices)
        # input_ids: (batch_size, max_seq_len + 2)
        input_ids = pad_sequence(batch_token_indices, batch_first=True)
        input_ids = try_gpu(input_ids)
        # attention_mask and token_type_ids
        positions = try_gpu(torch.arange(input_ids.size()[1])[None, :])
        attention_mask = 1 * (positions < (seq_lengths + 2)[:, None])
        token_type_ids = try_gpu(torch.zeros(input_ids.size(), dtype=int))
        # Encode!
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        return bert_output.last_hidden_state
    def forward(self, logit: TreeParserOutput, batch):
        losses: List[torch.Tensor] = []

        # for each training example
        for i in range(len(logit.batch_span_scores)):
            if self.edge_loss is not None:
                decoded_spans, decoded_edges = logit.batch_span_scores[i]
            else:
                decoded_spans = logit.batch_span_scores[i]
            gold_proto_node = logit.batch_golds[i]
            gold_chains = self.tree_to_chains(gold_proto_node)
            total_loss = 0.0

            # Add node losses
            # prediction: (num_spans, num_classes)
            # where num_classes = 1 + len(self.unary_chains)
            # The first column for NULL class is a zero vector.
            prediction = torch.cat(
                [
                    try_gpu(torch.zeros(len(decoded_spans), 1)),
                    torch.stack(
                        [decoded_span.labels for decoded_span in decoded_spans]
                    ),
                ],
                dim=1,
            )
            # targets: (num_spans,)
            # The unary chain indices are shifted by 1 since we padded
            # the prediction with the NULL class at index 0.
            targets = [
                gold_chains.get((decoded_span.start, decoded_span.end), -1) + 1
                for decoded_span in decoded_spans
            ]
            targets = try_gpu(torch.tensor(targets))
            total_loss = total_loss + self.node_loss(prediction, targets)

            # Add edge losses
            if self.edge_loss is not None:
                span_edges = {(span.start, span.end): span for span in decoded_edges}
                gold_edges = self.tree_to_edges(gold_proto_node)
                if not gold_edges:
                    edge_loss = 0.0
                else:
                    # prediction: (num_gold_edges, num_labels)
                    prediction = []
                    # targets: (num_gold_edges,)
                    targets = []
                    for key, (child_label, parent_label) in gold_edges.items():
                        prediction.append(span_edges[key].edges[child_label])
                        targets.append(parent_label)
                    edge_loss = self.edge_loss(
                        torch.stack(prediction), try_gpu(torch.tensor(targets))
                    )
                total_loss = total_loss + edge_loss

            losses.append(total_loss)

        stacked_losses = torch.stack(losses)
        return stacked_losses.mean()
    def get_batch_representation(
        self,
        bert_out: torch.Tensor,
        spans: List[List[int]],
        seq_lengths: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            bert_out: (batch_size, max_seq_len + 2, seq_in_size)
            spans: (num_spans, 2)
            seq_lengths: (batch_size,)
        Returns:
            representation: (batch_size, num_spans, ???)
        """
        batch_size, padded_max_seq_len, _ = bert_out.size()
        max_seq_len = padded_max_seq_len - 2
        num_spans = len(spans)

        # Create python lists for span-features
        index_matrix = [None] * num_spans
        for i, span in enumerate(spans):
            index_matrix[i] = torch.zeros(padded_max_seq_len).index_put(
                [torch.arange(span[0] + 1, span[1] + 1)], torch.tensor([1.0]))

        start_indices = try_gpu(torch.tensor([x[0] for x in spans]))
        end_indices = try_gpu(torch.tensor([x[1] for x in spans]))
        # (batch_size, num_spans, max_seq_len + 2)
        batch_index_matrix = try_gpu(
            torch.stack(index_matrix).expand(batch_size, -1, -1))

        # List[(batch_size, num_span, ???)]
        things_to_concat: List[torch.Tensor] = []

        # Index select the start and end lstm outputs.
        # (batch_size, num_spans, embed_dim)
        start_hidden = bert_out[:, start_indices + 1, :]
        # (batch_size, num_spans, seq_in_size)
        end_hidden = bert_out[:, end_indices, :]
        things_to_concat += [start_hidden, end_hidden]

        average_weights = batch_index_matrix / torch.sum(
            batch_index_matrix, dim=2, keepdim=True)
        # (batch_size, num_spans, seq_in_size)
        average_hidden = torch.bmm(average_weights, bert_out)
        things_to_concat.append(average_hidden)

        # Concatenate all portions of the representations.
        representation = torch.cat(things_to_concat, dim=2)
        return representation
 def create_model(self):
     config = self.config
     self.model = create_model(config, self.meta)
     self.model = try_gpu(self.model)
     self.optimizer = optim.Adam(self.model.parameters(),
             lr=config.train.learning_rate,
             weight_decay=config.train.l2_reg)
Example #5
0
    def _score_tree(
        self,
        span_scores: Tuple[List[DecodedSpan], List[DecodedEdges]],
        root_proto_node: ProtoNode,
    ) -> torch.Tensor:
        """
        Sums the unary chain scores and edge scores of all spans.

        At test time, unknown unary chains are simply skipped. This should happen
            only when scoring gold trees.
        """
        decoded_spans, decoded_edges = span_scores
        proto_node_stack: List[Tuple[ProtoNode, int, List[int]]] = [
            (root_proto_node, 0, [])
        ]
        span_cands = {(span.start, span.end): span for span in decoded_spans}
        edge_cands = {(span.start, span.end): span for span in decoded_edges}
        tree_score = 0
        while proto_node_stack:
            node, parent_label, chain_so_far = proto_node_stack.pop()
            # If it is still in a chain, push it back to the stack
            if (
                len(node.children) == 1
                and node.children[0].start == node.start
                and node.children[0].end == node.end
            ):
                proto_node_stack.append(
                    (
                        node.children[0],
                        parent_label,
                        chain_so_far + [self.labels_idx[node.label]],
                    )
                )
            else:
                chain = tuple(chain_so_far + [self.labels_idx[node.label]])
                # Node score
                if chain not in self.unary_chain_idx:
                    print("WARNING: chain {} not in training data".format(chain))
                else:
                    span = span_cands[node.start, node.end]
                    tree_score += span.labels[self.unary_chain_idx[chain]]
                # Edge score
                if parent_label != 0:
                    edges = edge_cands[node.start, node.end]
                    tree_score += edges.edges[chain[0]][parent_label]
                # Add the children
                for child in node.children:
                    proto_node_stack.append((child, self.labels_idx[node.label], []))
        if isinstance(tree_score, int):
            # This can happen if all unary chains are unknown.
            print("WARNING: tree_score is 0")
            return try_gpu(torch.zeros(1))
        else:
            return tree_score.view(1)  # noqa
    def forward(self, batch):
        seq_lengths = try_gpu(torch.tensor([x.length for x in batch]))
        max_seq_len = max(x.length for x in batch)

        # bert_out: (batch_size, max_seq_len + 2, seq_in_size)
        bert_out = self.embed_stuff(batch, seq_lengths)

        # Construct a list of all possible spans
        spans = self.all_spans(max_seq_len)

        # Batch head-word reprsentation
        batch_representations = self.get_batch_representation(
            bert_out,
            spans,
            seq_lengths,
        )

        return batch_representations, spans, seq_lengths
Example #7
0
    def forward(self, batch):
        seq_lengths = try_gpu(torch.tensor([x.length for x in batch]))
        max_seq_len = max(x.length for x in batch)

        # embedded_tokens: (batch_size, max_seq_len + 2, embed_dim)
        # lstm_out: (batch_size, max_seq_len + 2, seq_in_size)
        embedded_tokens, lstm_out = self.embed_stuff(batch, seq_lengths)

        # Step 0: construct a list of all possible spans
        spans = self.all_spans(max_seq_len)

        # Step 2: pass lstm output to the FFNN
        attention_weights = self.attention_ffnn(lstm_out)

        # Batch head-word reprsentation
        batch_representations = self.get_batch_representation(
            lstm_out, attention_weights, embedded_tokens, spans, seq_lengths)

        return batch_representations, spans, seq_lengths
Example #8
0
    def embed_stuff(self, batch, seq_lengths):
        """
        Build tensors from the token indices

        Args:
            batch: List[Example]
        Returns:
            embedded_tokens: (batch_size, max_seq_len, embed_dim)
            lstm_out: (batch_size, max_seq_len, seq_in_size)
        """
        batch_token_indices = []
        for example in batch:
            token_indices = example.token_indices
            if self.word_dropout and self.training:
                # Do word-level dropout
                token_indices = [
                    self.UNK_index
                    if np.random.rand() < self._dropout_probs.get(x, 0) else x
                    for x in token_indices
                ]
            # Pad with SOS and EOS
            token_indices = [self.SOS_index] + token_indices + [self.EOS_index]
            # Add to list
            token_indices = torch.tensor(token_indices)
            batch_token_indices.append(token_indices)
        padded_token_indices = pad_sequence(batch_token_indices,
                                            batch_first=True)
        padded_token_indices = try_gpu(padded_token_indices)
        # embedded_tokens: (batch_size, max_seq_len + 2, embed_dim)
        embedded_tokens = self.token_embedder(padded_token_indices)
        packed_embedded_tokens = pack_padded_sequence(
            embedded_tokens,
            seq_lengths + 2,
            batch_first=True,
        )
        packed_lstm_out, _ = self.lstm(packed_embedded_tokens)
        # lstm_out: (batch_size, max_seq_len + 2, seq_in_size)
        lstm_out, _ = pad_packed_sequence(packed_lstm_out, batch_first=True)
        return embedded_tokens, lstm_out
 def __init__(self, config, meta):
     """
     Configs:
         add_edge_loss (bool): Whether to add edge score loss.
         null_weight (float): A multiplier for the loss term when the
             gold class is null (i.e., the span is not in the gold tree).
             Higher null_weight -> fewer spans are predicted.
     """
     super().__init__()
     self.labels = meta.nt
     self.labels_idx = meta.nt_x
     self.unary_chains = meta.unary_chains
     self.chains_idx = meta.unary_chains_x
     # Construct the loss weight; the first class is the NULL class
     self.null_weight = config.model.output_layer.null_weight
     weight = [self.null_weight] + [1.0] * len(self.chains_idx)
     weight = try_gpu(torch.tensor(weight))
     self.node_loss = nn.CrossEntropyLoss(weight=weight)
     if config.model.output_layer.add_edge_loss:
         # Edge scores are already softmax-ed.
         self.edge_loss = nn.NLLLoss()
     else:
         self.edge_loss = None
Example #10
0
    def get_batch_representation(
        self,
        lstm_out: torch.Tensor,
        attention_weights: torch.Tensor,
        embedded_tokens: torch.Tensor,
        spans: List[List[int]],
        seq_lengths: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            lstm_out: (batch_size, max_seq_len + 2, seq_in_size)
                (seq_in_size is the total of LSTM hidden dims)
            attention_weights: (batch_size, max_seq_len + 2, 1)
            embedded_tokens: (batch_size, max_seq_len + 2, embed_dim)
                (embed_dim is the word embedding dim)
            spans: (num_spans, 2)
            seq_lengths: (batch_size,)
        Returns:
            representation: (batch_size, num_spans, ???)
        """
        batch_size, padded_max_seq_len, _ = attention_weights.size()
        max_seq_len = padded_max_seq_len - 2
        num_spans = len(spans)

        # Create python lists for span-features
        index_matrix = [None] * num_spans
        length_buckets = [0] * num_spans
        for i, span in enumerate(spans):
            index_matrix[i] = torch.zeros(padded_max_seq_len).index_put(
                [torch.arange(span[0] + 1, span[1] + 1)], torch.tensor([1.0]))
            length_buckets[i] = self.length_buckets.get(
                span[1] - span[0], self.num_buckets - 1)

        start_indices = try_gpu(torch.tensor([x[0] for x in spans]))
        end_indices = try_gpu(torch.tensor([x[1] for x in spans]))
        # (batch_size, num_spans, max_seq_len + 2)
        batch_index_matrix = try_gpu(
            torch.stack(index_matrix).expand(batch_size, -1, -1))

        # List[(batch_size, num_span, ???)]
        things_to_concat: List[torch.Tensor] = []

        if SpanFeature.INSIDE_TOKEN in self.span_features:
            # Index select the start and end word embeddings
            # (batch_size, num_spans, embed_dim)
            start_token = embedded_tokens[:, start_indices + 1, :]
            # (batch_size, num_spans, embed_dim)
            end_token = embedded_tokens[:, end_indices, :]
            things_to_concat += [start_token, end_token]

        if SpanFeature.INSIDE_HIDDEN in self.span_features:
            # Index select the start and end lstm outputs.
            # (batch_size, num_spans, embed_dim)
            start_hidden = lstm_out[:, start_indices + 1, :]
            # (batch_size, num_spans, seq_in_size)
            end_hidden = lstm_out[:, end_indices, :]
            things_to_concat += [start_hidden, end_hidden]

        if SpanFeature.STERN_HIDDEN in self.span_features:
            forward_before = lstm_out[:, start_indices, :self.lstm_dim]
            forward_after = lstm_out[:, end_indices, :self.lstm_dim]
            backward_before = lstm_out[:, end_indices + 1, self.lstm_dim:]
            backward_after = lstm_out[:, start_indices + 1, self.lstm_dim:]
            things_to_concat += [
                forward_after - forward_before,
                backward_after - backward_before,
            ]

        if (SpanFeature.AVERAGE_TOKEN in self.span_features
                or SpanFeature.AVERAGE_HIDDEN in self.span_features):
            average_weights = batch_index_matrix / torch.sum(
                batch_index_matrix, dim=2, keepdim=True)
            if SpanFeature.AVERAGE_TOKEN in self.span_features:
                # (batch_size, num_spans, embed_dim)
                average_token = torch.bmm(average_weights, embedded_tokens)
                things_to_concat.append(average_token)
            if SpanFeature.AVERAGE_HIDDEN in self.span_features:
                # (batch_size, num_spans, seq_in_size)
                average_hidden = torch.bmm(average_weights, lstm_out)
                things_to_concat.append(average_hidden)

        if (SpanFeature.ATTENTION_TOKEN in self.span_features
                or SpanFeature.ATTENTION_HIDDEN in self.span_features):
            # Batch the attention weights.
            # (batch_size, 1, max_seq_len)
            attention = attention_weights.transpose(1, 2)
            # Element-wise multriplication of attention weights with index matrices
            # as 0-indices will zero-out irrelevant attention indices.
            # (batch_size, num_spans, max_seq_len)
            attention_matrix = batch_index_matrix * attention
            # Replace all 0 indices with -float('inf') in preparation for softmax.
            # Softmax the attentions.
            attention_matrix[batch_index_matrix == 0.0] = float("-inf")
            # (batch_size, num_spans, max_seq_len)
            attention_matrix_normalized = F.softmax(attention_matrix, dim=2)
            # Batch matrix multiplication to obtain representation weighted by
            # normalized attention weights.
            if SpanFeature.ATTENTION_TOKEN in self.span_features:
                # (batch_size, num_spans, embed_dim)
                attention_token = torch.bmm(attention_matrix_normalized,
                                            embedded_tokens)
                things_to_concat.append(attention_token)
            if SpanFeature.ATTENTION_HIDDEN in self.span_features:
                # (batch_size, num_spans, seq_in_size)
                attention_hidden = torch.bmm(attention_matrix_normalized,
                                             lstm_out)
                things_to_concat.append(attention_hidden)

        if SpanFeature.LENGTH in self.span_features:
            # (num_spans,)
            length_buckets = try_gpu(torch.tensor(length_buckets))
            # (num_spans, length_embed_dim)
            length_embeddings = self.length_embeddings(length_buckets)
            # (batch_size, num_spans, length_embed_dim)
            things_to_concat.append(
                length_embeddings.expand(batch_size, -1, -1))

        # Concatenate all portions of the representations.
        representation = torch.cat(things_to_concat, dim=2)
        return representation