def test_split_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)
         y = flow.split(x, split_size_or_sections=-1)
     test_case.assertTrue(
         "split expects split_size be non-negative, but got split_size" in
         str(context.exception))
 def test_splitwithsize_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         x = flow.ones((5, 2), dtype=flow.float32, requires_grad=True)
         y = flow.split(x, [1, 3])
     test_case.assertTrue(
         "split_with_sizes expects split_sizes to sum exactly to " in str(
             context.exception))
Exemple #3
0
    def forward(self, x, mask):
        """Compute 'Scaled Dot Product Attention'

        :param torch.Tensor query: (batch, time1, size)
        :param torch.Tensor mask: (batch, time1 or 1, time2)
        :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
        """

        x = self.qvk_proj(x)

        if self.share_qvk_proj:
            query = key = value = x
        else:
            query, key, value = flow.split(x, self.d_model, dim=-1)

        batch_size = x.size(0)
        query = query.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)
        key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)
        value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)

        scores = flow.matmul(query, key.transpose(2, 3)) / math.sqrt(self.d_k)

        context, attn_weights = self.compute_context(
            value, scores, mask.unsqueeze(1) if mask is not None else None
        )

        return context, attn_weights
Exemple #4
0
    def inference(self, query, memory, memory_mask, cache=None):
        """Compute 'Scaled Dot Product Attention'

        :param torch.Tensor query: (batch, time1, size)
        :param torch.Tensor memory: (batch, time2, size)
        :param torch.Tensor mask: (batch, time1 or 1, time2)
        :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
        """

        query = self.q_proj(query)
        memory = self.vk_proj(memory)

        if self.share_vk_proj:
            key = value = memory
        else:
            key, value = flow.split(memory, self.d_model, dim=-1)

        batch_size = query.size(0)
        query = query.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)
        key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)
        value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)

        scores = flow.matmul(query, key.transpose(2, 3)) / math.sqrt(self.d_k)

        context, attn_weights = self.compute_context(
            value, scores, memory_mask.unsqueeze(1)
        )
        return context, attn_weights, cache
Exemple #5
0
 def expert_to_gates(self):
     """Gate values corresponding to the examples in the per-expert `Tensor`s.
     Returns:
       a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
           and shapes `[expert_batch_size_i]`
     """
     # split nonzero gates for each expert
     return flow.split(self._nonzero_gates, self._part_sizes, dim=0)
Exemple #6
0
    def dispatch(self, inp):
        """Create one input Tensor for each expert.
        The `Tensor` for a expert `i` contains the slices of `inp` corresponding
        to the batch elements `b` where `gates[b, i] > 0`.
        Args:
          inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
        Returns:
          a list of `num_experts` `Tensor`s with shapes
            `[expert_batch_size_i, <extra_input_dims>]`.
        """

        # assigns samples to experts whose gate is nonzero
        # expand according to batch index so we can just split by _part_sizes
        inp_exp = inp[self._batch_index].squeeze(1)
        return flow.split(inp_exp, self._part_sizes, dim=0)
Exemple #7
0
    def inference(self, x, mask, cache=None):

        x = self.qvk_proj(x)

        if self.share_qvk_proj:
            query = key = value = x
        else:
            query, key, value = flow.split(x, self.d_model, dim=-1)

        batch_size = x.size(0)
        query = query.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)
        key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)
        value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)

        scores = flow.matmul(query, key.transpose(2, 3)) / math.sqrt(self.d_k)

        context, attn_weights = self.compute_context(
            value, scores, mask.unsqueeze(1) if mask is not None else None
        )

        return context, attn_weights, cache
Exemple #8
0
    def forward(self, x, mask, pos):
        """
        Args:
            x: [B, T, V]
            mask: [B, 1, T]
            pos: positional embedding [B, S=2T-1, V] 
        """

        x = self.qvk_proj(x)

        if self.share_qvk_proj:
            query = key = value = x
        else:
            query, key, value = flow.split(x, self.d_model, dim=-1)

        batch_size = x.size(0)
        query = query.reshape(batch_size, -1, self.nheads, self.d_k)
        key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)
        value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2)

        bpos = pos.size(0)
        pos = (
            self.pos_proj(pos).reshape(bpos, -1, self.nheads, self.d_k).transpose(1, 2)
        )

        query_with_bias_u = query + self.posu
        query_with_bias_u = query_with_bias_u.transpose(1, 2)
        matrix_ac = flow.matmul(query_with_bias_u, key.transpose(-2, -1))

        matrix_bd = self._RelPosBias(
            query + self.posv if not self.skip_term_b else self.posv, pos
        )

        scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
        context, attn_weights = self.compute_context(
            value, scores, mask.unsqueeze(1) if mask is not None else None
        )

        return context, attn_weights
 def test_split_index_error(test_case):
     with test_case.assertRaises(Exception) as context:
         x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)
         y = flow.split(x, split_size_or_sections=0, dim=4)
     test_case.assertTrue(
         "Dimension out of range" in str(context.exception))