Exemple #1
0
def linear(input, weight, bias=None):
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = flow.addmm(bias, input, weight.transpose(0, 1))
    else:
        output = input.matmul(weight.transpose(0, 1))
        if bias is not None:
            output += bias
        ret = output
    return ret
Exemple #2
0
def _test_addmm(test_case, shape, alpha, beta, device):
    mat1 = np.random.randn(*shape)
    mat2 = np.random.randn(*shape)
    input = np.random.randn(*shape)
    mat1_tensor = flow.tensor(mat1, dtype=flow.float32, device=flow.device(device))
    mat2_tensor = flow.tensor(mat2, dtype=flow.float32, device=flow.device(device))
    input_tensor = flow.tensor(input, dtype=flow.float32, device=flow.device(device))
    of_out = flow.addmm(input_tensor, mat1_tensor, mat2_tensor, alpha, beta)
    np_out = np.add(beta * input, alpha * np.matmul(mat1, mat2))
    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
Exemple #3
0
def _test_addmm_backward(test_case, shape, alpha, beta, device):
    mat1 = np.random.randn(*shape)
    mat2 = np.random.randn(*shape)
    input = np.random.randn(*shape)
    mat1_tensor = flow.tensor(mat1, dtype=flow.float32, device=flow.device(device))
    mat2_tensor = flow.tensor(mat2, dtype=flow.float32, device=flow.device(device))
    input_tensor = flow.tensor(
        input, dtype=flow.float32, requires_grad=True, device=flow.device(device)
    )
    of_out = flow.addmm(input_tensor, mat1_tensor, mat2_tensor, alpha, beta).sum()
    of_out.backward()
    np_grad_out = np.ones_like(input) * beta
    test_case.assertTrue(
        np.allclose(input_tensor.grad.numpy(), np_grad_out, 1e-05, 1e-05)
    )
Exemple #4
0
def _addmm(self, mat1, mat2, alpha=1, beta=1):
    return flow.addmm(self, mat1, mat2, alpha, beta)
Exemple #5
0
 def forward(self, x):
     bsz, seq_len, channels = x.size()
     # size_out = x.size()[:-1] + (self.nf,)
     x = flow.addmm(self.bias, x.view(-1, channels), self.weight)
     x = x.view(bsz, seq_len, self.nf)
     return x