Beispiel #1
0
    def get_initial_forced_decoder_input(
        self,
        bsz: int,
        inputs: torch.LongTensor,
        n_docs: int,
        start_idx: int,
        end_idx: int,
        input_turns_cnt: Optional[torch.LongTensor] = None,
    ) -> torch.LongTensor:
        """
        Return the initial input to the decoder during training.

        Repeat inputs n_docs times.

        :param bsz:
            batchsize
        :param inputs:
            inputs to decode
        :param n_docs:
            number of docs per input
        :param start_idx:
            start token idx
        :param end_idx:
            end token idx
        :param input_turns_cnt:
            an optional tensor containing the number of turns of each corresponding context.

        :return initial_input:
            initial input for the decoder.
        """
        inputs = get_forced_decoder_inputs(
            inputs, bsz, start_idx, end_idx, self.generation_model
        )
        inputs = inputs.repeat(1, n_docs).reshape(-1, inputs.size(1))  # type: ignore
        return inputs
Beispiel #2
0
    def thorough_generation(
        cls,
        hyps: List[torch.LongTensor],
        new_input: torch.LongTensor,
        null_idx: int,
        model: RagModel,
    ) -> List[Tuple[torch.LongTensor, torch.Tensor]]:
        """
        Apply RAG-sequence thorough generation for a single batch item.

        Recomputes model scores with given hypotheses, sorts accordingly.

        :param hyps:
            list of candidate hypotheses
        :param new_input:
            input for the model

        :return sorted_hyps:
            return list of (hyp, score) tuples, sorted by their score.
        """
        # deduplicate, exclude BOS Token
        hyps = list({str(h.tolist()): h[1:] for h in hyps}.values())  # type: ignore
        new_input = new_input.repeat(len(hyps), 1)  # type: ignore
        new_ys, _ = padded_tensor(
            hyps, fp16friendly=new_input.size(1) % FP16_PAD_SIZE == 0, pad_idx=null_idx
        )
        new_ys = new_ys.to(new_input.device)
        scores, *_ = model.seq2seq_forward_pass(new_input, new_ys)
        loss = cls._rag_sequence_loss(
            new_ys.unsqueeze(1).unsqueeze(-1), scores.unsqueeze(1), null_idx
        )  # type: ignore
        sorted_by_score = [
            (hyps[idx], loss[idx]) for idx in loss.sort()[-1]
        ]  # sort ascending
        return sorted_by_score
Beispiel #3
0
    def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor:
        """Generate negative samples from the positive batch."""
        if self.num_negs_per_pos > 1:
            positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1)

        # Bind number of negatives to sample
        num_negs = positive_batch.shape[0]

        # Equally corrupt head and tail
        split_idx = num_negs // 2

        # Copy positive batch for corruption.
        # Do not detach, as no gradients should flow into the indices.
        negative_batch = positive_batch.clone()

        # Sample random entities as replacement
        negative_entities = torch.randint(high=self.num_entities - 1, size=(num_negs,), device=positive_batch.device)

        # Replace heads – To make sure we don't replace the head by the original value
        # we shift all values greater or equal than the original value by one up
        # for that reason we choose the random value from [0, num_entities -1]
        filter_same_head = (negative_entities[:split_idx] >= positive_batch[:split_idx, 0])
        negative_batch[:split_idx, 0] = negative_entities[:split_idx] + filter_same_head.long()
        # Corrupt tails
        filter_same_tail = (negative_entities[split_idx:] >= positive_batch[split_idx:, 2])
        negative_batch[split_idx:, 2] = negative_entities[split_idx:] + filter_same_tail.long()

        return negative_batch
Beispiel #4
0
 def get_initial_decoder_input(self,
                               input: torch.LongTensor) -> torch.LongTensor:
     """
     Repeat the decoder input accordingly.
     """
     return input.repeat(1,
                         self.n_docs).reshape(-1,
                                              input.size(1))  # type: ignore
Beispiel #5
0
    def sample(
        self, positive_batch: torch.LongTensor
    ) -> Tuple[torch.LongTensor, Optional[torch.Tensor]]:
        """Sample a negative batched based on the bern approach."""
        if self.num_negs_per_pos > 1:
            positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1)

        # Bind number of negatives to sample
        num_negs = positive_batch.shape[0]

        # Copy positive batch for corruption.
        # Do not detach, as no gradients should flow into the indices.
        negative_batch = positive_batch.clone()

        device = positive_batch.device
        # Decide whether to corrupt head or tail
        head_corruption_probability = self.corrupt_head_probability[
            positive_batch[:, 1]]
        head_mask = torch.rand(
            num_negs,
            device=device) < head_corruption_probability.to(device=device)

        # Tails are corrupted if heads are not corrupted
        tail_mask = ~head_mask

        # Randomly sample corruption. See below for explanation of
        # why this is on a range of [0, num_entities - 1]
        negative_entities = torch.randint(
            self.triples_factory.num_entities - 1,
            size=(num_negs, ),
            device=positive_batch.device,
        )

        # Replace heads
        negative_batch[head_mask, 0] = negative_entities[head_mask]

        # Replace tails
        negative_batch[tail_mask, 2] = negative_entities[tail_mask]

        # If filtering is activated, all negative triples that are positive in the training dataset will be removed
        if self.filtered:
            negative_batch, batch_filter = self.filter_negative_triples(
                negative_batch=negative_batch)
        else:
            # To make sure we don't replace the head by the original value
            # we shift all values greater or equal than the original value by one up
            # for that reason we choose the random value from [0, num_entities -1]
            negative_batch[head_mask,
                           0] += (negative_batch[head_mask, 0] >=
                                  positive_batch[head_mask, 0]).long()
            negative_batch[tail_mask,
                           2] += (negative_batch[tail_mask, 2] >=
                                  positive_batch[tail_mask, 2]).long()
            batch_filter = None

        return negative_batch, batch_filter
Beispiel #6
0
    def sample(
        self, positive_batch: torch.LongTensor
    ) -> Tuple[torch.LongTensor, Optional[torch.Tensor]]:
        """Generate negative samples from the positive batch."""
        if self.num_negs_per_pos > 1:
            positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1)

        # Bind number of negatives to sample
        num_negs = positive_batch.shape[0]

        # Equally corrupt all sides
        split_idx = num_negs // len(self._corruption_indices)

        # Copy positive batch for corruption.
        # Do not detach, as no gradients should flow into the indices.
        negative_batch = positive_batch.clone()

        for index, start in zip(self._corruption_indices,
                                range(0, num_negs, split_idx)):
            stop = min(start + split_idx, num_negs)

            # Relations have a different index maximum than entities
            index_max = self.num_relations if index == 1 else self.num_entities

            # If we do not use a filterer, we at least make sure to not replace the triples by the original value
            if self.filterer is None:
                index_max -= 1

            negative_batch[start:stop, index] = torch.randint(
                high=index_max,
                size=(stop - start, ),
                device=positive_batch.device,
            )

            # To make sure we don't replace the {head, relation, tail} by the
            # original value we shift all values greater or equal than the original value by one up
            # for that reason we choose the random value from [0, num_{heads, relations, tails} -1]
            if self.filterer is None:
                negative_batch[start:stop,
                               index] += (negative_batch[start:stop, index] >=
                                          positive_batch[start:stop,
                                                         index]).long()

        # If filtering is activated, all negative triples that are positive in the training dataset will be removed
        if self.filterer is not None:
            negative_batch, batch_filter = self.filterer(
                negative_batch=negative_batch)
        else:
            batch_filter = None

        return negative_batch, batch_filter
Beispiel #7
0
def copy_idx(idx: torch.LongTensor, dim_size: int, ncopies: int,
             offset_both_idx: bool):
    idx_copies = idx.repeat(1, ncopies)

    offset = dim_size * torch.arange(
        ncopies, dtype=torch.long, device=idx.device)[:, None].expand(
            ncopies, idx.shape[1]).flatten()

    if offset_both_idx:
        idx_copies += offset[None, :]
    else:
        idx_copies[0] += offset

    return idx_copies
Beispiel #8
0
    def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor:
        """Sample a negative batched based on the bern approach."""
        if self.num_negs_per_pos > 1:
            positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1)

        # Bind number of negatives to sample
        num_negs = positive_batch.shape[0]

        # Copy positive batch for corruption.
        # Do not detach, as no gradients should flow into the indices.
        negative_batch = positive_batch.clone()

        device = positive_batch.device
        # Decide whether to corrupt head or tail
        head_corruption_probability = self.corrupt_head_probability[
            positive_batch[:, 1]]
        head_mask = torch.rand(
            num_negs,
            device=device) < head_corruption_probability.to(device=device)

        # Tails are corrupted if heads are not corrupted
        tail_mask = ~head_mask

        # Randomly sample corruption
        negative_entities = torch.randint(
            self.triples_factory.num_entities - 1,
            size=(num_negs, ),
            device=positive_batch.device,
        )

        # Replace heads – To make sure we don't replace the head by the original value
        # we shift all values greater or equal than the original value by one up
        # for that reason we choose the random value from [0, num_entities -1]
        filter_same_head = (negative_entities[head_mask] >=
                            positive_batch[:, 0][head_mask])
        negative_batch[:, 0][head_mask] = negative_entities[
            head_mask] + filter_same_head.long()

        # Replace tails
        filter_same_tail = (negative_entities[tail_mask] >=
                            positive_batch[:, 2][tail_mask])
        negative_batch[:, 2][tail_mask] = negative_entities[
            tail_mask] + filter_same_tail.long()

        return negative_batch
Beispiel #9
0
 def _construct_q_values(self, outputs: torch.FloatTensor,
                         output_symbols: torch.LongTensor,
                         targets: torch.LongTensor,
                         mask: torch.BoolTensor) -> torch.FloatTensor:
     batch_size, max_pred_len, nlabels = outputs.size()
     _, max_len = targets.size()
     q_values = outputs.new_zeros((batch_size, max_pred_len, nlabels))
     distances = self._calculate_edit_distance(output_symbols, targets,
                                               mask)
     # distances = distances[:, :-1, :-1]
     min_dists, _ = torch.min(distances, dim=-1)
     truth_values = distances == min_dists.unsqueeze(-1)
     indices = truth_values.nonzero()
     extended_targets = targets.repeat(1, max_pred_len) \
         .view(-1, max_pred_len, max_len)
     # next_indices = indices.clone()
     gold_next_tokens = extended_targets[indices.split(1, dim=1)]
     indices[:, -1] = gold_next_tokens.squeeze(dim=1)
     q_values[indices.split(1, dim=1)] = 1
     q_values = q_values - (1 + min_dists).unsqueeze(-1)
     return q_values
Beispiel #10
0
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0))

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)

        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(
            self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads,
                                      int(self._values_dim / num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._values_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads,
                                        int(self._attention_dim / num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads,
                                  int(self._attention_dim / num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(
            queries_per_head, keys_per_head.transpose(1, 2)) / self._scale

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(scaled_similarities,
                                     mask.repeat(num_heads, 1))
        attention = self._attention_dropout(attention)
        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)
        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # Note that we _cannot_ use a reshape here, because this tensor was created
        # with num_heads being the first dimension, so reshaping naively would not
        # throw an error, but give an incorrect result.
        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Beispiel #11
0
    def edit_distance_q_values(sampled_y: torch.LongTensor,
                               gold_y: torch.LongTensor,
                               end_symbol_id: int,
                               vocab_size: int) -> torch.FloatTensor:
        """ OCD Edit Distance to compute QValues.
        Args:
            sampled_y (`~torch.LongTensor`): ``(batch_size, sequence_length)``
            gold_y (`~torch.LongTensor`): ``(batch_size, sequence_length)``
            end_symbol_id (int): index of the end symbol in your vocabulary.
            vocab_size (int): the number of possible output tokens.

        Returns:
            `~torch.FloatTensor`: ``(batch_size, sequence_length, vocab_size)``

        Example:
            from paper `https://arxiv.org/abs/1810.01398`

            vocabulary = {'S':0, 'U':1, 'N':2, 'D':3, 'A':4 , 'Y':5,
                          'T':6, 'R':7, 'P':8, '</s>':9, '<pad>': 10}
            vocab_size = 11
            end_symbol_id = 9

            gold Y = {'SUNDAY</s><pad><pad>', 'SUNDAY</s><pad><pad>'}
            gold_y = [[0, 1, 2, 3, 4, 5, 9, 10, 10],
                      [0, 1, 2, 3, 4, 5, 9, 10, 10]]

            sampled Y = {'SATURDAY</s>', 'SATRAPY</s>U'}
            sampled_y = [[0, 4, 6, 1, 7, 3, 4, 5, 9],
                         [0, 4, 6, 7, 4, 8, 5, 9, 1]]


            # expected size: (batch_size=2, sequence_lenght=9, vocab_size=11)
            expected q_values = [[[ 0., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-1.,  0., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-2., -1., -1., -2., -2., -2., -2., -2., -2., -2., -2.],
                                  [-3., -2., -2., -2., -3., -3., -3., -3., -3., -3., -3.],
                                  [-3., -3., -2., -3., -3., -3., -3., -3., -3., -3., -3.],
                                  [-4., -4., -3., -3., -4., -4., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -3., -4., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -4., -3., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -4., -4., -4., -4., -4., -3., -4.]],

                                 [[ 0., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-1.,  0., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
                                  [-2., -1., -1., -2., -2., -2., -2., -2., -2., -2., -2.],
                                  [-3., -2., -2., -2., -3., -3., -3., -3., -3., -3., -3.],
                                  [-4., -3., -3., -3., -3., -4., -4., -4., -4., -4., -4.],
                                  [-4., -4., -4., -4., -4., -3., -4., -4., -4., -4., -4.],
                                  [-5., -5., -5., -5., -5., -4., -5., -5., -5., -4., -5.],
                                  [-5., -5., -5., -5., -5., -5., -5., -5., -5., -4., -5.],
                                  [-6., -6., -6., -6., -6., -6., -6., -6., -6., -5., -6.]]]
        """
        assert gold_y.size() == sampled_y.size()
        b_sz, seq_len = gold_y.size()
        q_values = gold_y.new_zeros((b_sz, seq_len + 1, vocab_size), dtype=torch.float)
        edit_dists = gold_y.new_zeros((b_sz, seq_len + 1, seq_len + 1), dtype=torch.float)

        # run batch version of the levenshtein algorithm
        edit_dists[:, :, 0] = torch.arange(seq_len + 1)
        edit_dists[:, 0, :] = torch.arange(seq_len + 1)
        for i in range(1, seq_len + 1):
            for j in range(1, seq_len + 1):
                cost = (sampled_y[:, i-1] != gold_y[:, j-1]).float()
                min_cost, _ = torch.cat(((edit_dists[:, i-1, j] + 1).unsqueeze(dim=1),
                                         (edit_dists[:, i, j-1] + 1).unsqueeze(dim=1),
                                         (edit_dists[:, i-1, j-1] + cost).unsqueeze(dim=1)),
                                        dim=1).min(dim=1)
                edit_dists[:, i, j] = min_cost
        # #

        # find gold next tokens and update their QValues
        edit_dists_mask = OCD.edit_distance_mask(gold_y, end_symbol_id)
        edit_dists = edit_dists.masked_fill_(edit_dists_mask, LARGE_NUM)
        min_dist, _ = edit_dists.min(dim=2)
        min_dist = min_dist.unsqueeze(dim=2)
        steps_with_min_dists = (edit_dists == min_dist)
        extended_gold_y = gold_y.repeat(1, seq_len + 1).view(b_sz, seq_len + 1, seq_len)
        indices = steps_with_min_dists.nonzero()
        gold_next_tokens = extended_gold_y[indices.split(1, dim=1)]
        indices[:, 2] = gold_next_tokens.squeeze(dim=1)
        q_values[indices.split(1, dim=1)] = 1
        q_values = q_values - (1 + min_dist)
        return q_values[:, :-1, :]  # ignore the step 'seq_len + 1'
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)
        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim/num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps))
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps, int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                semantic_views_q: torch.Tensor,
                semantic_views_sent_mask: torch.Tensor,
                mask: torch.LongTensor = None,
                return_output_metadata: bool = False) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        if self.use_semantic_views:
            # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)

            values = self._values_projection(inputs)

            # split by attention dim - if values_dim > attention_dim, we will get more
            # than 3 elements returned. All of the rest are the values vector, so we
            # just concatenate them back together again below.

            bs = inputs.shape[0]
            seq_len = inputs.shape[1]
            input_dim = inputs.shape[-1]

            if not self.multi_head_attention_batch_computation:
                raise Exception("multi_head_attention_batch_computation = False is not supported!")
            else:
                # Shape (bs, num_heads, seq_len, d)
                head_dim = self._single_head_attention_dim

                if self._semantic_integration_mode == "projection":
                    inputs_by_head = inputs.unsqueeze(1).repeat(1, self._num_heads, 1, 1)

                    def get_input_per_head_using_sem_views_projection(inputs_by_head, semantic_view, emb_w, emb_b):
                        semantic_veiws_by_head_w = emb_w(semantic_view).view(
                            [bs, num_heads, seq_len, input_dim, head_dim])
                        semantic_veiws_by_head_b = emb_b(semantic_view).view([bs, num_heads, seq_len, head_dim])

                        res = torch.bmm(inputs_by_head.view(bs * num_heads * seq_len, 1, input_dim),
                                        semantic_veiws_by_head_w.view(bs * num_heads * seq_len, input_dim, head_dim)) \
                                  .view(bs, num_heads, seq_len, head_dim) \
                              + semantic_veiws_by_head_b

                        return res

                    queries_per_head = get_input_per_head_using_sem_views_projection(inputs_by_head, semantic_views_q,
                                                                          self._semantic_label_embedding_q_w,
                                                                          self._semantic_label_embedding_q_b)
                    queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, head_dim)

                    keys_per_head = get_input_per_head_using_sem_views_projection(inputs_by_head, semantic_views_q,
                                                                          self._semantic_label_embedding_k_w,
                                                                          self._semantic_label_embedding_k_b)
                    keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, head_dim)
                elif self._semantic_integration_mode == "concat":
                    inputs_by_head = inputs.unsqueeze(1).repeat(1, self._num_heads, 1, 1)

                    w_emb_dim = self._semantic_label_embeding_w_emb_dim
                    bias_emb_dim = self._semantic_label_embeding_b_emb_dim

                    def get_input_per_head_using_sem_views_concat(inputs_by_head, semantic_view, emb_w, emb_b, projection_per_head):
                        semantic_veiws_by_head_w = emb_w(semantic_view).view([bs, num_heads, seq_len, w_emb_dim])
                        semantic_veiws_by_head_b = emb_b(semantic_view).view([bs, num_heads, seq_len, bias_emb_dim])

                        inputs_by_head_concat = torch.cat([inputs_by_head, semantic_veiws_by_head_w], dim=-1)

                        res = torch.einsum("bhld,hdk->bhlk", [inputs_by_head_concat, projection_per_head])

                        assert not torch.isnan(res).any()

                        res = res + semantic_veiws_by_head_b

                        return res

                    query_projection_per_head = self._queries_projection
                    queries_per_head = get_input_per_head_using_sem_views_concat(inputs_by_head, semantic_views_q,
                                                                                 self._semantic_label_embedding_q_w,
                                                                                 self._semantic_label_embedding_q_b,
                                                                                 projection_per_head=query_projection_per_head
                                                                                 )
                    queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, head_dim)

                    key_projection_per_head = self._keys_projection
                    keys_per_head = get_input_per_head_using_sem_views_concat(inputs_by_head, semantic_views_q,
                                                                              self._semantic_label_embedding_k_w,
                                                                              self._semantic_label_embedding_k_b,
                                                                              projection_per_head=key_projection_per_head
                                                                              )
                    keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, head_dim)
                elif self._semantic_integration_mode == "concat_joint":
                    w_emb_dim = self._semantic_label_embeding_w_emb_dim
                    bias_emb_dim = self._semantic_label_embeding_b_emb_dim

                    def get_input_per_head_using_sem_views_concat_joint(inputs_not_by_head, semantic_view, emb_w, emb_b,
                                                                  projection):
                        semantic_veiws_by_head_w = emb_w(semantic_view).view([bs, num_heads, seq_len, w_emb_dim])
                        semantic_veiws_by_head_w = semantic_veiws_by_head_w.transpose(2, 1).contiguous().view(bs, seq_len, -1)

                        semantic_veiws_by_head_b = emb_b(semantic_view).view([bs, num_heads, seq_len, bias_emb_dim])

                        inputs_not_by_head_concat = torch.cat([inputs_not_by_head, semantic_veiws_by_head_w], dim=-1)

                        res = torch.einsum("bld,dk->blk", [inputs_not_by_head_concat, projection])
                        res = res.view(bs, seq_len, num_heads, bias_emb_dim).transpose(2, 1).contiguous()

                        assert not torch.isnan(res).any()

                        res = res + semantic_veiws_by_head_b

                        return res

                    query_projection_per_head = self._queries_projection
                    queries_per_head = get_input_per_head_using_sem_views_concat_joint(inputs, semantic_views_q,
                                                                                 self._semantic_label_embedding_q_w,
                                                                                 self._semantic_label_embedding_q_b,
                                                                                 projection=query_projection_per_head
                                                                                 )
                    queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, head_dim)

                    key_projection_per_head = self._keys_projection
                    keys_per_head = get_input_per_head_using_sem_views_concat_joint(inputs, semantic_views_q,
                                                                              self._semantic_label_embedding_k_w,
                                                                              self._semantic_label_embedding_k_b,
                                                                                    projection=key_projection_per_head
                                                                              )
                    keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, head_dim)
                else:
                    raise ValueError("semantic_integration_mode `{0}` is not yet supported!"
                                     .format(self._semantic_integration_mode))

                # shape (num_heads * batch_size, timesteps, timesteps)
                scaled_similarities = torch.bmm(queries_per_head / self._scale, keys_per_head.transpose(1, 2))

                # mask
                # Shape (bs, num_heads, seq_len, seq_len)
                semantic_views_sent_mask_tokenwise = semantic_views_sent_mask.unsqueeze(2).repeat(1, 1, seq_len, 1)
                # allow only per-scope mask - like sentence-wise, neighbouring sentences, etc.
                semantic_views_sent_mask_tokenwise = (semantic_views_sent_mask_tokenwise == semantic_views_sent_mask_tokenwise.transpose(3, 2)) \
                                                     * (semantic_views_sent_mask_tokenwise > 0) # this multiplication is the masking of padded zeros!
                semantic_views_sent_mask_tokenwise = semantic_views_sent_mask_tokenwise.float()
                semantic_views_sent_mask_tokenwise = semantic_views_sent_mask_tokenwise\
                                                        .view(bs * num_heads, seq_len, seq_len)
                # masked the similarities
                scaled_similarities = scaled_similarities * semantic_views_sent_mask_tokenwise

                # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
                values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim / num_heads))
                values_per_head = values_per_head.transpose(1, 2).contiguous()
                values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim / num_heads))

                # shape (num_heads * batch_size, timesteps, timesteps)
                # Normalise the distributions, using the same mask for all heads.
                attention = masked_softmax(scaled_similarities,
                                           semantic_views_sent_mask_tokenwise,
                                           memory_efficient=True)

        else:
            # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
            combined_projection = self._combined_projection(inputs)
            # split by attention dim - if values_dim > attention_dim, we will get more
            # than 3 elements returned. All of the rest are the values vector, so we
            # just concatenate them back together again below.
            queries, keys, *values = combined_projection.split(self._attention_dim, -1)
            queries = queries.contiguous()
            keys = keys.contiguous()
            values = torch.cat(values, -1).contiguous()

            # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
            values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim/num_heads))
            values_per_head = values_per_head.transpose(1, 2).contiguous()
            values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim/num_heads))

            # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
            queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
            queries_per_head = queries_per_head.transpose(1, 2).contiguous()
            queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

            # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
            keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
            keys_per_head = keys_per_head.transpose(1, 2).contiguous()
            keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

            # shape (num_heads * batch_size, timesteps, timesteps)
            scaled_similarities = torch.bmm(queries_per_head / self._scale, keys_per_head.transpose(1, 2))

            # shape (num_heads * batch_size, timesteps, timesteps)
            # Normalise the distributions, using the same mask for all heads.
            attention = masked_softmax(scaled_similarities,
                                       mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps),
                                       memory_efficient=True)

        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps, int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)

        output_meta = None
        if return_output_metadata:
            output_meta = {"attention": attention,
                           "semantic_views_q": semantic_views_q,
                           "semantic_views_sent_mask": semantic_views_sent_mask,
                           "mask": mask,
                           }

        return outputs, output_meta
Beispiel #14
0
    def forward(
        self,  # pylint: disable=arguments-differ
        inputs: torch.Tensor,
        mask: torch.LongTensor = None,
    ) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, hidden_dim = inputs.size()
        if mask is None:
            mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0))

        # Treat the queries, keys and values each as a ``num_heads`` size batch.
        # shape (num_heads, batch_size * timesteps, hidden_dim)
        inputs_per_head = inputs.repeat(num_heads, 1, 1).view(
            num_heads, batch_size * timesteps, hidden_dim
        )
        # Do the projections for all the heads at once.
        # Then reshape the result as though it had a
        # (num_heads * batch_size) sized batch.
        queries_per_head = torch.bmm(inputs_per_head, self._query_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        queries_per_head = queries_per_head.view(
            num_heads * batch_size, timesteps, self._attention_dim
        )

        keys_per_head = torch.bmm(inputs_per_head, self._key_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        keys_per_head = keys_per_head.view(num_heads * batch_size, timesteps, self._attention_dim)

        values_per_head = torch.bmm(inputs_per_head, self._value_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        values_per_head = values_per_head.view(num_heads * batch_size, timesteps, self._values_dim)

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = (
            torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale
        )

        # Masking should go here
        causality_mask = subsequent_mask(timesteps).cuda()
        masked_scaled_similarities = scaled_similarities.masked_fill(causality_mask == 0, -1e9)

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = masked_softmax(masked_scaled_similarities, mask.repeat(num_heads, 1))
        attention = self._attention_dropout(attention)
        # This is doing the following batch-wise matrix multiplication:
        # (num_heads * batch_size, timesteps, timesteps) *
        # (num_heads * batch_size, timesteps, values_dim)
        # which is equivalent to a weighted sum of the values with respect to
        # the attention distributions for each element in the num_heads * batch_size
        # dimension.
        # shape (num_heads * batch_size, timesteps, values_dim)
        outputs = torch.bmm(attention, values_per_head)

        # Reshape back to original shape (batch_size, timesteps, num_heads * values_dim)
        # Note that we _cannot_ use a reshape here, because this tensor was created
        # with num_heads being the first dimension, so reshaping naively would not
        # throw an error, but give an incorrect result.
        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            semantic_views_q: torch.Tensor,
            semantic_views_k: torch.Tensor,
            mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        use_token_wise_semantic_labels = True
        if use_token_wise_semantic_labels:
            # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)

            values = self._values_projection(inputs)

            # split by attention dim - if values_dim > attention_dim, we will get more
            # than 3 elements returned. All of the rest are the values vector, so we
            # just concatenate them back together again below.

            bs = inputs.shape[0]
            seq_len = inputs.shape[1]
            input_dim = inputs.shape[-1]

            # Shape (bs, timesteps, timesteps, input_dim)
            input_tokenwise = inputs.unsqueeze(1).repeat(1, seq_len, 1, 1)

            similarities_per_head = []
            for head_slice in zip(torch.split(semantic_views_q, 1, 1),
                                  torch.split(semantic_views_k, 1, 1)):
                semantic_views_q_head, semantic_views_k_head = head_slice
                semantic_views_q_head = semantic_views_q_head.contiguous(
                ).squeeze(1)
                semantic_views_k_head = semantic_views_k_head.contiguous(
                ).squeeze(1)

                q_head = torch.einsum('bjkd,bjkdn->bjkn', (input_tokenwise,
                                                                 self._semantic_label_embedding_q_w(
                                                                     semantic_views_q_head).view(
                                                                     list(semantic_views_q_head.shape)
                                                                     + [input_dim, -1]))) \
                         + self._semantic_label_embedding_q_b(semantic_views_q_head)
                k_head = torch.einsum('bjkd,bjkdn->bjkn', (input_tokenwise,
                                                                 self._semantic_label_embedding_k_w(
                                                                     semantic_views_k_head).view(
                                                                     list(semantic_views_k_head.shape)
                                                                     + [input_dim, -1]))) \
                         + self._semantic_label_embedding_k_b(semantic_views_k_head)

                att_head = torch.mul(q_head / self._scale, k_head.transpose(2, 1)).sum(-1)\
                                    .view(bs, 1, seq_len, seq_len)

                similarities_per_head.append(att_head)

            scaled_similarities = torch.cat(similarities_per_head, dim=1).view(
                bs * num_heads, seq_len, seq_len)

            # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
            values_per_head = values.view(batch_size, timesteps, num_heads,
                                          int(self._values_dim / num_heads))
            values_per_head = values_per_head.transpose(1, 2).contiguous()
            values_per_head = values_per_head.view(
                batch_size * num_heads, timesteps,
                int(self._values_dim / num_heads))

        else:
            # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
            combined_projection = self._combined_projection(inputs)
            # split by attention dim - if values_dim > attention_dim, we will get more
            # than 3 elements returned. All of the rest are the values vector, so we
            # just concatenate them back together again below.
            queries, keys, *values = combined_projection.split(
                self._attention_dim, -1)
            queries = queries.contiguous()
            keys = keys.contiguous()
            values = torch.cat(values, -1).contiguous()

            # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
            values_per_head = values.view(batch_size, timesteps, num_heads,
                                          int(self._values_dim / num_heads))
            values_per_head = values_per_head.transpose(1, 2).contiguous()
            values_per_head = values_per_head.view(
                batch_size * num_heads, timesteps,
                int(self._values_dim / num_heads))

            # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
            queries_per_head = queries.view(
                batch_size, timesteps, num_heads,
                int(self._attention_dim / num_heads))
            queries_per_head = queries_per_head.transpose(1, 2).contiguous()
            queries_per_head = queries_per_head.view(
                batch_size * num_heads, timesteps,
                int(self._attention_dim / num_heads))

            # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
            keys_per_head = keys.view(batch_size, timesteps, num_heads,
                                      int(self._attention_dim / num_heads))
            keys_per_head = keys_per_head.transpose(1, 2).contiguous()
            keys_per_head = keys_per_head.view(
                batch_size * num_heads, timesteps,
                int(self._attention_dim / num_heads))

            # shape (num_heads * batch_size, timesteps, timesteps)
            scaled_similarities = torch.bmm(queries_per_head / self._scale,
                                            keys_per_head.transpose(1, 2))

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = masked_softmax(scaled_similarities,
                                   mask.repeat(1, num_heads).view(
                                       batch_size * num_heads, timesteps),
                                   memory_efficient=True)
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps,
                               int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
    def forward(self,
                inputs: torch.Tensor,
                mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        # Parameters

        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        # Returns

        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)
        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(
            self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads,
                                      int(self._values_dim / num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._values_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads,
                                        int(self._attention_dim / num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads,
                                  int(self._attention_dim / num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(queries_per_head / self._scale,
                                        keys_per_head.transpose(1, 2))

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = masked_softmax(
            scaled_similarities,
            mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps),
            memory_efficient=True,
        )
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps,
                               int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Beispiel #17
0
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            semantic_views_q: torch.Tensor,
            semantic_views_sent_mask: torch.Tensor,
            mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        if self.use_semantic_views:
            # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)

            values = self._values_projection(inputs)

            # split by attention dim - if values_dim > attention_dim, we will get more
            # than 3 elements returned. All of the rest are the values vector, so we
            # just concatenate them back together again below.

            bs = inputs.shape[0]
            seq_len = inputs.shape[1]
            input_dim = inputs.shape[-1]

            similarities_per_head = []
            masks_per_head = []

            if not self.multi_head_attention_batch_computation:
                for head_slice in zip(
                        torch.split(semantic_views_q, 1, 1),
                        torch.split(semantic_views_sent_mask, 1, 1)):
                    semantic_views_q_head, semantic_views_sent_mask_head = head_slice
                    semantic_views_q_head = semantic_views_q_head.contiguous(
                    ).squeeze(1)
                    semantic_views_sent_mask_head = semantic_views_sent_mask_head.contiguous(
                    )

                    semantic_views_sent_mask_head = semantic_views_sent_mask_head.repeat(
                        1, seq_len, 1)
                    semantic_views_sent_mask_head = (
                        semantic_views_sent_mask_head
                        == semantic_views_sent_mask_head.transpose(
                            2, 1)) * (semantic_views_sent_mask_head >
                                      0) * (semantic_views_sent_mask_head > 0)

                    semantic_views_sent_mask_head = semantic_views_sent_mask_head.float(
                    )

                    mask_repeated = mask.unsqueeze(1).repeat(1, seq_len, 1)
                    semantic_views_sent_mask_head = semantic_views_sent_mask_head * mask_repeated

                    masks_per_head.append(semantic_views_sent_mask_head)

                    # q_head
                    q_head_w = self._semantic_label_embedding_q_w(
                        semantic_views_q_head).view(
                            list(semantic_views_q_head.shape) +
                            [input_dim, -1])
                    #q_head = torch.einsum('bid,bidn->bin', [inputs,  q_head_w])
                    q_head = torch.bmm(
                        inputs.view(bs * seq_len, 1, input_dim),
                        q_head_w.view(bs * seq_len, input_dim, -1))
                    q_head = q_head.view(bs, seq_len, -1)
                    q_head = q_head + self._semantic_label_embedding_q_b(
                        semantic_views_q_head)

                    if self.use_separate_label_embeddings_for_q_and_k:
                        # q_head
                        k_head_w = self._semantic_label_embedding_q_w(
                            semantic_views_q_head).view(
                                list(semantic_views_q_head.shape) +
                                [input_dim, -1])
                        #k_head = torch.einsum('bid,bidn->bin', [inputs,  k_head_w])
                        k_head = torch.bmm(
                            inputs.view(bs * seq_len, 1, input_dim),
                            k_head_w.view(bs * seq_len, input_dim, -1))
                        k_head = k_head.view(bs, seq_len, -1)
                        k_head = k_head + self._semantic_label_embedding_q_b(
                            semantic_views_q_head)
                    else:
                        k_head = q_head

                    att_head = torch.bmm(q_head / self._scale,
                                         k_head.transpose(1, 2))
                    att_head = att_head * semantic_views_sent_mask_head
                    att_head = att_head.view(bs, 1, seq_len, seq_len)

                    similarities_per_head.append(att_head)

                scaled_similarities = torch.cat(similarities_per_head,
                                                dim=1).view(
                                                    bs * num_heads, seq_len,
                                                    seq_len)

                # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
                values_per_head = values.view(
                    batch_size, timesteps, num_heads,
                    int(self._values_dim / num_heads))
                values_per_head = values_per_head.transpose(1, 2).contiguous()
                values_per_head = values_per_head.view(
                    batch_size * num_heads, timesteps,
                    int(self._values_dim / num_heads))

                # shape (num_heads * batch_size, timesteps, timesteps)
                # Normalise the distributions, using the same mask for all heads.
                attention = masked_softmax(scaled_similarities,
                                           mask.repeat(1, num_heads).view(
                                               batch_size * num_heads,
                                               timesteps),
                                           memory_efficient=True)
            else:
                # Shape (bs, num_heads, seq_len, d)
                inputs_by_head = inputs.unsqueeze(1).repeat(
                    1, self._num_heads, 1, 1)
                head_dim = self._single_head_attention_dim

                def get_input_per_head_using_sem_views(inputs_by_head,
                                                       semantic_view, emb_w,
                                                       emb_b):
                    semantic_veiws_by_head_w = emb_w(semantic_view).view(
                        [bs, num_heads, seq_len, input_dim, head_dim])
                    semantic_veiws_by_head_b = emb_b(semantic_view).view(
                        [bs, num_heads, seq_len, head_dim])

                    res = torch.bmm(inputs_by_head.view(bs * num_heads * seq_len, 1, input_dim),
                                    semantic_veiws_by_head_w.view(bs * num_heads * seq_len, input_dim, head_dim)) \
                                    .view(bs, num_heads, seq_len, head_dim) \
                          + semantic_veiws_by_head_b

                    return res

                queries_per_head = get_input_per_head_using_sem_views(
                    inputs_by_head, semantic_views_q,
                    self._semantic_label_embedding_q_w,
                    self._semantic_label_embedding_q_b)
                queries_per_head = queries_per_head.view(
                    batch_size * num_heads, timesteps, head_dim)

                keys_per_head = get_input_per_head_using_sem_views(
                    inputs_by_head, semantic_views_q,
                    self._semantic_label_embedding_k_w,
                    self._semantic_label_embedding_k_b)
                keys_per_head = keys_per_head.view(batch_size * num_heads,
                                                   timesteps, head_dim)

                # shape (num_heads * batch_size, timesteps, timesteps)
                scaled_similarities = torch.bmm(queries_per_head / self._scale,
                                                keys_per_head.transpose(1, 2))

                # mask
                # Shape (bs, num_heads, seq_len, seq_len)
                semantic_views_sent_mask_tokenwise = semantic_views_sent_mask.unsqueeze(
                    2).repeat(1, 1, seq_len, 1)
                # allow only per-scope mask - like sentence-wise, neighbouring sentences, etc.
                semantic_views_sent_mask_tokenwise = (semantic_views_sent_mask_tokenwise == semantic_views_sent_mask_tokenwise.transpose(3, 2)) \
                                                     * (semantic_views_sent_mask_tokenwise > 0) # this multiplication is the masking of padded zeros!
                semantic_views_sent_mask_tokenwise = semantic_views_sent_mask_tokenwise.float(
                )
                semantic_views_sent_mask_tokenwise = semantic_views_sent_mask_tokenwise\
                                                        .view(bs * num_heads, seq_len, seq_len)
                # masked the similarities
                scaled_similarities = scaled_similarities * semantic_views_sent_mask_tokenwise

                # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
                values_per_head = values.view(
                    batch_size, timesteps, num_heads,
                    int(self._values_dim / num_heads))
                values_per_head = values_per_head.transpose(1, 2).contiguous()
                values_per_head = values_per_head.view(
                    batch_size * num_heads, timesteps,
                    int(self._values_dim / num_heads))

                # shape (num_heads * batch_size, timesteps, timesteps)
                # Normalise the distributions, using the same mask for all heads.
                attention = masked_softmax(scaled_similarities,
                                           semantic_views_sent_mask_tokenwise,
                                           memory_efficient=True)

        else:
            # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
            combined_projection = self._combined_projection(inputs)
            # split by attention dim - if values_dim > attention_dim, we will get more
            # than 3 elements returned. All of the rest are the values vector, so we
            # just concatenate them back together again below.
            queries, keys, *values = combined_projection.split(
                self._attention_dim, -1)
            queries = queries.contiguous()
            keys = keys.contiguous()
            values = torch.cat(values, -1).contiguous()

            # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
            values_per_head = values.view(batch_size, timesteps, num_heads,
                                          int(self._values_dim / num_heads))
            values_per_head = values_per_head.transpose(1, 2).contiguous()
            values_per_head = values_per_head.view(
                batch_size * num_heads, timesteps,
                int(self._values_dim / num_heads))

            # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
            queries_per_head = queries.view(
                batch_size, timesteps, num_heads,
                int(self._attention_dim / num_heads))
            queries_per_head = queries_per_head.transpose(1, 2).contiguous()
            queries_per_head = queries_per_head.view(
                batch_size * num_heads, timesteps,
                int(self._attention_dim / num_heads))

            # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
            keys_per_head = keys.view(batch_size, timesteps, num_heads,
                                      int(self._attention_dim / num_heads))
            keys_per_head = keys_per_head.transpose(1, 2).contiguous()
            keys_per_head = keys_per_head.view(
                batch_size * num_heads, timesteps,
                int(self._attention_dim / num_heads))

            # shape (num_heads * batch_size, timesteps, timesteps)
            scaled_similarities = torch.bmm(queries_per_head / self._scale,
                                            keys_per_head.transpose(1, 2))

            # shape (num_heads * batch_size, timesteps, timesteps)
            # Normalise the distributions, using the same mask for all heads.
            attention = masked_softmax(scaled_similarities,
                                       mask.repeat(1, num_heads).view(
                                           batch_size * num_heads, timesteps),
                                       memory_efficient=True)

        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps,
                               int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs