Beispiel #1
0
    def backward(self, grad_output):
        if self._indices is not None:
            indices = self._indices
        else:
            indices, = self.saved_tensors

        grad_output = grad_output.contiguous()
        if not self.sparse:
            if indices.dim() == 2:
                indices = indices.view(-1)

            with torch.cuda.device_of(grad_output):
                if grad_output.is_cuda:
                    _sorted = torch.cuda.LongTensor()
                    _indices = torch.cuda.LongTensor()
                    _count = torch.cuda.LongTensor()
                else:
                    _count = torch.IntTensor()
                    _sorted = _indices = None

            # TODO: sparse updates...
            grad_weight = grad_output.new(self._weight_size).zero_()
            self._backend.LookupTable_accGradParameters(
                self._backend.library_state, indices, grad_output, grad_weight,
                _count, _sorted, _indices, self.scale_grad_by_freq,
                self.padding_idx, 1)
        else:
            sp = self._make_sparse(indices, type(grad_output))
            go = grad_output.view(-1, grad_output.size()[-1])
            grad_weight = torch.smm(sp, go)
        return None, grad_weight
def sds_bmm_torch(s_t1, d_t2):
    """
    bmm (Batch Matrix Matrix) for sparse x dense -> sparse. This function doesn't support gradient.
    And sparse tensors cannot accept gradient due to the limitation of torch implementation.
    with s_t1.shape = (b, x, s), d_t2.shape = (b, s, y), the output shape is (b, x, y)
    This is a work around utilizing torch.smm for sparse x dense -> sparse
    :param s_t1: sparse tensor 1 (in list, representing batches)
    :param d_t2: dense tensor 2
    :return: bmm result in sparse (in list, representing batches)
    """
    device = d_t2.device
    assert type(s_t1) == list
    batch_num = len(s_t1)

    assert batch_num == d_t2.shape[0], 'Batch size mismatch.'

    outp = []
    for b in range(batch_num):
        # force cpu
        _s_t1 = s_t1[b].cpu()
        _d_t2 = d_t2[b].cpu()
        assert _s_t1.shape[1] == _d_t2.shape[0], 'Matrix shape mismatch.'
        _outp = torch.smm(_s_t1,
                          _d_t2)  # CUDA version of smm is not implemented
        outp.append(_outp)

    return outp
    def test(self):
        texture_img = cv2.imread('models/default_texture2.jpg')
        texture_img = torch.from_numpy(texture_img).unsqueeze(0).float()
        texture_img = texture_img.reshape(1, -1).transpose(0, 1)
        start_time = time.time()

        action_tensor = random.choice(self.action_sparse_tensor_data)['mat']
        result_flat = torch.smm(action_tensor, texture_img).to_dense()
        result_flat = result_flat.transpose(0, 1)
        result_flat = result_flat.reshape(1, 224, 224, 3)
        stop_time = time.time()
        print('time use: {}'.format(stop_time - start_time))
        result_flat = result_flat.numpy()[0, :]
        cv2.imshow('result', result_flat.astype(np.uint8))
        cv2.waitKey()
Beispiel #4
0
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            t = self._gen_sparse(2, 20, [di, dk])[0]
            y = torch.randn(dj, dk)
            alpha = random.random()
            beta = random.random()

            res = torch.saddmm(alpha, t, beta, x, y)
            expected = torch.addmm(alpha, self.safeToDense(t), beta, self.safeToDense(x), y)
            self.assertEqual(self.safeToDense(res), expected)

            res = torch.saddmm(t, x, y)
            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y)
            self.assertEqual(self.safeToDense(res), expected)

            res = torch.smm(x, y)
            expected = torch.mm(self.safeToDense(x), y)
            self.assertEqual(self.safeToDense(res), expected)
Beispiel #5
0
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            t = self._gen_sparse(2, 20, [di, dk])[0]
            y = torch.randn(dj, dk)
            alpha = random.random()
            beta = random.random()

            res = torch.saddmm(alpha, t, beta, x, y)
            expected = torch.addmm(alpha, self.safeToDense(t), beta, self.safeToDense(x), y)
            self.assertEqual(self.safeToDense(res), expected)

            res = torch.saddmm(t, x, y)
            expected = torch.addmm(self.safeToDense(t), self.safeToDense(x), y)
            self.assertEqual(self.safeToDense(res), expected)

            res = torch.smm(x, y)
            expected = torch.mm(self.safeToDense(x), y)
            self.assertEqual(self.safeToDense(res), expected)
Beispiel #6
0
        def test_shape(di, dj, dk):
            x = self._gen_sparse(2, 20, [di, dj])[0]
            t = self._gen_sparse(2, 20, [di, dk])[0]
            y = torch.randn(dj, dk)
            alpha = random.random()
            beta = random.random()

            expected = torch.addmm(alpha, t.to_dense(), beta, x.to_dense(), y)
            res = torch.saddmm(alpha, t, beta, x, y)
            self.assertEqual(res.to_dense(), expected)

            expected = torch.addmm(t.to_dense(), x.to_dense(), y)
            res = torch.saddmm(t, x, y)
            self.assertEqual(res.to_dense(), expected)

            expected = torch.mm(x.to_dense(), y)
            res = torch.smm(x, y)
            self.assertEqual(res.to_dense(), expected)