def test_split_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.split(x, split_size_or_sections=-1) test_case.assertTrue( "split expects split_size be non-negative, but got split_size" in str(context.exception))
def test_splitwithsize_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((5, 2), dtype=flow.float32, requires_grad=True) y = flow.split(x, [1, 3]) test_case.assertTrue( "split_with_sizes expects split_sizes to sum exactly to " in str( context.exception))
def forward(self, x, mask): """Compute 'Scaled Dot Product Attention' :param torch.Tensor query: (batch, time1, size) :param torch.Tensor mask: (batch, time1 or 1, time2) :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) """ x = self.qvk_proj(x) if self.share_qvk_proj: query = key = value = x else: query, key, value = flow.split(x, self.d_model, dim=-1) batch_size = x.size(0) query = query.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) scores = flow.matmul(query, key.transpose(2, 3)) / math.sqrt(self.d_k) context, attn_weights = self.compute_context( value, scores, mask.unsqueeze(1) if mask is not None else None ) return context, attn_weights
def inference(self, query, memory, memory_mask, cache=None): """Compute 'Scaled Dot Product Attention' :param torch.Tensor query: (batch, time1, size) :param torch.Tensor memory: (batch, time2, size) :param torch.Tensor mask: (batch, time1 or 1, time2) :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) """ query = self.q_proj(query) memory = self.vk_proj(memory) if self.share_vk_proj: key = value = memory else: key, value = flow.split(memory, self.d_model, dim=-1) batch_size = query.size(0) query = query.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) scores = flow.matmul(query, key.transpose(2, 3)) / math.sqrt(self.d_k) context, attn_weights = self.compute_context( value, scores, memory_mask.unsqueeze(1) ) return context, attn_weights, cache
def expert_to_gates(self): """Gate values corresponding to the examples in the per-expert `Tensor`s. Returns: a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` and shapes `[expert_batch_size_i]` """ # split nonzero gates for each expert return flow.split(self._nonzero_gates, self._part_sizes, dim=0)
def dispatch(self, inp): """Create one input Tensor for each expert. The `Tensor` for a expert `i` contains the slices of `inp` corresponding to the batch elements `b` where `gates[b, i] > 0`. Args: inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]` Returns: a list of `num_experts` `Tensor`s with shapes `[expert_batch_size_i, <extra_input_dims>]`. """ # assigns samples to experts whose gate is nonzero # expand according to batch index so we can just split by _part_sizes inp_exp = inp[self._batch_index].squeeze(1) return flow.split(inp_exp, self._part_sizes, dim=0)
def inference(self, x, mask, cache=None): x = self.qvk_proj(x) if self.share_qvk_proj: query = key = value = x else: query, key, value = flow.split(x, self.d_model, dim=-1) batch_size = x.size(0) query = query.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) scores = flow.matmul(query, key.transpose(2, 3)) / math.sqrt(self.d_k) context, attn_weights = self.compute_context( value, scores, mask.unsqueeze(1) if mask is not None else None ) return context, attn_weights, cache
def forward(self, x, mask, pos): """ Args: x: [B, T, V] mask: [B, 1, T] pos: positional embedding [B, S=2T-1, V] """ x = self.qvk_proj(x) if self.share_qvk_proj: query = key = value = x else: query, key, value = flow.split(x, self.d_model, dim=-1) batch_size = x.size(0) query = query.reshape(batch_size, -1, self.nheads, self.d_k) key = key.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) value = value.reshape(batch_size, -1, self.nheads, self.d_k).transpose(1, 2) bpos = pos.size(0) pos = ( self.pos_proj(pos).reshape(bpos, -1, self.nheads, self.d_k).transpose(1, 2) ) query_with_bias_u = query + self.posu query_with_bias_u = query_with_bias_u.transpose(1, 2) matrix_ac = flow.matmul(query_with_bias_u, key.transpose(-2, -1)) matrix_bd = self._RelPosBias( query + self.posv if not self.skip_term_b else self.posv, pos ) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) context, attn_weights = self.compute_context( value, scores, mask.unsqueeze(1) if mask is not None else None ) return context, attn_weights
def test_split_index_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True) y = flow.split(x, split_size_or_sections=0, dim=4) test_case.assertTrue( "Dimension out of range" in str(context.exception))