Beispiel #1
0
def scatter_sort(
    src: Tensor,
    index: LongTensor,
    descending=False,
    dim_size=None,
    out: Optional[Tuple[Tensor, LongTensor]] = None,
) -> Tuple[Tensor, LongTensor]:
    if src.ndimension() > 1:
        raise ValueError("Only implemented for 1D tensors")

    if dim_size is None:
        dim_size = index.max() + 1

    if out is None:
        result_values = torch.empty_like(src)
        result_indexes = index.new_empty(src.shape)
    else:
        result_values, result_indexes = out

    sizes = (
        index.new_zeros(dim_size)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start = 0
    for size in sizes:
        end = start + size
        values, indexes = torch.sort(src[start:end], dim=0, descending=descending)
        result_values[start:end] = values
        result_indexes[start:end] = indexes + start
        start = end

    return result_values, result_indexes
Beispiel #2
0
def scatter_topk(
    src: Tensor, index: LongTensor, k: int, num_chunks=None, fill_value=None
) -> Tuple[Tensor, LongTensor, LongTensor]:
    """

    Args:
        src:
        index: must be sorted in ascending order
        k:
        num_chunks:
        fill_value:

    Returns: A 1D tensor of shape [num_chunks * k]

    """
    if src.ndimension() > 1:
        raise ValueError("Only implemented for 1D tensors")

    if num_chunks is None:
        num_chunks = index.max().item() + 1

    if fill_value is None:
        fill_value = float("NaN")

    result_values = src.new_full((num_chunks * k,), fill_value=fill_value)
    result_indexes_whole = index.new_full((num_chunks * k,), fill_value=-1)
    result_indexes_within_chunk = index.new_full((num_chunks * k,), fill_value=-1)

    chunk_sizes = (
        index.new_zeros(num_chunks)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start = 0
    for chunk_idx, chunk_size in enumerate(chunk_sizes):
        chunk = src[start : start + chunk_size]
        values, indexes = torch.topk(chunk, k=min(k, chunk_size), dim=0)

        result_values[chunk_idx * k : chunk_idx * k + len(values)] = values
        result_indexes_within_chunk[
            chunk_idx * k : chunk_idx * k + len(indexes)
        ] = indexes
        result_indexes_whole[chunk_idx * k : chunk_idx * k + len(indexes)] = (
            indexes + start
        )

        start += chunk_size

    return result_values, result_indexes_whole, result_indexes_within_chunk
Beispiel #3
0
 def to_onehot(labels: torch.LongTensor, seq_len: int) -> torch.FloatTensor:
     """
     convert categorical start/end labels to ont-hot labels, used for computing bce loss
     Args:
         labels: tensor of shape [bsz]
         seq_len: sequence length
     Returns:
         onehot_labels: tensor of shape [bsz, seq_len]
     """
     bsz = labels.size(0)
     labels = labels.unsqueeze(-1)  # [bsz, 1]
     onehot = labels.new_zeros([bsz, seq_len])
     onehot.scatter_(1, labels, 1)  # onehot[i][labels[i][j]] = 1
     onehot = onehot.float()
     return onehot
Beispiel #4
0
    def _action_to_token(self, action_tokens: torch.LongTensor,
                         draft_tokens: torch.LongTensor) -> torch.LongTensor:
        predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1))
        draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1))

        predicted_tokens = action_tokens.new_full((action_tokens.size()),
                                                  self.END)

        for act_step in action_tokens.t():
            # KEEP, DELETE, COPY, ADD (other)
            keep_mask = act_step == self.KEEP
            drop_mask = act_step == self.DROP
            add_mask = ~(keep_mask | drop_mask)

            predicted_tokens.scatter_(1, predicted_pointer,
                                      draft_tokens.gather(1, draft_pointer))
            predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter(
                1, predicted_pointer[add_mask],
                act_step[add_mask].unsqueeze(1))

            draft_pointer[keep_mask | drop_mask] += 1
            predicted_pointer[~drop_mask] += 1
        return predicted_tokens
    def _get_candidates(self,
                        entity_ids: torch.LongTensor) -> torch.LongTensor:
        """
        Combines the unique ids from the current batch with the previous set of ids to form the
        collection of **all** relevant ids.

        Parameters
        ----------
        entity_ids : ``torch.LongTensor``
            A tensor of shape ``(batch_size, sequence_length)`` whose elements are the ids
            of the corresponding token in the ``target`` sequence.

        Returns
        -------
        unique_entity_ids : ``torch.LongTensor``
            A tensor of shape ``(batch_size, max_num_parents)`` containing all of the unique
            candidate ids.
        """
        # Get the tensors of unique ids for each batch element and store them in a list
        all_unique: List[torch.LongTensor] = []
        for i, ids in enumerate(entity_ids):
            if self._remaining[i] is not None:
                previous_ids = list(self._remaining[i].keys())
                previous_ids = entity_ids.new_tensor(previous_ids)
                ids = torch.cat((ids.view(-1), previous_ids), dim=0)
            unique = torch.unique(ids, sorted=True)
            all_unique.append(unique)

        # Convert the list to a tensor by adding adequete padding.
        batch_size = entity_ids.shape[0]
        max_num_parents = max(unique.shape[0] for unique in all_unique)
        unique_entity_ids = entity_ids.new_zeros(
            size=(batch_size, max_num_parents))
        for i, unique in enumerate(all_unique):
            unique_entity_ids[i, :unique.shape[0]] = unique

        return unique_entity_ids
Beispiel #6
0
    def _calculate_edit_distance(self, output_symbols: torch.LongTensor,
                                 targets: torch.LongTensor,
                                 mask: torch.BoolTensor) -> torch.FloatTensor:
        batch_size, max_pred_len = output_symbols.size()
        _, max_len = targets.size()

        distances = output_symbols.new_zeros(batch_size, max_pred_len, max_len)
        distances[:, :, 0] = torch.arange(max_pred_len)
        distances[:, 0, :] = torch.arange(max_len)
        distances = distances.float()

        for i in range(1, max_pred_len):
            for j in range(1, max_len):
                diagonal = distances[:, i-1, j-1] + \
                    self.dsub * (output_symbols[:, i-1] != targets[:, j-1]).float()
                comp = torch.stack(
                    (diagonal, distances[:, i - 1, j] + self.dins,
                     distances[:, i, j - 1] + self.ddel),
                    dim=-1)
                distances[:, i, j], _ = torch.min(comp, dim=-1)

        #edit_distance_mask = self._get_edit_distance_mask(mask, output_symbols)
        distances = distances.masked_fill(~mask.unsqueeze(1), float('inf'))
        return distances
Beispiel #7
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor, 
                gold_spans: torch.LongTensor, 
                tags: torch.LongTensor = None,
                span_labels: torch.LongTensor = None,
                gold_span_labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        '''
            tags: Shape(batch_size, seq_len)
                bilou scheme tags for crf modelling
        '''
        
        batch_size = spans.size(0)
        # Adding mask
        mask = util.get_text_field_mask(tokens)

        token_mask = torch.cat([mask, 
                                mask.new_zeros(batch_size, 1)],
                                dim=1)

        embedded_text_input = self.text_field_embedder(tokens)

        embedded_text_input = torch.cat([embedded_text_input, 
                                         embedded_text_input.new_zeros(batch_size, 1, embedded_text_input.size(2))],
                                        dim=1)

        # span_mask Shape: (batch_size, num_spans), 1 or 0
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        gold_span_mask = (gold_spans[:,:,0] >=0).squeeze(-1).float()
        last_span_indices = gold_span_mask.sum(-1,keepdim=True).long()

        batch_indices = torch.arange(batch_size).unsqueeze(-1)
        batch_indices = util.move_to_device(batch_indices, 
                                            util.get_device_of(embedded_text_input))
        last_span_indices = torch.cat([batch_indices, last_span_indices],dim=-1)
        embedded_text_input[last_span_indices[:,0], last_span_indices[:,1]] += self.end_token_embedding.cuda(util.get_device_of(spans))

        token_mask[last_span_indices[:,0], last_span_indices[:,1]] += 1.
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.

        # spans Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()
        gold_spans = F.relu(gold_spans.float()).long()
        num_spans = spans.size(1)
        num_gold_spans = gold_spans.size(1)

        # Shape (batch_size, num_gold_spans, 4)
        hscrf_target = torch.cat([gold_spans, gold_spans.new_zeros(*gold_spans.size())],
                                 dim=-1)
        hscrf_target[:,:,2] = torch.cat([
            (gold_span_labels.new_zeros(batch_size, 1)+self.hscrf_layer.start_id).long(), # start tags in the front
            gold_span_labels.squeeze()[:,0:-1]],
            dim=-1)
        hscrf_target[:,:,3] = gold_span_labels.squeeze()
        # Shape (batch_size, num_gold_spans+1, 4)  including an <end> singular-span
        hscrf_target = torch.cat([hscrf_target, gold_spans.new_zeros(batch_size, 1, 4)],
                                 dim=1)

        hscrf_target[last_span_indices[:,0], last_span_indices[:,1],0:2] = \
                hscrf_target[last_span_indices[:,0], last_span_indices[:,1]-1][:,1:2] + 1

        hscrf_target[last_span_indices[:,0], last_span_indices[:,1],2] = \
                hscrf_target[last_span_indices[:,0], last_span_indices[:,1]-1][:,3]

        hscrf_target[last_span_indices[:,0], last_span_indices[:,1],3] = \
                self.hscrf_layer.stop_id
        
        

        # span_mask Shape: (batch_size, num_spans), 1 or 0
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()

        gold_span_mask = torch.cat([gold_span_mask.float(), 
                                gold_span_mask.new_zeros(batch_size, 1).float()], dim=-1)
        gold_span_mask[last_span_indices[:,0], last_span_indices[:,1]] = 1.


        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.

        # spans Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()
        num_spans = spans.size(1)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, token_mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self._feedforward is not None:
            encoded_text = self._feedforward(encoded_text)

        hscrf_neg_log_likelihood = self.hscrf_layer(
            encoded_text, 
            tokens,
            token_mask.sum(-1).squeeze(),
            hscrf_target,
            gold_span_mask
        )

        pred_results = self.hscrf_layer.get_scrf_decode(
            token_mask.sum(-1).squeeze()
        )
        self._span_f1_metric(
            pred_results, 
            [dic['gold_spans'] for dic in metadata],
            sentences=[x["words"] for x in metadata])
        output = {
            "mask": token_mask,
            "loss": hscrf_neg_log_likelihood,
            "results": pred_results
                 }
        
        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output
Beispiel #8
0
    def forward(
        self,  # type: ignore
        sentence: Dict[str, torch.LongTensor],
        worlds: List[List[NlvrLanguage]],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        identifier: List[str] = None,
        labels: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        if self._dynamic_cost_rate is not None:
            # This could be added back pretty easily with an EpochCallback passed to the Trainer (it
            # just has to set the epoch number on the model, which could then be queried in here).
            logger.warning(
                "Dynamic cost rate functionality was removed in AllenNLP 1.0. If you want this, "
                "use version 0.9.  We will just use the static checklist cost weight."
            )
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [agenda.new_zeros(1, dtype=torch.float) for i in range(batch_size)]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                )
            )
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states,
        )
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(  # type: ignore
            initial_state, self._decoder_step, partial(self._get_state_cost, worlds)
        )
        if identifier is not None:
            outputs["identifier"] = identifier
        best_final_states = outputs["best_final_states"]
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [state.action_history[0] for state in states]
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(
                action_strings=batch_action_strings,
                worlds=worlds,
                label_strings=label_strings,
                possible_actions=actions,
                agenda_data=agenda_data,
            )
        else:
            # We're testing.
            if metadata is not None:
                outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
            outputs["debug_info"] = []
            for i in range(batch_size):
                outputs["debug_info"].append(best_final_states[i][0].debug_info[0])  # type: ignore
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs["action_mapping"] = action_mapping
        return outputs
    def forward(self,  # type: ignore
                words: Dict[str, torch.LongTensor],
                pos_tags: torch.LongTensor,
                lemmas: torch.LongTensor,
                ner_tags: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                supertags: torch.LongTensor = None,
                lexlabels: torch.LongTensor = None,
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        metadata : List[Dict[str, Any]], optional (default=None)
            A dictionary of metadata for each batch element which has keys:
                words : ``List[str]``, required.
                    The tokens in the original sentence.
                pos : ``List[str]``, required.
                    The dependencies POS tags for each word.
        head_tags : = edge_labels torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold edge labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        edge_label_loss : ``torch.FloatTensor``
            The loss contribution from the edge labels.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        edge_labels : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        if 'formalism' not in metadata[0]:
            raise ConfigurationError("metadata is missing 'formalism' key.\
            Please use the amconll dataset reader.")

        formalism_of_batch = metadata[0]['formalism']
        for entry in metadata:
            if entry['formalism'] != formalism_of_batch:
                raise ConfigurationError("Two formalisms in the same batch.")
        if not formalism_of_batch in self.tasks.keys():
            raise ConfigurationError(f"Got formalism {formalism_of_batch} but I only have these tasks: {list(self.tasks.keys())}")

        if self.tok2vec:
            token_ids = words["tokens"]
            embedded_text_input = self.tok2vec.embed(self.vocab, token_ids) #shape (batch_size, seq len, encoder dim)
            concatenated_input = [embedded_text_input, self.text_field_embedder(words)]
        else:
            embedded_text_input = self.text_field_embedder(words)
            concatenated_input = [embedded_text_input]

        if pos_tags is not None and self._pos_tag_embedding is not None:
            concatenated_input.append(self._pos_tag_embedding(pos_tags))
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        if self._lemma_embedding is not None:
            concatenated_input.append(self._lemma_embedding(lemmas))
        if self._ne_embedding is not None:
            concatenated_input.append(self._ne_embedding(ner_tags))

        if len(concatenated_input) > 1:
            embedded_text_input = torch.cat(concatenated_input, -1)
        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text_parsing, encoded_text_tagging = self.encoder(formalism_of_batch, embedded_text_input, mask) #potentially weight-sharing

        batch_size, seq_len, encoding_dim = encoded_text_parsing.size()
        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the artificial root onto the sentence representation.
        encoded_text_parsing = torch.cat([head_sentinel, encoded_text_parsing], 1)

        if encoded_text_tagging is not None: #might be none when batch is of formalism without tagging (UD)
            batch_size, seq_len, encoding_dim = encoded_text_tagging.size()
            head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
            # Concatenate the artificial root onto the sentence representation.
            encoded_text_tagging = torch.cat([head_sentinel, encoded_text_tagging], 1)

        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)

        return self.tasks[formalism_of_batch](encoded_text_parsing, encoded_text_tagging, mask, pos_tags, metadata, supertags, lexlabels, head_tags, head_indices)
Beispiel #10
0
    def forward(self,  # type: ignore
                words: Dict[str, torch.LongTensor],
                weight: torch.Tensor,
                metadata: List[Dict[str, Any]],
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_text_input = self.text_field_embedder(words)
        embedded_text_input = self._input_dropout(embedded_text_input)

        mask = get_text_field_mask(words)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self.head_arc_feedforward(encoded_text)
        child_arc_representation = self.child_arc_feedforward(encoded_text)

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self.head_tag_feedforward(encoded_text)
        child_tag_representation = self.child_tag_feedforward(encoded_text)
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        if head_indices is not None and head_tags is not None:
            loss, normalised_arc_logits, normalised_head_tag_logits = \
                self._construct_loss(head_tag_representation=head_tag_representation,
                                     child_tag_representation=child_tag_representation,
                                     attended_arcs=attended_arcs,
                                     head_indices=head_indices,
                                     head_tags=head_tags,
                                     mask=mask,
                                     weight=weight)

            normalised_arc_logits = _apply_head_mask(normalised_arc_logits, mask)
            tag_mask = self._get_unknown_tag_mask(mask, head_tags)
            self._attachment_scores(normalised_arc_logits[:, 1:], head_indices[:, 1:], mask[:, 1:])
            self._tagging_accuracy(normalised_head_tag_logits[:, 1:], head_tags[:, 1:], tag_mask[:, 1:])
            predicted_heads, predicted_head_tags = None, None
        else:
            attended_arcs = _apply_head_mask(attended_arcs, mask)
            # Compute the heads greedily.
            # shape (batch_size, sequence_length)
            _, predicted_heads = attended_arcs.max(dim=2)

            # Given the greedily predicted heads, decode their dependency tags.
            # shape (batch_size, sequence_length, num_head_tags)
            head_tag_logits = self._get_head_tags(head_tag_representation,
                                                  child_tag_representation,
                                                  predicted_heads)
            _, predicted_head_tags = head_tag_logits.max(dim=2)

            loss, normalised_arc_logits, normalised_head_tag_logits = \
                self._construct_loss(head_tag_representation=head_tag_representation,
                                     child_tag_representation=child_tag_representation,
                                     attended_arcs=attended_arcs,
                                     head_indices=predicted_heads.long(),
                                     head_tags=predicted_head_tags.long(),
                                     mask=mask,
                                     weight=weight)
            normalised_arc_logits = _apply_head_mask(normalised_arc_logits, mask)

        output_dict = {
            "heads": normalised_arc_logits,
            "head_tags": normalised_head_tag_logits,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
        }
        if predicted_heads is not None and predicted_head_tags is not None:
            output_dict['predicted_heads'] = predicted_heads[:, 1:]
            output_dict['predicted_head_tags'] = predicted_head_tags[:, 1:]
        return output_dict
    def forward(
            self,  # type: ignore
            words: Dict[str, torch.LongTensor],
            pos_tags: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        metadata : List[Dict[str, Any]], optional (default=None)
            A dictionary of metadata for each batch element which has keys:
                words : ``List[str]``, required.
                    The tokens in the original sentence.
                pos : ``List[str]``, required.
                    The dependencies POS tags for each word.
        head_tags : = edge_labels torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold edge labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        edge_label_loss : ``torch.FloatTensor``
            The loss contribution from the edge labels.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        edge_labels : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        encoded_text = self._dropout(encoded_text)

        edge_existence_scores = self.edge_model.edge_existence(
            encoded_text, mask)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads = self._greedy_decode_arcs(edge_existence_scores,
                                                       mask)
            edge_label_logits = self.edge_model.label_scores(
                encoded_text, predicted_heads)
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)
        else:
            #Find best tree with CLE
            predicted_heads = cle_decode(edge_existence_scores,
                                         mask.data.sum(dim=1).long())
            #With info about tree structure, get edge label scores
            edge_label_logits = self.edge_model.label_scores(
                encoded_text, predicted_heads)
            #Predict edge labels
            predicted_edge_labels = self._greedy_decode_edge_labels(
                edge_label_logits)

        output_dict = {
            "heads":
            predicted_heads,
            "edge_labels":
            predicted_edge_labels,
            "mask":
            mask,
            "words": [meta["words"] for meta in metadata],
            "pos": [meta["pos"] for meta in metadata],
            "position_in_corpus":
            [meta["position_in_corpus"] for meta in metadata],
        }

        if head_indices is not None and head_tags is not None:
            gold_edge_label_logits = self.edge_model.label_scores(
                encoded_text, head_indices)
            edge_label_loss = self.loss_function.label_loss(
                gold_edge_label_logits, mask, head_tags)

            arc_nll = self.loss_function.edge_existence_loss(
                edge_existence_scores, head_indices, mask)

            loss = arc_nll + edge_label_loss

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_edge_labels[:, 1:],
                                    head_indices[:, 1:], head_tags[:, 1:],
                                    evaluation_mask)
            output_dict["arc_loss"] = arc_nll
            output_dict["edge_label_loss"] = edge_label_loss
            output_dict["loss"] = loss

        if self.pass_over_data_just_started:
            # here we could decide if we want to start collecting info for the decoder.
            pass
        self.pass_over_data_just_started = False
        return output_dict
 def get_classification_embedding(self, token_ids: torch.LongTensor, mask: torch.LongTensor) -> torch.Tensor:
     embeddings = self.transformer_model(input_ids=token_ids, attention_mask=mask)[0]
     batch_size = token_ids.size(0)
     cls_offsets = token_ids.new_zeros([batch_size]).unsqueeze(1)
     classification_embedding = get_select_embedding(embeddings, cls_offsets).squeeze(1)
     return classification_embedding
Beispiel #13
0
    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        grammar_value_logits = self._gram_val_output(encoded_text)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

        # Заведем выход предсказания грамматической метки на вход лемматизатора -- ЭКСПЕРИМЕНТАЛЬНОЕ
        #l_ext_input = encoded_text
        l_ext_input = torch.cat([encoded_text, grammar_value_logits], -1)
        lemma_logits = self._lemma_output(l_ext_input)
        predicted_lemmas = lemma_logits.argmax(-1)

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ВАРИАНТОВ ЛЕММАТИЗАЦИИ И ОЦЕНОК ВЕРОЯТНОСТИ
        lemma_probs = torch.nn.functional.softmax(lemma_logits, -1)
        top_lemmas_indices = (-lemma_logits).argsort(-1)[:, :, :self.TopNCnt]
        #top_lemmas_indices = (-lemma_probs).argsort(-1)[:,:,:self.TopNCnt]
        top_lemmas_prob = torch.gather(lemma_probs, -1, top_lemmas_indices)
        #top_lemmas_prob = torch.gather(lemma_logits, -1, top_lemmas_indices)

        # АНАЛОГИЧНО ДЛЯ ГРАММЕМ
        gramm_probs = torch.nn.functional.softmax(grammar_value_logits, -1)
        top_gramms_indices = (
            -grammar_value_logits).argsort(-1)[:, :, :self.TopNCnt]
        top_gramms_prob = torch.gather(gramm_probs, -1, top_gramms_indices)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            synt_prediction, benrg = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
            predicted_heads, predicted_head_tags = synt_prediction

        # ПОЛУЧЕНИЕ TOP-N НАИБОЛЕЕ ВЕРОЯТНЫХ ЛОКАЛЬНЫХ!!! (не mst) ВАРИАНТОВ СИНТАКСИЧЕСКОГО РАЗБОРА И ОЦЕНОК ВЕРОЯТНОСИ
        benrgf = torch.flatten(benrg, start_dim=1, end_dim=2).permute(
            0, 2, 1)  # склеивает тип синт. отношения с индексом родителя
        top_deprels_indices = (-benrgf).argsort(
            -1)[:, :, :self.TopNCnt]  # отбираем наилучшие комбинации
        top_deprels_prob = torch.gather(benrgf, -1, top_deprels_indices)
        seqlen = benrg.shape[2]
        top_heads = torch.fmod(top_deprels_indices, seqlen)
        top_deprels_indices = torch.div(top_deprels_indices,
                                        seqlen)  # torch.floor не срабатывает

        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            grammar_nll = self._update_multiclass_prediction_metrics(
                logits=grammar_value_logits,
                targets=grammar_values,
                mask=token_mask,
                accuracy_metric=self._gram_val_prediction_accuracy)

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            lemma_nll = self._update_multiclass_prediction_metrics(
                logits=lemma_logits,
                targets=lemma_indices,
                mask=token_mask,
                accuracy_metric=self._lemma_prediction_accuracy,
                masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX)

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
            "top_lemmas": top_lemmas_indices,
            "top_lemmas_prob": top_lemmas_prob,
            "top_gramms": top_gramms_indices,
            "top_gramms_prob": top_gramms_prob,
            "top_heads": top_heads,
            "top_deprels": top_deprels_indices,
            "top_deprels_prob": top_deprels_prob,
        }

        return output_dict
Beispiel #14
0
    def _parse(
        self,
        embedded_text_input: torch.Tensor,
        mask: torch.BoolTensor,
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, sequence_length, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self.head_arc_feedforward(encoded_text)
        child_arc_representation = self.child_arc_feedforward(encoded_text)

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self.head_tag_feedforward(encoded_text)
        child_tag_representation = self.child_tag_feedforward(encoded_text)

        # calculate dimensions again as sequence_length is now + 1 from adding the head_sentinel
        batch_size, sequence_length, arc_dim = head_arc_representation.size()
        
        # now repeat the token representations to form a matrix:
        # shape (batch_size, sequence_length, sequence_length, arc_representation_dim)
        heads = head_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim) # heads in one direction
        deps = child_arc_representation.repeat(1, sequence_length, 1).reshape(batch_size, sequence_length, sequence_length, arc_dim).transpose(1, 2) # deps in the other direction  
        
        # shape (batch_size, sequence_length, sequence_length, arc_representation_dim)
        combined_arcs = self.activation(heads + deps)

        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_out_layer(combined_arcs).squeeze(3)
        
        minus_inf = -1e8
        minus_mask = ~mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation, attended_arcs, mask
            )
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation, attended_arcs, mask
            )
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask,
            )
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask,
            )

        return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
Beispiel #15
0
    def forward(
        self,  # type: ignore
        sentence: Dict[str, torch.LongTensor],
        worlds: List[List[NlvrLanguage]],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        identifier: List[str] = None,
        labels: torch.LongTensor = None,
        epoch_num: List[int] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError(
                    "If you want a dynamic cost weight, use the "
                    "BucketIterator with track_epoch=True.")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f",
                                self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            agenda.new_zeros(1, dtype=torch.float) for i in range(batch_size)
        ]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                ))
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states,
        )
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(  # type: ignore
            initial_state, self._decoder_step,
            partial(self._get_state_cost, worlds))
        if identifier is not None:
            outputs["identifier"] = identifier
        best_final_states = outputs["best_final_states"]
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [
                state.action_history[0] for state in states
            ]
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(
                action_strings=batch_action_strings,
                worlds=worlds,
                label_strings=label_strings,
                possible_actions=actions,
                agenda_data=agenda_data,
            )
        else:
            # We're testing.
            if metadata is not None:
                outputs["sentence_tokens"] = [
                    x["sentence_tokens"] for x in metadata
                ]
            outputs["debug_info"] = []
            for i in range(batch_size):
                outputs["debug_info"].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs["action_mapping"] = action_mapping
        return outputs
Beispiel #16
0
    def forward(self,  # type: ignore
                # words: Dict[str, torch.LongTensor],
                encoded_text: torch.FloatTensor,
                mask: torch.LongTensor,
                pos_logits: torch.LongTensor = None,  # predicted
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, _, _ = encoded_text.size()

        pos_tags = None
        if pos_logits is not None and self.pos_tag_embedding is not None:
            # Embed the predicted POS tags and concatenate the embeddings to the input
            num_pos_classes = pos_logits.size(-1)
            pos_logits = pos_logits.view(-1, num_pos_classes)
            _, pos_tags = pos_logits.max(-1)

            pos_embed_size = self.pos_tag_embedding.get_output_dim()
            embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags))
            embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size)
            encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1)

        encoded_text = self.encoder(encoded_text, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attachment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
            # "pos": [meta["pos"] for meta in metadata]
        }

        return output_dict
Beispiel #17
0
    def forward(self,  # type: ignore
                words: Dict[str, torch.LongTensor],
                pos_tags: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required.
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text_orig = self.encoder(embedded_text_input, mask)

        encoder_final_state = get_final_encoder_states(encoded_text_orig, mask)

        batch_size, _, encoding_dim = encoded_text_orig.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text_orig], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
                "encoder_final_state": encoder_final_state,
                "encoded_text": encoded_text_orig,
                "heads": predicted_heads,
                "head_tags": predicted_head_tags,
                "arc_loss": arc_nll,
                "tag_loss": tag_nll,
                "loss": loss,
                "mask": mask,
                "words": [meta["words"] for meta in metadata],
                "pos": [meta["pos"] for meta in metadata]
                }

        return output_dict
Beispiel #18
0
    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        # добавим измеремение, которое каждому выходу энкодера ставит в соответствие три его копии
        encoded_text_3 = encoded_text
        encoded_text_3 = torch.unsqueeze(encoded_text_3, 2)
        encoded_text_3 = encoded_text_3.repeat(1,1,3,1)
        # пропустим три копии вектора (с выхода энкодера) через lstm
        seq_len = encoded_text.size()[1]
        emb_div_val = encoded_text.size()[2]
        multi_triplets = torch.reshape(encoded_text_3, (-1, 3, emb_div_val))
        label_variants, _ = self.multilabeler_lstm(multi_triplets)
        batched_label_variants = torch.reshape(label_variants, (-1, seq_len, 3, emb_div_val))
#         # отладочный вывод
#         print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n")
#         print( "ITLOG: encoded_text.size() = {}".format(encoded_text.size()) )
#         print( "ITLOG: encoded_text_3.size() = {}".format(encoded_text_3.size()) )
#         print( "ITLOG: multi_triplets.size() = {}".format(multi_triplets.size()) )
#         print( "ITLOG: label_variants.size() = {}".format(label_variants.size()) )
#         print( "ITLOG: batched_label_variants.size() = {}".format(batched_label_variants.size()) )
#         print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n")

#        grammar_value_logits = self._gram_val_output(encoded_text)
        grammar_value_logits = self._gram_val_output(batched_label_variants)
#         print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n")
#         print( "ITLOG: grammar_value_logits.size() = {}".format(grammar_value_logits.size()) )
#         print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n")
#        grammar_value_logits = grammar_value_logits.select(2, 0)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

#        lemma_logits = self._lemma_output(encoded_text)
        lemma_logits = self._lemma_output(batched_label_variants)
        predicted_lemmas = lemma_logits.argmax(-1)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            token_mask_3 = token_mask
            token_mask_3 = torch.unsqueeze(token_mask_3, 2)
            token_mask_3 = token_mask_3.repeat(1,1,3)            
#             print("\n\n------------------------------------------------- ITLOG-BEGIN ------------------------------------------\n")
#             print( "ITLOG: token_mask.size = {}".format(token_mask.size()) )
#             print( "ITLOG: token_mask_3.size = {}".format(token_mask_3.size()) )
#             print( "ITLOG: token_mask_3 = {}".format(token_mask_3) )
#             print("\n------------------------------------------------- ITLOG-END ------------------------------------------\n")            
            grammar_nll = self._update_multiclass_prediction_metrics_3(
                logits=grammar_value_logits, targets=grammar_values,
                mask=token_mask_3, accuracy_metric=self._gram_val_prediction_accuracy
            )

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            token_mask_3 = token_mask
            token_mask_3 = torch.unsqueeze(token_mask_3, 2)
            token_mask_3 = token_mask_3.repeat(1,1,3)            
            lemma_nll = self._update_multiclass_prediction_metrics_3(
                logits=lemma_logits, targets=lemma_indices,
                mask=token_mask_3, accuracy_metric=self._lemma_prediction_accuracy #, masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX
            )

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
        }

        return output_dict
Beispiel #19
0
def scatter_topk_2d_flat(
    src: Tensor, index: LongTensor, k: int, dim_size=None, fill_value=None
) -> Tuple[Tensor, Tuple[LongTensor, LongTensor], Tuple[LongTensor, LongTensor]]:
    """Finds the top k values in a 2D array partitioned along the dimension 0.

    ::

        +-----------------------+
        |          X            |
        |  X                    |
        |              X        |
        |     X                 |
        +-----------------------+
        |                       |
        |                 Y     |
        |       Y               |              +-------+
        |                       |              |X X X X|
        |                       |    top 4     +-------+
        |                       |  -------->   |X X X X|
        |                       |              +-------+
        |             Y         |              |Z Z Z Z|
        |                       |              +-------+
        |   Y                   |
        |                       |
        +-----------------------+
        |                       |
        |     Z       Z         |
        |                       |
        |        Z        Z     |
        |                       |
        +-----------------------+


    Args:
        src:
        index:
        k:
        dim_size:
        fill_value:

    Returns:

    """
    if src.ndimension() != 2:
        raise ValueError("Only implemented for 2D tensors")

    if dim_size is None:
        dim_size = index.max().item() + 1

    if fill_value is None:
        fill_value = float("NaN")

    ncols = src.shape[1]

    result_values = src.new_full((dim_size, k), fill_value=fill_value)
    result_indexes_whole_0 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_whole_1 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_within_chunk_0 = index.new_full((dim_size, k), fill_value=-1)
    result_indexes_within_chunk_1 = index.new_full((dim_size, k), fill_value=-1)

    chunk_sizes = (
        index.new_zeros(dim_size)
        .scatter_add_(dim=0, index=index, src=torch.ones_like(index))
        .tolist()
    )

    start_src = 0
    for chunk_idx, chunk_size in enumerate(chunk_sizes):
        flat_chunk = src[start_src : start_src + chunk_size, :].flatten()
        flat_values, flat_indexes = torch.topk(
            flat_chunk, k=min(k, chunk_size * ncols), dim=0
        )
        result_values[chunk_idx, : len(flat_values)] = flat_values

        indexes_0 = flat_indexes / ncols
        indexes_1 = flat_indexes % ncols
        result_indexes_within_chunk_0[chunk_idx, : len(flat_indexes)] = indexes_0
        result_indexes_within_chunk_1[chunk_idx, : len(flat_indexes)] = indexes_1

        result_indexes_whole_0[chunk_idx, : len(flat_indexes)] = indexes_0 + start_src
        result_indexes_whole_1[chunk_idx, : len(flat_indexes)] = indexes_1

        start_src += chunk_size

    return (
        result_values,
        (result_indexes_whole_0, result_indexes_whole_1),
        (result_indexes_within_chunk_0, result_indexes_within_chunk_1),
    )
Beispiel #20
0
    def _parse(self,
               embedded_text_input: torch.Tensor,
               mask: torch.LongTensor,
               head_tags: torch.LongTensor = None,
               head_indices: torch.LongTensor = None,
               grammar_values: torch.LongTensor = None,
               lemma_indices: torch.LongTensor = None):

        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        grammar_value_logits = self._gram_val_output(encoded_text)
        predicted_gram_vals = grammar_value_logits.argmax(-1)

        lemma_logits = self._lemma_output(encoded_text)
        predicted_lemmas = lemma_logits.argmax(-1)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        token_mask = mask.float()
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask)
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask)

        grammar_nll = torch.tensor(0.)
        if grammar_values is not None:
            grammar_nll = self._update_multiclass_prediction_metrics(
                logits=grammar_value_logits,
                targets=grammar_values,
                mask=token_mask,
                accuracy_metric=self._gram_val_prediction_accuracy)

        lemma_nll = torch.tensor(0.)
        if lemma_indices is not None:
            lemma_nll = self._update_multiclass_prediction_metrics(
                logits=lemma_logits,
                targets=lemma_indices,
                mask=token_mask,
                accuracy_metric=self._lemma_prediction_accuracy,
                masked_index=self.lemmatize_helper.UNKNOWN_RULE_INDEX)

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "gram_vals": predicted_gram_vals,
            "lemmas": predicted_lemmas,
            "mask": mask,
            "arc_nll": arc_nll,
            "tag_nll": tag_nll,
            "grammar_nll": grammar_nll,
            "lemma_nll": lemma_nll,
        }

        return output_dict
    def __call__(self,
                 entity_ids: torch.LongTensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns the set of valid parent entities for a given mention.

        Parameters
        ----------
        entity_ids : ``torch.LongTensor``
            A tensor of shape ``(batch_size, sequence_length)`` whose elements are the ids
            of the corresponding token in the ``target`` sequence.

        Returns
        -------
        A tuple ``(candidate_ids, candidate_mask)`` containings the following elements:
        candidate_ids : ``torch.LongTensor``
            A tensor of shape ``(batch_size, n_candidates)`` of all of the candidates for each
            batch element.
        candidate_mask : ``torch.LongTensor``
            A tensor of shape ``(batch_size, sequence_length, n_candidates)`` defining which
            subset of candidates can be selected at the given point in the sequence.
        """
        batch_size, sequence_length = entity_ids.shape[:2]

        # TODO: See if we can get away without nested loops / cast to CPU.
        candidate_ids = self._get_candidates(entity_ids)
        candidate_lookup = [{parent_id: j for j, parent_id in enumerate(
            l)} for l in candidate_ids.tolist()]

        # Create mask
        candidate_mask = entity_ids.new_zeros(size=(batch_size, sequence_length, candidate_ids.shape[-1]),
                                              dtype=torch.uint8)

        # Start by accounting for unfinished masks that remain from the last batch
        for i, lookup in enumerate(self._remaining):
            for parent_id, remainder in lookup.items():
                # Find index w.r.t. the **current** set of candidates
                k = candidate_lookup[i][parent_id]
                # Fill in the remaining amount of mask
                candidate_mask[i, :remainder, k] = 1
                # If splits are really short, then we might still have some remaining
                lookup[parent_id] -= sequence_length

        # Cast to list so we can use elements as keys (not possible for tensors)
        parent_id_list = entity_ids.tolist()
        for i, j, *_, parent_id in nested_enumerate(parent_id_list):
            if parent_id == 0:
                continue
            else:
                # Fill in mask
                k = candidate_lookup[i][parent_id]
                candidate_mask[i, j + 1: j + self._cutoff + 1, k] = 1
                # Track how many sequence elements remain
                remainder = sequence_length - (j + self._cutoff + 1)
                self._remaining[i][parent_id] = (
                    j + self._cutoff + 1) - sequence_length

        # Remove any ids for non-recent parents (e.g. those without remaining mask)
        for i, lookup in enumerate(self._remaining):
            self._remaining[i] = {key: value for key,
                                  value in lookup.items() if value > 0}

        return candidate_ids, candidate_mask
Beispiel #22
0
    def forward(
        self,  # type: ignore
        token_ids: torch.LongTensor,
        type_ids: torch.LongTensor,
        offsets: torch.LongTensor,
        wordpiece_mask: torch.BoolTensor,
        dep_idxs: torch.LongTensor,
        dep_tags: torch.LongTensor,
        pos_tags: torch.LongTensor,
        word_mask: torch.BoolTensor,
    ):

        embedded_text_input = self.get_word_embedding(
            token_ids=token_ids,
            offsets=offsets,
            wordpiece_mask=wordpiece_mask,
            type_ids=type_ids,
        )
        if self.pos_embedding is not None:
            embedded_pos_tags = self.pos_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
            if self.fuse_layer is not None:
                embedded_text_input = self.fuse_layer(embedded_text_input)
        # todo compare normal dropout with InputVariationalDropout
        embedded_text_input = self._input_dropout(embedded_text_input)

        if self.additional_encoder is not None:
            if self.config.additional_layer_type == "transformer":
                extended_attention_mask = self.bert.get_extended_attention_mask(
                    word_mask, word_mask.size(), word_mask.device)
                encoded_text = self.additional_encoder(
                    hidden_states=embedded_text_input,
                    attention_mask=extended_attention_mask)[0]
            else:
                encoded_text = self.additional_encoder(
                    inputs=embedded_text_input, mask=word_mask)
        else:
            encoded_text = embedded_text_input

        batch_size, _, encoding_dim = encoded_text.size()
        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        word_mask = torch.cat([word_mask.new_ones(batch_size, 1), word_mask],
                              1)
        dep_idxs = torch.cat([dep_idxs.new_zeros(batch_size, 1), dep_idxs], 1)
        dep_tags = torch.cat([dep_tags.new_zeros(batch_size, 1), dep_tags], 1)

        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = ~word_mask * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, word_mask)

        arc_nll, tag_nll = self._construct_loss(
            head_tag_representation=head_tag_representation,
            child_tag_representation=child_tag_representation,
            attended_arcs=attended_arcs,
            head_indices=dep_idxs,
            head_tags=dep_tags,
            mask=word_mask,
        )

        return predicted_heads, predicted_head_tags, arc_nll, tag_nll
    def _parse(
        self,
        encoded_text: torch.Tensor,
        mask: torch.LongTensor,
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat(
                [head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat(
                [head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(
            self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(
            self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(
            self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(
            self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(
            2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(
                head_tag_representation, child_tag_representation,
                attended_arcs, mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=head_indices,
                head_tags=head_tags,
                mask=mask,
            )
        else:
            arc_nll, tag_nll = self._construct_loss(
                head_tag_representation=head_tag_representation,
                child_tag_representation=child_tag_representation,
                attended_arcs=attended_arcs,
                head_indices=predicted_heads.long(),
                head_tags=predicted_head_tags.long(),
                mask=mask,
            )

        return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll
Beispiel #24
0
    def edit_distance_q_values(sampled_y: torch.LongTensor,
                               gold_y: torch.LongTensor,
                               end_symbol_id: int,
                               vocab_size: int) -> torch.FloatTensor:
        """ OCD Edit Distance to compute QValues.
        Args:
            sampled_y (`~torch.LongTensor`): ``(batch_size, sequence_length)``
            gold_y (`~torch.LongTensor`): ``(batch_size, sequence_length)``
            end_symbol_id (int): index of the end symbol in your vocabulary.
            vocab_size (int): the number of possible output tokens.

        Returns:
            `~torch.FloatTensor`: ``(batch_size, sequence_length, vocab_size)``

        Example:
            from paper `https://arxiv.org/abs/1810.01398`

            vocabulary = {'S':0, 'U':1, 'N':2, 'D':3, 'A':4 , 'Y':5,
                          'T':6, 'R':7, 'P':8, '</s>':9, '<pad>': 10}
            vocab_size = 11
            end_symbol_id = 9

            gold Y = {'SUNDAY</s><pad><pad>', 'SUNDAY</s><pad><pad>'}
            gold_y = [[0, 1, 2, 3, 4, 5, 9, 10, 10],
                      [0, 1, 2, 3, 4, 5, 9, 10, 10]]

            sampled Y = {'SATURDAY</s>', 'SATRAPY</s>U'}
            sampled_y = [[0, 4, 6, 1, 7, 3, 4, 5, 9],
                         [0, 4, 6, 7, 4, 8, 5, 9, 1]]


            # expected size: (batch_size=2, sequence_lenght=9, vocab_size=11)
            expected q_values = [[[ 0., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-1.,  0., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-2., -1., -1., -2., -2., -2., -2., -2., -2., -2., -2.],
                                  [-3., -2., -2., -2., -3., -3., -3., -3., -3., -3., -3.],
                                  [-3., -3., -2., -3., -3., -3., -3., -3., -3., -3., -3.],
                                  [-4., -4., -3., -3., -4., -4., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -3., -4., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -4., -3., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -4., -4., -4., -4., -4., -3., -4.]],

                                 [[ 0., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-1.,  0., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-2., -1., -1., -2., -2., -2., -2., -2., -2., -2., -2.],
                                  [-3., -2., -2., -2., -3., -3., -3., -3., -3., -3., -3.],
                                  [-4., -3., -3., -3., -3., -4., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -4., -3., -4., -4., -4., -4., -4.],
                                  [-5., -5., -5., -5., -5., -4., -5., -5., -5., -4., -5.],
                                  [-5., -5., -5., -5., -5., -5., -5., -5., -5., -4., -5.],
                                  [-6., -6., -6., -6., -6., -6., -6., -6., -6., -5., -6.]]]
        """
        assert gold_y.size() == sampled_y.size()
        b_sz, seq_len = gold_y.size()
        q_values = gold_y.new_zeros((b_sz, seq_len + 1, vocab_size), dtype=torch.float)
        edit_dists = gold_y.new_zeros((b_sz, seq_len + 1, seq_len + 1), dtype=torch.float)

        # run batch version of the levenshtein algorithm
        edit_dists[:, :, 0] = torch.arange(seq_len + 1)
        edit_dists[:, 0, :] = torch.arange(seq_len + 1)
        for i in range(1, seq_len + 1):
            for j in range(1, seq_len + 1):
                cost = (sampled_y[:, i-1] != gold_y[:, j-1]).float()
                min_cost, _ = torch.cat(((edit_dists[:, i-1, j] + 1).unsqueeze(dim=1),
                                         (edit_dists[:, i, j-1] + 1).unsqueeze(dim=1),
                                         (edit_dists[:, i-1, j-1] + cost).unsqueeze(dim=1)),
                                        dim=1).min(dim=1)
                edit_dists[:, i, j] = min_cost
        # #

        # find gold next tokens and update their QValues
        edit_dists_mask = OCD.edit_distance_mask(gold_y, end_symbol_id)
        edit_dists = edit_dists.masked_fill_(edit_dists_mask, LARGE_NUM)
        min_dist, _ = edit_dists.min(dim=2)
        min_dist = min_dist.unsqueeze(dim=2)
        steps_with_min_dists = (edit_dists == min_dist)
        extended_gold_y = gold_y.repeat(1, seq_len + 1).view(b_sz, seq_len + 1, seq_len)
        indices = steps_with_min_dists.nonzero()
        gold_next_tokens = extended_gold_y[indices.split(1, dim=1)]
        indices[:, 2] = gold_next_tokens.squeeze(dim=1)
        q_values[indices.split(1, dim=1)] = 1
        q_values = q_values - (1 + min_dist)
        return q_values[:, :-1, :]  # ignore the step 'seq_len + 1'
    def forward(self,  # type: ignore
                words: Dict[str, torch.LongTensor],
                pos_tags: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required.
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
                "heads": predicted_heads,
                "head_tags": predicted_head_tags,
                "arc_loss": arc_nll,
                "tag_loss": tag_nll,
                "loss": loss,
                "mask": mask,
                "words": [meta["words"] for meta in metadata],
                "pos": [meta["pos"] for meta in metadata]
                }

        return output_dict
Beispiel #26
0
    def forward(self,
                inputs: Dict[str, Any],
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None):
        with self.input_encoding_timer as _:
            embeded_input = {}
            for name, fn in self.input_layers.items():
                input_ = inputs[name]
                embeded_input[name] = fn(input_)

            encoded_input = []
            for encoder_ in self.input_encoders:
                ordered_names = encoder_.get_ordered_names()
                args_ = {name: embeded_input[name] for name in ordered_names}
                encoded_input.append(self.input_dropout_(encoder_(args_)))

            encoded_input = torch.cat(encoded_input, dim=-1)
            # encoded_input: (batch_size, seq_len, input_dim)

        with self.context_encoding_timer as _:
            mask = get_mask_from_sequence_lengths(inputs['length'],
                                                  inputs['length'].max())
            # mask: (batch_size, seq_len)

            context_encoded_input = self.encoder(encoded_input, mask)
            # context_encoded_input: (batch_size, seq_len, encoded_dim)

            # handle the sentinel/dummy root.
            batch_size, _, encoding_dim = context_encoded_input.size()
            head_sentinel = self.head_sentinel_.expand(batch_size, 1,
                                                       encoding_dim)
            context_encoded_input = torch.cat(
                [head_sentinel, context_encoded_input], 1)
            # context_encoded_input: (batch_size, seq_len + 1, encoded_dim)

        with self.classification_timer as _:
            mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
            # mask: (batch_size, seq_len + 1)

            if head_indices is not None:
                head_indices = torch.cat(
                    [head_indices.new_zeros(batch_size, 1), head_indices], 1)
            if head_tags is not None:
                head_tags = torch.cat(
                    [head_tags.new_zeros(batch_size, 1), head_tags], 1)

            context_encoded_input = self.dropout_(context_encoded_input)

            head_arc_representation = self.head_arc_feedforward(
                context_encoded_input)
            child_arc_representation = self.child_arc_feedforward(
                context_encoded_input)

            head_tag_representation = self.head_tag_feedforward(
                context_encoded_input)
            child_tag_representation = self.child_tag_feedforward(
                context_encoded_input)

            # head_tag_representation / child_tag_representation: (batch_size, seq_len + 1, dim)
            arc_representation = self.dropout_(
                torch.cat([head_arc_representation, child_arc_representation],
                          dim=1))
            tag_representation = self.dropout_(
                torch.cat([head_tag_representation, child_tag_representation],
                          dim=1))

            head_arc_representation, child_arc_representation = arc_representation.chunk(
                2, dim=1)
            head_tag_representation, child_tag_representation = tag_representation.chunk(
                2, dim=1)

            head_tag_representation = head_tag_representation.contiguous()
            child_tag_representation = child_tag_representation.contiguous()

            # attended_arcs: (batch_size, seq_len + 1, seq_len + 1)
            attended_arcs = self.arc_attention(head_arc_representation,
                                               child_arc_representation)

            if not self.training:
                if not self.use_mst_decoding_for_validation:
                    predicted_heads, predicted_head_tags = self._greedy_decode(
                        head_tag_representation, child_tag_representation,
                        attended_arcs, mask)
                else:
                    predicted_heads, predicted_head_tags = self._mst_decode(
                        head_tag_representation, child_tag_representation,
                        attended_arcs, mask)
            else:
                predicted_heads, predicted_head_tags = None, None

            if head_indices is not None and head_tags is not None:

                arc_nll, tag_nll = self._construct_loss(
                    head_tag_representation=head_tag_representation,
                    child_tag_representation=child_tag_representation,
                    attended_arcs=attended_arcs,
                    head_indices=head_indices,
                    head_tags=head_tags,
                    mask=mask)
                loss = arc_nll + tag_nll
            else:
                arc_nll, tag_nll = self._construct_loss(
                    head_tag_representation=head_tag_representation,
                    child_tag_representation=child_tag_representation,
                    attended_arcs=attended_arcs,
                    head_indices=predicted_heads.long(),
                    head_tags=predicted_head_tags.long(),
                    mask=mask)
                loss = arc_nll + tag_nll

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
        }

        return output_dict