Exemplo n.º 1
0
 def forward(self, x):
     if x.dim() >= 3:
         raise RuntimeError(
             "{} accept 1/2D tensor as input, but got {:d}".format(
                 self.__name__, x.dim()
             )
         )
     # when inference, only one utt
     if x.dim() == 1:
         x = flow.unsqueeze(x, 0)
     # n x 1 x S => n x N x T
     w = F.relu(self.encoder_1d(x))
     # n x B x T
     y = self.proj(self.ln(w))
     # n x B x T
     y = self.repeats(y)
     # n x 2N x T
     e = flow.chunk(self.mask(y), self.num_spks, 1)
     # n x N x T
     if self.non_linear_type == "softmax":
         m = self.non_linear(flow.stack(e, dim=0), dim=0)
     else:
         m = self.non_linear(flow.stack(e, dim=0))
     # spks x [n x N x T]
     s = [w * m[n] for n in range(self.num_spks)]
     # spks x n x S
     return [self.decoder_1d(x, squeeze=True) for x in s]
Exemplo n.º 2
0
    def forward(self, hidden_states, layer_past=None, use_cache=False):
        hidden_states = self.c_attn(hidden_states)
        query, key, value = flow.chunk(hidden_states, chunks=3, dim=2)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            past_key, past_value = layer_past
            key = flow.cat((past_key, key), dim=-2)
            value = flow.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        attn_output, attn_weights = self._attn(query, key, value)

        attn_output = self._merge_heads(attn_output, self.num_heads,
                                        self.head_dim)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present, attn_weights)
        return outputs
Exemplo n.º 3
0
 def test_chunk_0_dim_input_exception(test_case):
     # torch exception and messge:
     #
     #   RuntimeError: chunk expects at least a 1-dimensional tensor.
     #
     x = flow.tensor(3.14)
     with test_case.assertRaises(RuntimeError) as ctx:
         y = flow.chunk(x, chunks=1, dim=0)
     test_case.assertTrue("chunk expects at least a 1-dimensional tensor" in
                          str(ctx.exception))
Exemplo n.º 4
0
 def test_chunk_dim_param_exception(test_case):
     # torch exception and messge:
     #
     #   IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)
     #
     x = flow.tensor([[1, 2, 3], [4, 5, 6]])
     with test_case.assertRaises(IndexError) as ctx:
         y = flow.chunk(x, chunks=2, dim=-3)
     test_case.assertTrue(
         "Dimension out of range (expected to be in range of [-2, 1], but got -3)"
         in str(ctx.exception))
Exemplo n.º 5
0
 def test_chunk_0_chunks_param_exception(test_case):
     # torch exception and messge:
     #
     #   RuntimeError: chunk expects `chunks` to be greater than 0, got: 0
     #
     x = flow.tensor([[1, 2, 3], [4, 5, 6]])
     with test_case.assertRaises(RuntimeError) as ctx:
         y = flow.chunk(x, chunks=0, dim=0)
     test_case.assertTrue(
         "chunk expects `chunks` to be greater than 0, got: " in str(
             ctx.exception))
Exemplo n.º 6
0
 def test_chunk_value_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)
         y = flow.chunk(x, chunks=-1, dim=4)
     test_case.assertTrue("chunk expects `chunks` to be greater than 0, got"
                          in str(context.exception))
Exemplo n.º 7
0
 def test_chunk_tensor_dim_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         x = flow.tensor(1, dtype=flow.float32, requires_grad=True)
         y = flow.chunk(x, chunks=2, dim=4)
     test_case.assertTrue("chunk expects at least a 1-dimensional tensor" in
                          str(context.exception))
Exemplo n.º 8
0
 def test_chunk_index_error(test_case):
     with test_case.assertRaises(Exception) as context:
         x = flow.ones((1, 2, 3), dtype=flow.float32, requires_grad=True)
         y = flow.chunk(x, chunks=2, dim=4)
     test_case.assertTrue(
         "Dimension out of range" in str(context.exception))