def embed_stuff(self, batch, seq_lengths): """ Build tensors from the token indices Args: batch: List[Example] seq_lengths: (batch_size,) Returns: bert_out: (batch_size, max_seq_len, seq_in_size) """ batch_token_indices = [] for example in batch: token_indices = example.token_indices # Pad with [CLS] and [SEP] token_indices = [self.CLS_index] + token_indices + [self.SEP_index] # Add to list token_indices = torch.tensor(token_indices) batch_token_indices.append(token_indices) # input_ids: (batch_size, max_seq_len + 2) input_ids = pad_sequence(batch_token_indices, batch_first=True) input_ids = try_gpu(input_ids) # attention_mask and token_type_ids positions = try_gpu(torch.arange(input_ids.size()[1])[None, :]) attention_mask = 1 * (positions < (seq_lengths + 2)[:, None]) token_type_ids = try_gpu(torch.zeros(input_ids.size(), dtype=int)) # Encode! bert_output = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) return bert_output.last_hidden_state
def forward(self, logit: TreeParserOutput, batch): losses: List[torch.Tensor] = [] # for each training example for i in range(len(logit.batch_span_scores)): if self.edge_loss is not None: decoded_spans, decoded_edges = logit.batch_span_scores[i] else: decoded_spans = logit.batch_span_scores[i] gold_proto_node = logit.batch_golds[i] gold_chains = self.tree_to_chains(gold_proto_node) total_loss = 0.0 # Add node losses # prediction: (num_spans, num_classes) # where num_classes = 1 + len(self.unary_chains) # The first column for NULL class is a zero vector. prediction = torch.cat( [ try_gpu(torch.zeros(len(decoded_spans), 1)), torch.stack( [decoded_span.labels for decoded_span in decoded_spans] ), ], dim=1, ) # targets: (num_spans,) # The unary chain indices are shifted by 1 since we padded # the prediction with the NULL class at index 0. targets = [ gold_chains.get((decoded_span.start, decoded_span.end), -1) + 1 for decoded_span in decoded_spans ] targets = try_gpu(torch.tensor(targets)) total_loss = total_loss + self.node_loss(prediction, targets) # Add edge losses if self.edge_loss is not None: span_edges = {(span.start, span.end): span for span in decoded_edges} gold_edges = self.tree_to_edges(gold_proto_node) if not gold_edges: edge_loss = 0.0 else: # prediction: (num_gold_edges, num_labels) prediction = [] # targets: (num_gold_edges,) targets = [] for key, (child_label, parent_label) in gold_edges.items(): prediction.append(span_edges[key].edges[child_label]) targets.append(parent_label) edge_loss = self.edge_loss( torch.stack(prediction), try_gpu(torch.tensor(targets)) ) total_loss = total_loss + edge_loss losses.append(total_loss) stacked_losses = torch.stack(losses) return stacked_losses.mean()
def get_batch_representation( self, bert_out: torch.Tensor, spans: List[List[int]], seq_lengths: torch.Tensor, ) -> torch.Tensor: """ Args: bert_out: (batch_size, max_seq_len + 2, seq_in_size) spans: (num_spans, 2) seq_lengths: (batch_size,) Returns: representation: (batch_size, num_spans, ???) """ batch_size, padded_max_seq_len, _ = bert_out.size() max_seq_len = padded_max_seq_len - 2 num_spans = len(spans) # Create python lists for span-features index_matrix = [None] * num_spans for i, span in enumerate(spans): index_matrix[i] = torch.zeros(padded_max_seq_len).index_put( [torch.arange(span[0] + 1, span[1] + 1)], torch.tensor([1.0])) start_indices = try_gpu(torch.tensor([x[0] for x in spans])) end_indices = try_gpu(torch.tensor([x[1] for x in spans])) # (batch_size, num_spans, max_seq_len + 2) batch_index_matrix = try_gpu( torch.stack(index_matrix).expand(batch_size, -1, -1)) # List[(batch_size, num_span, ???)] things_to_concat: List[torch.Tensor] = [] # Index select the start and end lstm outputs. # (batch_size, num_spans, embed_dim) start_hidden = bert_out[:, start_indices + 1, :] # (batch_size, num_spans, seq_in_size) end_hidden = bert_out[:, end_indices, :] things_to_concat += [start_hidden, end_hidden] average_weights = batch_index_matrix / torch.sum( batch_index_matrix, dim=2, keepdim=True) # (batch_size, num_spans, seq_in_size) average_hidden = torch.bmm(average_weights, bert_out) things_to_concat.append(average_hidden) # Concatenate all portions of the representations. representation = torch.cat(things_to_concat, dim=2) return representation
def create_model(self): config = self.config self.model = create_model(config, self.meta) self.model = try_gpu(self.model) self.optimizer = optim.Adam(self.model.parameters(), lr=config.train.learning_rate, weight_decay=config.train.l2_reg)
def _score_tree( self, span_scores: Tuple[List[DecodedSpan], List[DecodedEdges]], root_proto_node: ProtoNode, ) -> torch.Tensor: """ Sums the unary chain scores and edge scores of all spans. At test time, unknown unary chains are simply skipped. This should happen only when scoring gold trees. """ decoded_spans, decoded_edges = span_scores proto_node_stack: List[Tuple[ProtoNode, int, List[int]]] = [ (root_proto_node, 0, []) ] span_cands = {(span.start, span.end): span for span in decoded_spans} edge_cands = {(span.start, span.end): span for span in decoded_edges} tree_score = 0 while proto_node_stack: node, parent_label, chain_so_far = proto_node_stack.pop() # If it is still in a chain, push it back to the stack if ( len(node.children) == 1 and node.children[0].start == node.start and node.children[0].end == node.end ): proto_node_stack.append( ( node.children[0], parent_label, chain_so_far + [self.labels_idx[node.label]], ) ) else: chain = tuple(chain_so_far + [self.labels_idx[node.label]]) # Node score if chain not in self.unary_chain_idx: print("WARNING: chain {} not in training data".format(chain)) else: span = span_cands[node.start, node.end] tree_score += span.labels[self.unary_chain_idx[chain]] # Edge score if parent_label != 0: edges = edge_cands[node.start, node.end] tree_score += edges.edges[chain[0]][parent_label] # Add the children for child in node.children: proto_node_stack.append((child, self.labels_idx[node.label], [])) if isinstance(tree_score, int): # This can happen if all unary chains are unknown. print("WARNING: tree_score is 0") return try_gpu(torch.zeros(1)) else: return tree_score.view(1) # noqa
def forward(self, batch): seq_lengths = try_gpu(torch.tensor([x.length for x in batch])) max_seq_len = max(x.length for x in batch) # bert_out: (batch_size, max_seq_len + 2, seq_in_size) bert_out = self.embed_stuff(batch, seq_lengths) # Construct a list of all possible spans spans = self.all_spans(max_seq_len) # Batch head-word reprsentation batch_representations = self.get_batch_representation( bert_out, spans, seq_lengths, ) return batch_representations, spans, seq_lengths
def forward(self, batch): seq_lengths = try_gpu(torch.tensor([x.length for x in batch])) max_seq_len = max(x.length for x in batch) # embedded_tokens: (batch_size, max_seq_len + 2, embed_dim) # lstm_out: (batch_size, max_seq_len + 2, seq_in_size) embedded_tokens, lstm_out = self.embed_stuff(batch, seq_lengths) # Step 0: construct a list of all possible spans spans = self.all_spans(max_seq_len) # Step 2: pass lstm output to the FFNN attention_weights = self.attention_ffnn(lstm_out) # Batch head-word reprsentation batch_representations = self.get_batch_representation( lstm_out, attention_weights, embedded_tokens, spans, seq_lengths) return batch_representations, spans, seq_lengths
def embed_stuff(self, batch, seq_lengths): """ Build tensors from the token indices Args: batch: List[Example] Returns: embedded_tokens: (batch_size, max_seq_len, embed_dim) lstm_out: (batch_size, max_seq_len, seq_in_size) """ batch_token_indices = [] for example in batch: token_indices = example.token_indices if self.word_dropout and self.training: # Do word-level dropout token_indices = [ self.UNK_index if np.random.rand() < self._dropout_probs.get(x, 0) else x for x in token_indices ] # Pad with SOS and EOS token_indices = [self.SOS_index] + token_indices + [self.EOS_index] # Add to list token_indices = torch.tensor(token_indices) batch_token_indices.append(token_indices) padded_token_indices = pad_sequence(batch_token_indices, batch_first=True) padded_token_indices = try_gpu(padded_token_indices) # embedded_tokens: (batch_size, max_seq_len + 2, embed_dim) embedded_tokens = self.token_embedder(padded_token_indices) packed_embedded_tokens = pack_padded_sequence( embedded_tokens, seq_lengths + 2, batch_first=True, ) packed_lstm_out, _ = self.lstm(packed_embedded_tokens) # lstm_out: (batch_size, max_seq_len + 2, seq_in_size) lstm_out, _ = pad_packed_sequence(packed_lstm_out, batch_first=True) return embedded_tokens, lstm_out
def __init__(self, config, meta): """ Configs: add_edge_loss (bool): Whether to add edge score loss. null_weight (float): A multiplier for the loss term when the gold class is null (i.e., the span is not in the gold tree). Higher null_weight -> fewer spans are predicted. """ super().__init__() self.labels = meta.nt self.labels_idx = meta.nt_x self.unary_chains = meta.unary_chains self.chains_idx = meta.unary_chains_x # Construct the loss weight; the first class is the NULL class self.null_weight = config.model.output_layer.null_weight weight = [self.null_weight] + [1.0] * len(self.chains_idx) weight = try_gpu(torch.tensor(weight)) self.node_loss = nn.CrossEntropyLoss(weight=weight) if config.model.output_layer.add_edge_loss: # Edge scores are already softmax-ed. self.edge_loss = nn.NLLLoss() else: self.edge_loss = None
def get_batch_representation( self, lstm_out: torch.Tensor, attention_weights: torch.Tensor, embedded_tokens: torch.Tensor, spans: List[List[int]], seq_lengths: torch.Tensor, ) -> torch.Tensor: """ Args: lstm_out: (batch_size, max_seq_len + 2, seq_in_size) (seq_in_size is the total of LSTM hidden dims) attention_weights: (batch_size, max_seq_len + 2, 1) embedded_tokens: (batch_size, max_seq_len + 2, embed_dim) (embed_dim is the word embedding dim) spans: (num_spans, 2) seq_lengths: (batch_size,) Returns: representation: (batch_size, num_spans, ???) """ batch_size, padded_max_seq_len, _ = attention_weights.size() max_seq_len = padded_max_seq_len - 2 num_spans = len(spans) # Create python lists for span-features index_matrix = [None] * num_spans length_buckets = [0] * num_spans for i, span in enumerate(spans): index_matrix[i] = torch.zeros(padded_max_seq_len).index_put( [torch.arange(span[0] + 1, span[1] + 1)], torch.tensor([1.0])) length_buckets[i] = self.length_buckets.get( span[1] - span[0], self.num_buckets - 1) start_indices = try_gpu(torch.tensor([x[0] for x in spans])) end_indices = try_gpu(torch.tensor([x[1] for x in spans])) # (batch_size, num_spans, max_seq_len + 2) batch_index_matrix = try_gpu( torch.stack(index_matrix).expand(batch_size, -1, -1)) # List[(batch_size, num_span, ???)] things_to_concat: List[torch.Tensor] = [] if SpanFeature.INSIDE_TOKEN in self.span_features: # Index select the start and end word embeddings # (batch_size, num_spans, embed_dim) start_token = embedded_tokens[:, start_indices + 1, :] # (batch_size, num_spans, embed_dim) end_token = embedded_tokens[:, end_indices, :] things_to_concat += [start_token, end_token] if SpanFeature.INSIDE_HIDDEN in self.span_features: # Index select the start and end lstm outputs. # (batch_size, num_spans, embed_dim) start_hidden = lstm_out[:, start_indices + 1, :] # (batch_size, num_spans, seq_in_size) end_hidden = lstm_out[:, end_indices, :] things_to_concat += [start_hidden, end_hidden] if SpanFeature.STERN_HIDDEN in self.span_features: forward_before = lstm_out[:, start_indices, :self.lstm_dim] forward_after = lstm_out[:, end_indices, :self.lstm_dim] backward_before = lstm_out[:, end_indices + 1, self.lstm_dim:] backward_after = lstm_out[:, start_indices + 1, self.lstm_dim:] things_to_concat += [ forward_after - forward_before, backward_after - backward_before, ] if (SpanFeature.AVERAGE_TOKEN in self.span_features or SpanFeature.AVERAGE_HIDDEN in self.span_features): average_weights = batch_index_matrix / torch.sum( batch_index_matrix, dim=2, keepdim=True) if SpanFeature.AVERAGE_TOKEN in self.span_features: # (batch_size, num_spans, embed_dim) average_token = torch.bmm(average_weights, embedded_tokens) things_to_concat.append(average_token) if SpanFeature.AVERAGE_HIDDEN in self.span_features: # (batch_size, num_spans, seq_in_size) average_hidden = torch.bmm(average_weights, lstm_out) things_to_concat.append(average_hidden) if (SpanFeature.ATTENTION_TOKEN in self.span_features or SpanFeature.ATTENTION_HIDDEN in self.span_features): # Batch the attention weights. # (batch_size, 1, max_seq_len) attention = attention_weights.transpose(1, 2) # Element-wise multriplication of attention weights with index matrices # as 0-indices will zero-out irrelevant attention indices. # (batch_size, num_spans, max_seq_len) attention_matrix = batch_index_matrix * attention # Replace all 0 indices with -float('inf') in preparation for softmax. # Softmax the attentions. attention_matrix[batch_index_matrix == 0.0] = float("-inf") # (batch_size, num_spans, max_seq_len) attention_matrix_normalized = F.softmax(attention_matrix, dim=2) # Batch matrix multiplication to obtain representation weighted by # normalized attention weights. if SpanFeature.ATTENTION_TOKEN in self.span_features: # (batch_size, num_spans, embed_dim) attention_token = torch.bmm(attention_matrix_normalized, embedded_tokens) things_to_concat.append(attention_token) if SpanFeature.ATTENTION_HIDDEN in self.span_features: # (batch_size, num_spans, seq_in_size) attention_hidden = torch.bmm(attention_matrix_normalized, lstm_out) things_to_concat.append(attention_hidden) if SpanFeature.LENGTH in self.span_features: # (num_spans,) length_buckets = try_gpu(torch.tensor(length_buckets)) # (num_spans, length_embed_dim) length_embeddings = self.length_embeddings(length_buckets) # (batch_size, num_spans, length_embed_dim) things_to_concat.append( length_embeddings.expand(batch_size, -1, -1)) # Concatenate all portions of the representations. representation = torch.cat(things_to_concat, dim=2) return representation