def _get_checklist_info(agenda: torch.LongTensor, all_actions: List[ProductionRule], terminal_productions: Set[str], max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda, a list of all actions, a set of terminal productions in the corresponding world, and a length to pad the checklist vectors to, and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRule]`` All actions for one instance. ``terminal_productions`` : ``Set[str]`` String representations of terminal productions in the corresponding world. ``max_num_terminals`` : ``int`` Length to which the checklist vectors will be padded till. This is the max number of terminal productions in all the worlds in the batch. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. for index, action in enumerate(all_actions): # Each action is a ProductionRule, a tuple where the first item is the production # rule string. if action[0] in terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) while len(target_checklist_list) < max_num_terminals: target_checklist_list.append([0]) terminal_indices.append([-1]) # (max_num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (max_num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(1): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:, :len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()
def _get_checklist_info(self, agenda: torch.LongTensor, all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda and a list of all actions and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``, ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to ``False``, indices of all terminals that are not in the agenda will be masked. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRuleArray]`` All actions for one instance. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()]) for index, action in enumerate(all_actions): # Each action is a ProductionRuleArray, a tuple where the first item is the production # rule string. if action[0] in self._terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. # (num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) if self._penalize_non_agenda_actions: # All terminal actions are relevant checklist_mask = torch.ones_like(target_checklist) else: checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int: # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something. # Check if target is big enough to cover prediction (including start/end symbols) if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) targets_trimmed = targets[:len(predicted)] # Return 1 if the predicted sequence is anywhere in the list of targets. return predicted_tensor.equal(targets_trimmed)
def _get_checklist_info( self, agenda: torch.LongTensor, all_actions: List[ProductionRule] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Takes an agenda and a list of all actions and returns a target checklist against which the checklist at each state will be compared to compute a loss, indices of ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions are relevant for checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``, ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to ``False``, indices of all terminals that are not in the agenda will be masked. Parameters ---------- ``agenda`` : ``torch.LongTensor`` Agenda of one instance of size ``(agenda_size, 1)``. ``all_actions`` : ``List[ProductionRule]`` All actions for one instance. """ terminal_indices = [] target_checklist_list = [] agenda_indices_set = {int(x) for x in agenda.squeeze(0).detach().cpu().numpy()} for index, action in enumerate(all_actions): # Each action is a ProductionRule, a tuple where the first item is the production # rule string. if action[0] in self._terminal_productions: terminal_indices.append([index]) if index in agenda_indices_set: target_checklist_list.append([1]) else: target_checklist_list.append([0]) # We want to return checklist target and terminal actions that are column vectors to make # computing softmax over the difference between checklist and target easier. # (num_terminals, 1) terminal_actions = agenda.new_tensor(terminal_indices) # (num_terminals, 1) target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float) if self._penalize_non_agenda_actions: # All terminal actions are relevant checklist_mask = torch.ones_like(target_checklist) else: checklist_mask = (target_checklist != 0).float() return target_checklist, terminal_actions, checklist_mask
def __init__(self, start_tokens: torch.LongTensor, end_token: Union[int, torch.LongTensor]): if start_tokens.dim() != 1: raise ValueError("start_tokens must be a vector") if not isinstance(end_token, int) and end_token.dim() != 0: raise ValueError("end_token must be a scalar") self._start_tokens = start_tokens self._batch_size = start_tokens.size(0) if isinstance(end_token, int): self._end_token = start_tokens.new_tensor(end_token) else: self._end_token = end_token
def is_equal(self, predicted: List[int], targets: torch.LongTensor, target_mask: torch.LongTensor) -> int: """ Judge whether given predict sql is equal to ground truth under the db_id :return: if equal, return 1; otherwise, return 0 """ if len(predicted) > targets.size(0): return 0 predicted_tensor = targets.new_tensor(predicted) # remove padding ones actual_len = target_mask.sum() targets_trimmed = targets[:actual_len] # Return 1 if the predicted sequence is anywhere in the list of targets. is_correct = torch.equal(predicted_tensor, targets_trimmed) if is_correct: return 1 else: return 0
def _get_candidates(self, entity_ids: torch.LongTensor) -> torch.LongTensor: """ Combines the unique ids from the current batch with the previous set of ids to form the collection of **all** relevant ids. Parameters ---------- entity_ids : ``torch.LongTensor`` A tensor of shape ``(batch_size, sequence_length)`` whose elements are the ids of the corresponding token in the ``target`` sequence. Returns ------- unique_entity_ids : ``torch.LongTensor`` A tensor of shape ``(batch_size, max_num_parents)`` containing all of the unique candidate ids. """ # Get the tensors of unique ids for each batch element and store them in a list all_unique: List[torch.LongTensor] = [] for i, ids in enumerate(entity_ids): if self._remaining[i] is not None: previous_ids = list(self._remaining[i].keys()) previous_ids = entity_ids.new_tensor(previous_ids) ids = torch.cat((ids.view(-1), previous_ids), dim=0) unique = torch.unique(ids, sorted=True) all_unique.append(unique) # Convert the list to a tensor by adding adequete padding. batch_size = entity_ids.shape[0] max_num_parents = max(unique.shape[0] for unique in all_unique) unique_entity_ids = entity_ids.new_zeros( size=(batch_size, max_num_parents)) for i, unique in enumerate(all_unique): unique_entity_ids[i, :unique.shape[0]] = unique return unique_entity_ids
def segment_lengths_to_ids( segment_lengths: torch.LongTensor) -> torch.LongTensor: """ Args: segment_lengths: Non-negative lengths of the tensor segments Returns: A tensor containing ids for every element in the tensor to be segmented Examples: >>> segments = torch.tensor([2, 4, 3, 1]) >>> segment_lengths_to_slices(segments) tensor([0, 0, 1, 1, 1, 1, 2, 2, 2, 3]) """ if segment_lengths.dim() != 1: raise ValueError( f'`segment_lengths` should have a single dimension, got shape {segment_lengths.shape}' ) if (segment_lengths < 0).any(): raise ValueError( f'All entries in `segment_lengths` should be non-negative') return segment_lengths.new_tensor( np.arange(len(segment_lengths)).repeat(segment_lengths.cpu().numpy()))
def forward(self, word_ids: torch.LongTensor, word_segment_ids: torch.LongTensor, word_attention_mask: torch.LongTensor, entity_ids: torch.LongTensor, entity_position_ids: torch.LongTensor, entity_segment_ids: torch.LongTensor, entity_attention_mask: torch.LongTensor, masked_entity_labels: Optional[torch.LongTensor] = None, masked_lm_labels: Optional[torch.LongTensor] = None, **kwargs): model_dtype = next(self.parameters()).dtype # for fp16 compatibility output = super().forward( word_ids, word_segment_ids, word_attention_mask, entity_ids, entity_position_ids, entity_segment_ids, entity_attention_mask, ) word_sequence_output, entity_sequence_output = output[:2] loss_fn = CrossEntropyLoss(ignore_index=-1) ret = dict(loss=word_ids.new_tensor(0.0, dtype=model_dtype)) if masked_entity_labels is not None: entity_mask = masked_entity_labels != -1 if entity_mask.sum() > 0: target_entity_sequence_output = torch.masked_select( entity_sequence_output, entity_mask.unsqueeze(-1)) target_entity_sequence_output = target_entity_sequence_output.view( -1, self.config.hidden_size) target_entity_labels = torch.masked_select( masked_entity_labels, entity_mask) entity_scores = self.entity_predictions( target_entity_sequence_output) entity_scores = entity_scores.view( -1, self.config.entity_vocab_size) ret["masked_entity_loss"] = loss_fn(entity_scores, target_entity_labels) ret["masked_entity_correct"] = (torch.argmax( entity_scores, 1).data == target_entity_labels.data).sum() ret["masked_entity_total"] = target_entity_labels.ne(-1).sum() ret["loss"] += ret["masked_entity_loss"] else: ret["masked_entity_loss"] = word_ids.new_tensor( 0.0, dtype=model_dtype) ret["masked_entity_correct"] = word_ids.new_tensor( 0, dtype=torch.long) ret["masked_entity_total"] = word_ids.new_tensor( 0, dtype=torch.long) if masked_lm_labels is not None: masked_lm_mask = masked_lm_labels != -1 if masked_lm_mask.sum() > 0: masked_word_sequence_output = torch.masked_select( word_sequence_output, masked_lm_mask.unsqueeze(-1)) masked_word_sequence_output = masked_word_sequence_output.view( -1, self.config.hidden_size) if self.config.bert_model_name and "roberta" in self.config.bert_model_name: masked_lm_scores = self.lm_head( masked_word_sequence_output) else: masked_lm_scores = self.cls.predictions( masked_word_sequence_output) masked_lm_scores = masked_lm_scores.view( -1, self.config.vocab_size) masked_lm_labels = torch.masked_select(masked_lm_labels, masked_lm_mask) ret["masked_lm_loss"] = loss_fn(masked_lm_scores, masked_lm_labels) ret["masked_lm_correct"] = (torch.argmax( masked_lm_scores, 1).data == masked_lm_labels.data).sum() ret["masked_lm_total"] = masked_lm_labels.ne(-1).sum() ret["loss"] += ret["masked_lm_loss"] else: ret["masked_lm_loss"] = word_ids.new_tensor(0.0, dtype=model_dtype) ret["masked_lm_correct"] = word_ids.new_tensor( 0, dtype=torch.long) ret["masked_lm_total"] = word_ids.new_tensor(0, dtype=torch.long) return ret