def _rebuild_sparse_csr_tensor(layout, data): if layout == torch.sparse_csr: crow_indices, col_indices, values, size = data result = torch._sparse_csr_tensor_unsafe(crow_indices, col_indices, values, size) _sparse_tensors_to_validate.append(result) return result raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
def test_factory_override(self): class A(TorchFunctionMode): def __torch_function__(self, *args, **kwargs): return -1 with torch.overrides.push_torch_function_mode(A): self.assertEqual(torch.tensor([1]), -1) self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1) self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1) self.assertEqual(torch._sparse_coo_tensor_unsafe(1, 1, (1, 1)), -1) self.assertEqual(torch._sparse_csr_tensor_unsafe(1, 1, 1, (1, 1)), -1) self.assertEqual(torch.as_tensor([1]), -1)