def scatter_sort( src: Tensor, index: LongTensor, descending=False, dim_size=None, out: Optional[Tuple[Tensor, LongTensor]] = None, ) -> Tuple[Tensor, LongTensor]: if src.ndimension() > 1: raise ValueError("Only implemented for 1D tensors") if dim_size is None: dim_size = index.max() + 1 if out is None: result_values = torch.empty_like(src) result_indexes = index.new_empty(src.shape) else: result_values, result_indexes = out sizes = ( index.new_zeros(dim_size) .scatter_add_(dim=0, index=index, src=torch.ones_like(index)) .tolist() ) start = 0 for size in sizes: end = start + size values, indexes = torch.sort(src[start:end], dim=0, descending=descending) result_values[start:end] = values result_indexes[start:end] = indexes + start start = end return result_values, result_indexes
def scatter_topk( src: Tensor, index: LongTensor, k: int, num_chunks=None, fill_value=None ) -> Tuple[Tensor, LongTensor, LongTensor]: """ Args: src: index: must be sorted in ascending order k: num_chunks: fill_value: Returns: A 1D tensor of shape [num_chunks * k] """ if src.ndimension() > 1: raise ValueError("Only implemented for 1D tensors") if num_chunks is None: num_chunks = index.max().item() + 1 if fill_value is None: fill_value = float("NaN") result_values = src.new_full((num_chunks * k,), fill_value=fill_value) result_indexes_whole = index.new_full((num_chunks * k,), fill_value=-1) result_indexes_within_chunk = index.new_full((num_chunks * k,), fill_value=-1) chunk_sizes = ( index.new_zeros(num_chunks) .scatter_add_(dim=0, index=index, src=torch.ones_like(index)) .tolist() ) start = 0 for chunk_idx, chunk_size in enumerate(chunk_sizes): chunk = src[start : start + chunk_size] values, indexes = torch.topk(chunk, k=min(k, chunk_size), dim=0) result_values[chunk_idx * k : chunk_idx * k + len(values)] = values result_indexes_within_chunk[ chunk_idx * k : chunk_idx * k + len(indexes) ] = indexes result_indexes_whole[chunk_idx * k : chunk_idx * k + len(indexes)] = ( indexes + start ) start += chunk_size return result_values, result_indexes_whole, result_indexes_within_chunk
def to_onehot(labels: torch.LongTensor, seq_len: int) -> torch.FloatTensor: """ convert categorical start/end labels to ont-hot labels, used for computing bce loss Args: labels: tensor of shape [bsz] seq_len: sequence length Returns: onehot_labels: tensor of shape [bsz, seq_len] """ bsz = labels.size(0) labels = labels.unsqueeze(-1) # [bsz, 1] onehot = labels.new_zeros([bsz, seq_len]) onehot.scatter_(1, labels, 1) # onehot[i][labels[i][j]] = 1 onehot = onehot.float() return onehot
def _action_to_token(self, action_tokens: torch.LongTensor, draft_tokens: torch.LongTensor) -> torch.LongTensor: predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1)) draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1)) predicted_tokens = action_tokens.new_full((action_tokens.size()), self.END) for act_step in action_tokens.t(): # KEEP, DELETE, COPY, ADD (other) keep_mask = act_step == self.KEEP drop_mask = act_step == self.DROP add_mask = ~(keep_mask | drop_mask) predicted_tokens.scatter_(1, predicted_pointer, draft_tokens.gather(1, draft_pointer)) predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter( 1, predicted_pointer[add_mask], act_step[add_mask].unsqueeze(1)) draft_pointer[keep_mask | drop_mask] += 1 predicted_pointer[~drop_mask] += 1 return predicted_tokens
def _get_candidates(self, entity_ids: torch.LongTensor) -> torch.LongTensor: """ Combines the unique ids from the current batch with the previous set of ids to form the collection of **all** relevant ids. Parameters ---------- entity_ids : ``torch.LongTensor`` A tensor of shape ``(batch_size, sequence_length)`` whose elements are the ids of the corresponding token in the ``target`` sequence. Returns ------- unique_entity_ids : ``torch.LongTensor`` A tensor of shape ``(batch_size, max_num_parents)`` containing all of the unique candidate ids. """ # Get the tensors of unique ids for each batch element and store them in a list all_unique: List[torch.LongTensor] = [] for i, ids in enumerate(entity_ids): if self._remaining[i] is not None: previous_ids = list(self._remaining[i].keys()) previous_ids = entity_ids.new_tensor(previous_ids) ids = torch.cat((ids.view(-1), previous_ids), dim=0) unique = torch.unique(ids, sorted=True) all_unique.append(unique) # Convert the list to a tensor by adding adequete padding. batch_size = entity_ids.shape[0] max_num_parents = max(unique.shape[0] for unique in all_unique) unique_entity_ids = entity_ids.new_zeros( size=(batch_size, max_num_parents)) for i, unique in enumerate(all_unique): unique_entity_ids[i, :unique.shape[0]] = unique return unique_entity_ids
def _calculate_edit_distance(self, output_symbols: torch.LongTensor, targets: torch.LongTensor, mask: torch.BoolTensor) -> torch.FloatTensor: batch_size, max_pred_len = output_symbols.size() _, max_len = targets.size() distances = output_symbols.new_zeros(batch_size, max_pred_len, max_len) distances[:, :, 0] = torch.arange(max_pred_len) distances[:, 0, :] = torch.arange(max_len) distances = distances.float() for i in range(1, max_pred_len): for j in range(1, max_len): diagonal = distances[:, i-1, j-1] + \ self.dsub * (output_symbols[:, i-1] != targets[:, j-1]).float() comp = torch.stack( (diagonal, distances[:, i - 1, j] + self.dins, distances[:, i, j - 1] + self.ddel), dim=-1) distances[:, i, j], _ = torch.min(comp, dim=-1) #edit_distance_mask = self._get_edit_distance_mask(mask, output_symbols) distances = distances.masked_fill(~mask.unsqueeze(1), float('inf')) return distances
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], spans: torch.LongTensor, gold_spans: torch.LongTensor, tags: torch.LongTensor = None, span_labels: torch.LongTensor = None, gold_span_labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]: ''' tags: Shape(batch_size, seq_len) bilou scheme tags for crf modelling ''' batch_size = spans.size(0) # Adding mask mask = util.get_text_field_mask(tokens) token_mask = torch.cat([mask, mask.new_zeros(batch_size, 1)], dim=1) embedded_text_input = self.text_field_embedder(tokens) embedded_text_input = torch.cat([embedded_text_input, embedded_text_input.new_zeros(batch_size, 1, embedded_text_input.size(2))], dim=1) # span_mask Shape: (batch_size, num_spans), 1 or 0 span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() gold_span_mask = (gold_spans[:,:,0] >=0).squeeze(-1).float() last_span_indices = gold_span_mask.sum(-1,keepdim=True).long() batch_indices = torch.arange(batch_size).unsqueeze(-1) batch_indices = util.move_to_device(batch_indices, util.get_device_of(embedded_text_input)) last_span_indices = torch.cat([batch_indices, last_span_indices],dim=-1) embedded_text_input[last_span_indices[:,0], last_span_indices[:,1]] += self.end_token_embedding.cuda(util.get_device_of(spans)) token_mask[last_span_indices[:,0], last_span_indices[:,1]] += 1. # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # spans Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() gold_spans = F.relu(gold_spans.float()).long() num_spans = spans.size(1) num_gold_spans = gold_spans.size(1) # Shape (batch_size, num_gold_spans, 4) hscrf_target = torch.cat([gold_spans, gold_spans.new_zeros(*gold_spans.size())], dim=-1) hscrf_target[:,:,2] = torch.cat([ (gold_span_labels.new_zeros(batch_size, 1)+self.hscrf_layer.start_id).long(), # start tags in the front gold_span_labels.squeeze()[:,0:-1]], dim=-1) hscrf_target[:,:,3] = gold_span_labels.squeeze() # Shape (batch_size, num_gold_spans+1, 4) including an <end> singular-span hscrf_target = torch.cat([hscrf_target, gold_spans.new_zeros(batch_size, 1, 4)], dim=1) hscrf_target[last_span_indices[:,0], last_span_indices[:,1],0:2] = \ hscrf_target[last_span_indices[:,0], last_span_indices[:,1]-1][:,1:2] + 1 hscrf_target[last_span_indices[:,0], last_span_indices[:,1],2] = \ hscrf_target[last_span_indices[:,0], last_span_indices[:,1]-1][:,3] hscrf_target[last_span_indices[:,0], last_span_indices[:,1],3] = \ self.hscrf_layer.stop_id # span_mask Shape: (batch_size, num_spans), 1 or 0 span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() gold_span_mask = torch.cat([gold_span_mask.float(), gold_span_mask.new_zeros(batch_size, 1).float()], dim=-1) gold_span_mask[last_span_indices[:,0], last_span_indices[:,1]] = 1. # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # spans Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() num_spans = spans.size(1) if self.dropout: embedded_text_input = self.dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, token_mask) if self.dropout: encoded_text = self.dropout(encoded_text) if self._feedforward is not None: encoded_text = self._feedforward(encoded_text) hscrf_neg_log_likelihood = self.hscrf_layer( encoded_text, tokens, token_mask.sum(-1).squeeze(), hscrf_target, gold_span_mask ) pred_results = self.hscrf_layer.get_scrf_decode( token_mask.sum(-1).squeeze() ) self._span_f1_metric( pred_results, [dic['gold_spans'] for dic in metadata], sentences=[x["words"] for x in metadata]) output = { "mask": token_mask, "loss": hscrf_neg_log_likelihood, "results": pred_results } if metadata is not None: output["words"] = [x["words"] for x in metadata] return output
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrLanguage]], actions: List[List[ProductionRule]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ if self._dynamic_cost_rate is not None: # This could be added back pretty easily with an EpochCallback passed to the Trainer (it # just has to set the epoch number on the model, which could then be queried in here). logger.warning( "Dynamic cost rate functionality was removed in AllenNLP 1.0. If you want this, " "use version 0.9. We will just use the static checklist cost weight." ) batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [agenda.new_zeros(1, dtype=torch.float) for i in range(batch_size)] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] label_strings = self._get_label_strings(labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) # TODO(mattg): It looks like the agenda is only ever used on the CPU. In that case, it's a # waste to copy it to the GPU and then back, and this should probably be a MetadataField. agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros(checklist_target.size()) initial_checklist_states.append( ChecklistStatelet( terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist, ) ) initial_state = CoverageState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings, checklist_state=initial_checklist_states, ) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode( # type: ignore initial_state, self._decoder_step, partial(self._get_state_cost, worlds) ) if identifier is not None: outputs["identifier"] = identifier best_final_states = outputs["best_final_states"] best_action_sequences = {} for batch_index, states in best_final_states.items(): best_action_sequences[batch_index] = [state.action_history[0] for state in states] batch_action_strings = self._get_action_strings(actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics( action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data, ) else: # We're testing. if metadata is not None: outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata] outputs["debug_info"] = [] for i in range(batch_size): outputs["debug_info"].append(best_final_states[i][0].debug_info[0]) # type: ignore outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs["action_mapping"] = action_mapping return outputs
def forward(self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, lemmas: torch.LongTensor, ner_tags: torch.LongTensor, metadata: List[Dict[str, Any]], supertags: torch.LongTensor = None, lexlabels: torch.LongTensor = None, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : ``List[str]``, required. The tokens in the original sentence. pos : ``List[str]``, required. The dependencies POS tags for each word. head_tags : = edge_labels torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold edge labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. edge_label_loss : ``torch.FloatTensor`` The loss contribution from the edge labels. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). edge_labels : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ if 'formalism' not in metadata[0]: raise ConfigurationError("metadata is missing 'formalism' key.\ Please use the amconll dataset reader.") formalism_of_batch = metadata[0]['formalism'] for entry in metadata: if entry['formalism'] != formalism_of_batch: raise ConfigurationError("Two formalisms in the same batch.") if not formalism_of_batch in self.tasks.keys(): raise ConfigurationError(f"Got formalism {formalism_of_batch} but I only have these tasks: {list(self.tasks.keys())}") if self.tok2vec: token_ids = words["tokens"] embedded_text_input = self.tok2vec.embed(self.vocab, token_ids) #shape (batch_size, seq len, encoder dim) concatenated_input = [embedded_text_input, self.text_field_embedder(words)] else: embedded_text_input = self.text_field_embedder(words) concatenated_input = [embedded_text_input] if pos_tags is not None and self._pos_tag_embedding is not None: concatenated_input.append(self._pos_tag_embedding(pos_tags)) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") if self._lemma_embedding is not None: concatenated_input.append(self._lemma_embedding(lemmas)) if self._ne_embedding is not None: concatenated_input.append(self._ne_embedding(ner_tags)) if len(concatenated_input) > 1: embedded_text_input = torch.cat(concatenated_input, -1) mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text_parsing, encoded_text_tagging = self.encoder(formalism_of_batch, embedded_text_input, mask) #potentially weight-sharing batch_size, seq_len, encoding_dim = encoded_text_parsing.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the artificial root onto the sentence representation. encoded_text_parsing = torch.cat([head_sentinel, encoded_text_parsing], 1) if encoded_text_tagging is not None: #might be none when batch is of formalism without tagging (UD) batch_size, seq_len, encoding_dim = encoded_text_tagging.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the artificial root onto the sentence representation. encoded_text_tagging = torch.cat([head_sentinel, encoded_text_tagging], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) return self.tasks[formalism_of_batch](encoded_text_parsing, encoded_text_tagging, mask, pos_tags, metadata, supertags, lexlabels, head_tags, head_indices)
def forward(self, # type: ignore words: Dict[str, torch.LongTensor], weight: torch.Tensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_text_input = self.text_field_embedder(words) embedded_text_input = self._input_dropout(embedded_text_input) mask = get_text_field_mask(words) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self.head_arc_feedforward(encoded_text) child_arc_representation = self.child_arc_feedforward(encoded_text) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self.head_tag_feedforward(encoded_text) child_tag_representation = self.child_tag_feedforward(encoded_text) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) if head_indices is not None and head_tags is not None: loss, normalised_arc_logits, normalised_head_tag_logits = \ self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask, weight=weight) normalised_arc_logits = _apply_head_mask(normalised_arc_logits, mask) tag_mask = self._get_unknown_tag_mask(mask, head_tags) self._attachment_scores(normalised_arc_logits[:, 1:], head_indices[:, 1:], mask[:, 1:]) self._tagging_accuracy(normalised_head_tag_logits[:, 1:], head_tags[:, 1:], tag_mask[:, 1:]) predicted_heads, predicted_head_tags = None, None else: attended_arcs = _apply_head_mask(attended_arcs, mask) # Compute the heads greedily. # shape (batch_size, sequence_length) _, predicted_heads = attended_arcs.max(dim=2) # Given the greedily predicted heads, decode their dependency tags. # shape (batch_size, sequence_length, num_head_tags) head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, predicted_heads) _, predicted_head_tags = head_tag_logits.max(dim=2) loss, normalised_arc_logits, normalised_head_tag_logits = \ self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask, weight=weight) normalised_arc_logits = _apply_head_mask(normalised_arc_logits, mask) output_dict = { "heads": normalised_arc_logits, "head_tags": normalised_head_tag_logits, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], } if predicted_heads is not None and predicted_head_tags is not None: output_dict['predicted_heads'] = predicted_heads[:, 1:] output_dict['predicted_head_tags'] = predicted_head_tags[:, 1:] return output_dict
def forward( self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : ``List[str]``, required. The tokens in the original sentence. pos : ``List[str]``, required. The dependencies POS tags for each word. head_tags : = edge_labels torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold edge labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. edge_label_loss : ``torch.FloatTensor`` The loss contribution from the edge labels. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). edge_labels : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) encoded_text = self._dropout(encoded_text) edge_existence_scores = self.edge_model.edge_existence( encoded_text, mask) if self.training or not self.use_mst_decoding_for_validation: predicted_heads = self._greedy_decode_arcs(edge_existence_scores, mask) edge_label_logits = self.edge_model.label_scores( encoded_text, predicted_heads) predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) else: #Find best tree with CLE predicted_heads = cle_decode(edge_existence_scores, mask.data.sum(dim=1).long()) #With info about tree structure, get edge label scores edge_label_logits = self.edge_model.label_scores( encoded_text, predicted_heads) #Predict edge labels predicted_edge_labels = self._greedy_decode_edge_labels( edge_label_logits) output_dict = { "heads": predicted_heads, "edge_labels": predicted_edge_labels, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata], "position_in_corpus": [meta["position_in_corpus"] for meta in metadata], } if head_indices is not None and head_tags is not None: gold_edge_label_logits = self.edge_model.label_scores( encoded_text, head_indices) edge_label_loss = self.loss_function.label_loss( gold_edge_label_logits, mask, head_tags) arc_nll = self.loss_function.edge_existence_loss( edge_existence_scores, head_indices, mask) loss = arc_nll + edge_label_loss evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attachment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_edge_labels[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) output_dict["arc_loss"] = arc_nll output_dict["edge_label_loss"] = edge_label_loss output_dict["loss"] = loss if self.pass_over_data_just_started: # here we could decide if we want to start collecting info for the decoder. pass self.pass_over_data_just_started = False return output_dict
def get_classification_embedding(self, token_ids: torch.LongTensor, mask: torch.LongTensor) -> torch.Tensor: embeddings = self.transformer_model(input_ids=token_ids, attention_mask=mask)[0] batch_size = token_ids.size(0) cls_offsets = token_ids.new_zeros([batch_size]).unsqueeze(1) classification_embedding = get_select_embedding(embeddings, cls_offsets).squeeze(1) return classification_embedding
def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) grammar_value_logits = self._gram_val_output(encoded_text) predicted_gram_vals = grammar_value_logits.argmax(-1) # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ #l_ext_input = encoded_text l_ext_input = torch.cat([encoded_text, grammar_value_logits], -1) lemma_logits = self._lemma_output(l_ext_input) predicted_lemmas = lemma_logits.argmax(-1) # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ВАРИАНТОВ ЛЕММАТИЗАЦИИ И ОЦЕНОК ВЕРОЯТНОСТИ lemma_probs = torch.nn.functional.softmax(lemma_logits, -1) top_lemmas_indices = (-lemma_logits).argsort(-1)[:, :, :self.TopNCnt] #top_lemmas_indices = (-lemma_probs).argsort(-1)[:,:,:self.TopNCnt] top_lemmas_prob = torch.gather(lemma_probs, -1, top_lemmas_indices) #top_lemmas_prob = torch.gather(lemma_logits, -1, top_lemmas_indices) # АНАЛОГИЧНО ДЛЯ ГРАММЕМ gramm_probs = torch.nn.functional.softmax(grammar_value_logits, -1) top_gramms_indices = ( -grammar_value_logits).argsort(-1)[:, :, :self.TopNCnt] top_gramms_prob = torch.gather(gramm_probs, -1, top_gramms_indices) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: synt_prediction, benrg = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) predicted_heads, predicted_head_tags = synt_prediction # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ЛОКАЛЬНЫХ!!! (не mst) ВАРИАНТОВ СИНТАКСИЧЕСКОГО РАЗБОРА И ОЦЕНОК ВЕРОЯТНОСИ benrgf = torch.flatten(benrg, start_dim=1, end_dim=2).permute( 0, 2, 1) # склеивает тип синт. отношения с индексом родителя top_deprels_indices = (-benrgf).argsort( -1)[:, :, :self.TopNCnt] # отбираем наилучшие комбинации top_deprels_prob = torch.gather(benrgf, -1, top_deprels_indices) seqlen = benrg.shape[2] top_heads = torch.fmod(top_deprels_indices, seqlen) top_deprels_indices = torch.div(top_deprels_indices, seqlen) # torch.floor не срабатывает if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: grammar_nll = self._update_multiclass_prediction_metrics( logits=grammar_value_logits, targets=grammar_values, mask=token_mask, accuracy_metric=self._gram_val_prediction_accuracy) lemma_nll = torch.tensor(0.) if lemma_indices is not None: lemma_nll = self._update_multiclass_prediction_metrics( logits=lemma_logits, targets=lemma_indices, mask=token_mask, accuracy_metric=self._lemma_prediction_accuracy, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, "top_lemmas": top_lemmas_indices, "top_lemmas_prob": top_lemmas_prob, "top_gramms": top_gramms_indices, "top_gramms_prob": top_gramms_prob, "top_heads": top_heads, "top_deprels": top_deprels_indices, "top_deprels_prob": top_deprels_prob, } return output_dict
def _parse( self, embedded_text_input: torch.Tensor, mask: torch.BoolTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, sequence_length, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self.head_arc_feedforward(encoded_text) child_arc_representation = self.child_arc_feedforward(encoded_text) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self.head_tag_feedforward(encoded_text) child_tag_representation = self.child_tag_feedforward(encoded_text) # calculate dimensions again as sequence_length is now + 1 from adding the head_sentinel batch_size, sequence_length, arc_dim = head_arc_representation.size() # now repeat the token representations to form a matrix: # shape (batch_size, sequence_length, sequence_length, arc_representation_dim) heads = head_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim) # heads in one direction deps = child_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim).transpose(1, 2) # deps in the other direction # shape (batch_size, sequence_length, sequence_length, arc_representation_dim) combined_arcs = self.activation(heads + deps) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_out_layer(combined_arcs).squeeze(3) minus_inf = -1e8 minus_mask = ~mask * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask ) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask ) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask, ) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask, ) return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
def forward( self, # type: ignore sentence: Dict[str, torch.LongTensor], worlds: List[List[NlvrLanguage]], actions: List[List[ProductionRule]], agenda: torch.LongTensor, identifier: List[str] = None, labels: torch.LongTensor = None, epoch_num: List[int] = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Decoder logic for producing type constrained target sequences that maximize coverage of their respective agendas, and minimize a denotation based loss. """ # We look at the epoch number and adjust the checklist cost weight if needed here. instance_epoch_num = epoch_num[0] if epoch_num is not None else None if self._dynamic_cost_rate is not None: if self.training and instance_epoch_num is None: raise RuntimeError( "If you want a dynamic cost weight, use the " "BucketIterator with track_epoch=True.") if instance_epoch_num != self._last_epoch_in_forward: if instance_epoch_num >= self._dynamic_cost_wait_epochs: decrement = self._checklist_cost_weight * self._dynamic_cost_rate self._checklist_cost_weight -= decrement logger.info("Checklist cost weight is now %f", self._checklist_cost_weight) self._last_epoch_in_forward = instance_epoch_num batch_size = len(worlds) initial_rnn_state = self._get_initial_rnn_state(sentence) initial_score_list = [ agenda.new_zeros(1, dtype=torch.float) for i in range(batch_size) ] # TODO (pradeep): Assuming all worlds give the same set of valid actions. initial_grammar_state = [ self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size) ] label_strings = self._get_label_strings( labels) if labels is not None else None # Each instance's agenda is of size (agenda_size, 1) # TODO(mattg): It looks like the agenda is only ever used on the CPU. In that case, it's a # waste to copy it to the GPU and then back, and this should probably be a MetadataField. agenda_list = [agenda[i] for i in range(batch_size)] initial_checklist_states = [] for instance_actions, instance_agenda in zip(actions, agenda_list): checklist_info = self._get_checklist_info(instance_agenda, instance_actions) checklist_target, terminal_actions, checklist_mask = checklist_info initial_checklist = checklist_target.new_zeros( checklist_target.size()) initial_checklist_states.append( ChecklistStatelet( terminal_actions=terminal_actions, checklist_target=checklist_target, checklist_mask=checklist_mask, checklist=initial_checklist, )) initial_state = CoverageState( batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, extras=label_strings, checklist_state=initial_checklist_states, ) if not self.training: initial_state.debug_info = [[] for _ in range(batch_size)] agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list] outputs = self._decoder_trainer.decode( # type: ignore initial_state, self._decoder_step, partial(self._get_state_cost, worlds)) if identifier is not None: outputs["identifier"] = identifier best_final_states = outputs["best_final_states"] best_action_sequences = {} for batch_index, states in best_final_states.items(): best_action_sequences[batch_index] = [ state.action_history[0] for state in states ] batch_action_strings = self._get_action_strings( actions, best_action_sequences) batch_denotations = self._get_denotations(batch_action_strings, worlds) if labels is not None: # We're either training or validating. self._update_metrics( action_strings=batch_action_strings, worlds=worlds, label_strings=label_strings, possible_actions=actions, agenda_data=agenda_data, ) else: # We're testing. if metadata is not None: outputs["sentence_tokens"] = [ x["sentence_tokens"] for x in metadata ] outputs["debug_info"] = [] for i in range(batch_size): outputs["debug_info"].append( best_final_states[i][0].debug_info[0]) # type: ignore outputs["best_action_strings"] = batch_action_strings outputs["denotations"] = batch_denotations action_mapping = {} for batch_index, batch_actions in enumerate(actions): for action_index, action in enumerate(batch_actions): action_mapping[(batch_index, action_index)] = action[0] outputs["action_mapping"] = action_mapping return outputs
def forward(self, # type: ignore # words: Dict[str, torch.LongTensor], encoded_text: torch.FloatTensor, mask: torch.LongTensor, pos_logits: torch.LongTensor = None, # predicted head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, _, _ = encoded_text.size() pos_tags = None if pos_logits is not None and self.pos_tag_embedding is not None: # Embed the predicted POS tags and concatenate the embeddings to the input num_pos_classes = pos_logits.size(-1) pos_logits = pos_logits.view(-1, num_pos_classes) _, pos_tags = pos_logits.max(-1) pos_embed_size = self.pos_tag_embedding.get_output_dim() embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags)) embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size) encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1) encoded_text = self.encoder(encoded_text, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attachment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], # "pos": [meta["pos"] for meta in metadata] } return output_dict
def forward(self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required. The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text_orig = self.encoder(embedded_text_input, mask) encoder_final_state = get_final_encoder_states(encoded_text_orig, mask) batch_size, _, encoding_dim = encoded_text_orig.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text_orig], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "encoder_final_state": encoder_final_state, "encoded_text": encoded_text_orig, "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata] } return output_dict
def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) # добавим измеремение, которое каждому выходу энкодера ставит в соответствие три его копии encoded_text_3 = encoded_text encoded_text_3 = torch.unsqueeze(encoded_text_3, 2) encoded_text_3 = encoded_text_3.repeat(1,1,3,1) # пропустим три копии вектора (с выхода энкодера) через lstm seq_len = encoded_text.size()[1] emb_div_val = encoded_text.size()[2] multi_triplets = torch.reshape(encoded_text_3, (-1, 3, emb_div_val)) label_variants, _ = self.multilabeler_lstm(multi_triplets) batched_label_variants = torch.reshape(label_variants, (-1, seq_len, 3, emb_div_val)) # # отладочный вывод # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: encoded_text.size() = {}".format(encoded_text.size()) ) # print( "ITLOG: encoded_text_3.size() = {}".format(encoded_text_3.size()) ) # print( "ITLOG: multi_triplets.size() = {}".format(multi_triplets.size()) ) # print( "ITLOG: label_variants.size() = {}".format(label_variants.size()) ) # print( "ITLOG: batched_label_variants.size() = {}".format(batched_label_variants.size()) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") # grammar_value_logits = self._gram_val_output(encoded_text) grammar_value_logits = self._gram_val_output(batched_label_variants) # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: grammar_value_logits.size() = {}".format(grammar_value_logits.size()) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") # grammar_value_logits = grammar_value_logits.select(2, 0) predicted_gram_vals = grammar_value_logits.argmax(-1) # lemma_logits = self._lemma_output(encoded_text) lemma_logits = self._lemma_output(batched_label_variants) predicted_lemmas = lemma_logits.argmax(-1) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: token_mask_3 = token_mask token_mask_3 = torch.unsqueeze(token_mask_3, 2) token_mask_3 = token_mask_3.repeat(1,1,3) # print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n") # print( "ITLOG: token_mask.size = {}".format(token_mask.size()) ) # print( "ITLOG: token_mask_3.size = {}".format(token_mask_3.size()) ) # print( "ITLOG: token_mask_3 = {}".format(token_mask_3) ) # print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n") grammar_nll = self._update_multiclass_prediction_metrics_3( logits=grammar_value_logits, targets=grammar_values, mask=token_mask_3, accuracy_metric=self._gram_val_prediction_accuracy ) lemma_nll = torch.tensor(0.) if lemma_indices is not None: token_mask_3 = token_mask token_mask_3 = torch.unsqueeze(token_mask_3, 2) token_mask_3 = token_mask_3.repeat(1,1,3) lemma_nll = self._update_multiclass_prediction_metrics_3( logits=lemma_logits, targets=lemma_indices, mask=token_mask_3, accuracy_metric=self._lemma_prediction_accuracy #, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX ) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, } return output_dict
def scatter_topk_2d_flat( src: Tensor, index: LongTensor, k: int, dim_size=None, fill_value=None ) -> Tuple[Tensor, Tuple[LongTensor, LongTensor], Tuple[LongTensor, LongTensor]]: """Finds the top k values in a 2D array partitioned along the dimension 0. :: +-----------------------+ | X | | X | | X | | X | +-----------------------+ | | | Y | | Y | +-------+ | | |X X X X| | | top 4 +-------+ | | --------> |X X X X| | | +-------+ | Y | |Z Z Z Z| | | +-------+ | Y | | | +-----------------------+ | | | Z Z | | | | Z Z | | | +-----------------------+ Args: src: index: k: dim_size: fill_value: Returns: """ if src.ndimension() != 2: raise ValueError("Only implemented for 2D tensors") if dim_size is None: dim_size = index.max().item() + 1 if fill_value is None: fill_value = float("NaN") ncols = src.shape[1] result_values = src.new_full((dim_size, k), fill_value=fill_value) result_indexes_whole_0 = index.new_full((dim_size, k), fill_value=-1) result_indexes_whole_1 = index.new_full((dim_size, k), fill_value=-1) result_indexes_within_chunk_0 = index.new_full((dim_size, k), fill_value=-1) result_indexes_within_chunk_1 = index.new_full((dim_size, k), fill_value=-1) chunk_sizes = ( index.new_zeros(dim_size) .scatter_add_(dim=0, index=index, src=torch.ones_like(index)) .tolist() ) start_src = 0 for chunk_idx, chunk_size in enumerate(chunk_sizes): flat_chunk = src[start_src : start_src + chunk_size, :].flatten() flat_values, flat_indexes = torch.topk( flat_chunk, k=min(k, chunk_size * ncols), dim=0 ) result_values[chunk_idx, : len(flat_values)] = flat_values indexes_0 = flat_indexes / ncols indexes_1 = flat_indexes % ncols result_indexes_within_chunk_0[chunk_idx, : len(flat_indexes)] = indexes_0 result_indexes_within_chunk_1[chunk_idx, : len(flat_indexes)] = indexes_1 result_indexes_whole_0[chunk_idx, : len(flat_indexes)] = indexes_0 + start_src result_indexes_whole_1[chunk_idx, : len(flat_indexes)] = indexes_1 start_src += chunk_size return ( result_values, (result_indexes_whole_0, result_indexes_whole_1), (result_indexes_within_chunk_0, result_indexes_within_chunk_1), )
def _parse(self, embedded_text_input: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, grammar_values: torch.LongTensor = None, lemma_indices: torch.LongTensor = None): embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) grammar_value_logits = self._gram_val_output(encoded_text) predicted_gram_vals = grammar_value_logits.argmax(-1) lemma_logits = self._lemma_output(encoded_text) predicted_lemmas = lemma_logits.argmax(-1) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) token_mask = mask.float() mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) grammar_nll = torch.tensor(0.) if grammar_values is not None: grammar_nll = self._update_multiclass_prediction_metrics( logits=grammar_value_logits, targets=grammar_values, mask=token_mask, accuracy_metric=self._gram_val_prediction_accuracy) lemma_nll = torch.tensor(0.) if lemma_indices is not None: lemma_nll = self._update_multiclass_prediction_metrics( logits=lemma_logits, targets=lemma_indices, mask=token_mask, accuracy_metric=self._lemma_prediction_accuracy, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "gram_vals": predicted_gram_vals, "lemmas": predicted_lemmas, "mask": mask, "arc_nll": arc_nll, "tag_nll": tag_nll, "grammar_nll": grammar_nll, "lemma_nll": lemma_nll, } return output_dict
def __call__(self, entity_ids: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns the set of valid parent entities for a given mention. Parameters ---------- entity_ids : ``torch.LongTensor`` A tensor of shape ``(batch_size, sequence_length)`` whose elements are the ids of the corresponding token in the ``target`` sequence. Returns ------- A tuple ``(candidate_ids, candidate_mask)`` containings the following elements: candidate_ids : ``torch.LongTensor`` A tensor of shape ``(batch_size, n_candidates)`` of all of the candidates for each batch element. candidate_mask : ``torch.LongTensor`` A tensor of shape ``(batch_size, sequence_length, n_candidates)`` defining which subset of candidates can be selected at the given point in the sequence. """ batch_size, sequence_length = entity_ids.shape[:2] # TODO: See if we can get away without nested loops / cast to CPU. candidate_ids = self._get_candidates(entity_ids) candidate_lookup = [{parent_id: j for j, parent_id in enumerate( l)} for l in candidate_ids.tolist()] # Create mask candidate_mask = entity_ids.new_zeros(size=(batch_size, sequence_length, candidate_ids.shape[-1]), dtype=torch.uint8) # Start by accounting for unfinished masks that remain from the last batch for i, lookup in enumerate(self._remaining): for parent_id, remainder in lookup.items(): # Find index w.r.t. the **current** set of candidates k = candidate_lookup[i][parent_id] # Fill in the remaining amount of mask candidate_mask[i, :remainder, k] = 1 # If splits are really short, then we might still have some remaining lookup[parent_id] -= sequence_length # Cast to list so we can use elements as keys (not possible for tensors) parent_id_list = entity_ids.tolist() for i, j, *_, parent_id in nested_enumerate(parent_id_list): if parent_id == 0: continue else: # Fill in mask k = candidate_lookup[i][parent_id] candidate_mask[i, j + 1: j + self._cutoff + 1, k] = 1 # Track how many sequence elements remain remainder = sequence_length - (j + self._cutoff + 1) self._remaining[i][parent_id] = ( j + self._cutoff + 1) - sequence_length # Remove any ids for non-recent parents (e.g. those without remaining mask) for i, lookup in enumerate(self._remaining): self._remaining[i] = {key: value for key, value in lookup.items() if value > 0} return candidate_ids, candidate_mask
def forward( self, # type: ignore token_ids: torch.LongTensor, type_ids: torch.LongTensor, offsets: torch.LongTensor, wordpiece_mask: torch.BoolTensor, dep_idxs: torch.LongTensor, dep_tags: torch.LongTensor, pos_tags: torch.LongTensor, word_mask: torch.BoolTensor, ): embedded_text_input = self.get_word_embedding( token_ids=token_ids, offsets=offsets, wordpiece_mask=wordpiece_mask, type_ids=type_ids, ) if self.pos_embedding is not None: embedded_pos_tags = self.pos_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) if self.fuse_layer is not None: embedded_text_input = self.fuse_layer(embedded_text_input) # todo compare normal dropout with InputVariationalDropout embedded_text_input = self._input_dropout(embedded_text_input) if self.additional_encoder is not None: if self.config.additional_layer_type == "transformer": extended_attention_mask = self.bert.get_extended_attention_mask( word_mask, word_mask.size(), word_mask.device) encoded_text = self.additional_encoder( hidden_states=embedded_text_input, attention_mask=extended_attention_mask)[0] else: encoded_text = self.additional_encoder( inputs=embedded_text_input, mask=word_mask) else: encoded_text = embedded_text_input batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) word_mask = torch.cat([word_mask.new_ones(batch_size, 1), word_mask], 1) dep_idxs = torch.cat([dep_idxs.new_zeros(batch_size, 1), dep_idxs], 1) dep_tags = torch.cat([dep_tags.new_zeros(batch_size, 1), dep_tags], 1) encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = ~word_mask * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, word_mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, word_mask) arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=dep_idxs, head_tags=dep_tags, mask=word_mask, ) return predicted_heads, predicted_head_tags, arc_nll, tag_nll
def _parse( self, encoded_text: torch.Tensor, mask: torch.LongTensor, head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout( self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout( self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout( self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout( self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze( 2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask, ) else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask, ) return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
def edit_distance_q_values(sampled_y: torch.LongTensor, gold_y: torch.LongTensor, end_symbol_id: int, vocab_size: int) -> torch.FloatTensor: """ OCD Edit Distance to compute QValues. Args: sampled_y (`~torch.LongTensor`): ``(batch_size, sequence_length)`` gold_y (`~torch.LongTensor`): ``(batch_size, sequence_length)`` end_symbol_id (int): index of the end symbol in your vocabulary. vocab_size (int): the number of possible output tokens. Returns: `~torch.FloatTensor`: ``(batch_size, sequence_length, vocab_size)`` Example: from paper `https://arxiv.org/abs/1810.01398` vocabulary = {'S':0, 'U':1, 'N':2, 'D':3, 'A':4 , 'Y':5, 'T':6, 'R':7, 'P':8, '</s>':9, '<pad>': 10} vocab_size = 11 end_symbol_id = 9 gold Y = {'SUNDAY</s><pad><pad>', 'SUNDAY</s><pad><pad>'} gold_y = [[0, 1, 2, 3, 4, 5, 9, 10, 10], [0, 1, 2, 3, 4, 5, 9, 10, 10]] sampled Y = {'SATURDAY</s>', 'SATRAPY</s>U'} sampled_y = [[0, 4, 6, 1, 7, 3, 4, 5, 9], [0, 4, 6, 7, 4, 8, 5, 9, 1]] # expected size: (batch_size=2, sequence_lenght=9, vocab_size=11) expected q_values = [[[ 0., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], [-1., 0., -1., -1., -1., -1., -1., -1., -1., -1., -1.], [-2., -1., -1., -2., -2., -2., -2., -2., -2., -2., -2.], [-3., -2., -2., -2., -3., -3., -3., -3., -3., -3., -3.], [-3., -3., -2., -3., -3., -3., -3., -3., -3., -3., -3.], [-4., -4., -3., -3., -4., -4., -4., -4., -4., -4., -4.], [-4., -4., -4., -4., -3., -4., -4., -4., -4., -4., -4.], [-4., -4., -4., -4., -4., -3., -4., -4., -4., -4., -4.], [-4., -4., -4., -4., -4., -4., -4., -4., -4., -3., -4.]], [[ 0., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], [-1., 0., -1., -1., -1., -1., -1., -1., -1., -1., -1.], [-2., -1., -1., -2., -2., -2., -2., -2., -2., -2., -2.], [-3., -2., -2., -2., -3., -3., -3., -3., -3., -3., -3.], [-4., -3., -3., -3., -3., -4., -4., -4., -4., -4., -4.], [-4., -4., -4., -4., -4., -3., -4., -4., -4., -4., -4.], [-5., -5., -5., -5., -5., -4., -5., -5., -5., -4., -5.], [-5., -5., -5., -5., -5., -5., -5., -5., -5., -4., -5.], [-6., -6., -6., -6., -6., -6., -6., -6., -6., -5., -6.]]] """ assert gold_y.size() == sampled_y.size() b_sz, seq_len = gold_y.size() q_values = gold_y.new_zeros((b_sz, seq_len + 1, vocab_size), dtype=torch.float) edit_dists = gold_y.new_zeros((b_sz, seq_len + 1, seq_len + 1), dtype=torch.float) # run batch version of the levenshtein algorithm edit_dists[:, :, 0] = torch.arange(seq_len + 1) edit_dists[:, 0, :] = torch.arange(seq_len + 1) for i in range(1, seq_len + 1): for j in range(1, seq_len + 1): cost = (sampled_y[:, i-1] != gold_y[:, j-1]).float() min_cost, _ = torch.cat(((edit_dists[:, i-1, j] + 1).unsqueeze(dim=1), (edit_dists[:, i, j-1] + 1).unsqueeze(dim=1), (edit_dists[:, i-1, j-1] + cost).unsqueeze(dim=1)), dim=1).min(dim=1) edit_dists[:, i, j] = min_cost # # # find gold next tokens and update their QValues edit_dists_mask = OCD.edit_distance_mask(gold_y, end_symbol_id) edit_dists = edit_dists.masked_fill_(edit_dists_mask, LARGE_NUM) min_dist, _ = edit_dists.min(dim=2) min_dist = min_dist.unsqueeze(dim=2) steps_with_min_dists = (edit_dists == min_dist) extended_gold_y = gold_y.repeat(1, seq_len + 1).view(b_sz, seq_len + 1, seq_len) indices = steps_with_min_dists.nonzero() gold_next_tokens = extended_gold_y[indices.split(1, dim=1)] indices[:, 2] = gold_next_tokens.squeeze(dim=1) q_values[indices.split(1, dim=1)] = 1 q_values = q_values - (1 + min_dist) return q_values[:, :-1, :] # ignore the step 'seq_len + 1'
def forward(self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required. The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata] } return output_dict
def forward(self, inputs: Dict[str, Any], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None): with self.input_encoding_timer as _: embeded_input = {} for name, fn in self.input_layers.items(): input_ = inputs[name] embeded_input[name] = fn(input_) encoded_input = [] for encoder_ in self.input_encoders: ordered_names = encoder_.get_ordered_names() args_ = {name: embeded_input[name] for name in ordered_names} encoded_input.append(self.input_dropout_(encoder_(args_))) encoded_input = torch.cat(encoded_input, dim=-1) # encoded_input: (batch_size, seq_len, input_dim) with self.context_encoding_timer as _: mask = get_mask_from_sequence_lengths(inputs['length'], inputs['length'].max()) # mask: (batch_size, seq_len) context_encoded_input = self.encoder(encoded_input, mask) # context_encoded_input: (batch_size, seq_len, encoded_dim) # handle the sentinel/dummy root. batch_size, _, encoding_dim = context_encoded_input.size() head_sentinel = self.head_sentinel_.expand(batch_size, 1, encoding_dim) context_encoded_input = torch.cat( [head_sentinel, context_encoded_input], 1) # context_encoded_input: (batch_size, seq_len + 1, encoded_dim) with self.classification_timer as _: mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) # mask: (batch_size, seq_len + 1) if head_indices is not None: head_indices = torch.cat( [head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat( [head_tags.new_zeros(batch_size, 1), head_tags], 1) context_encoded_input = self.dropout_(context_encoded_input) head_arc_representation = self.head_arc_feedforward( context_encoded_input) child_arc_representation = self.child_arc_feedforward( context_encoded_input) head_tag_representation = self.head_tag_feedforward( context_encoded_input) child_tag_representation = self.child_tag_feedforward( context_encoded_input) # head_tag_representation / child_tag_representation: (batch_size, seq_len + 1, dim) arc_representation = self.dropout_( torch.cat([head_arc_representation, child_arc_representation], dim=1)) tag_representation = self.dropout_( torch.cat([head_tag_representation, child_tag_representation], dim=1)) head_arc_representation, child_arc_representation = arc_representation.chunk( 2, dim=1) head_tag_representation, child_tag_representation = tag_representation.chunk( 2, dim=1) head_tag_representation = head_tag_representation.contiguous() child_tag_representation = child_tag_representation.contiguous() # attended_arcs: (batch_size, seq_len + 1, seq_len + 1) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) if not self.training: if not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode( head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = None, None if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll else: arc_nll, tag_nll = self._construct_loss( head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, } return output_dict