def create_lazy_tensor(self):
     tensor = torch.randn(4, 3, 5, 2)
     diag = torch.tensor([[1.0, 2.0, 4.0, 2.0, 3.0],
                          [2.0, 1.0, 2.0, 1.0, 4.0],
                          [1.0, 2.0, 2.0, 3.0, 4.0]]).repeat(4, 1, 1)
     lt = LowRankRootLazyTensor(tensor).add_diag(diag)
     assert isinstance(lt, LowRankRootAddedDiagLazyTensor)
     return lt
 def create_lazy_tensor(self):
     root = torch.randn(4, 3, 5, 2)
     return LowRankRootLazyTensor(root)
 def create_lazy_tensor(self):
     root = torch.randn(3, 1, requires_grad=True)
     return LowRankRootLazyTensor(root)
 def create_lazy_tensor(self):
     tensor = torch.randn(5, 2)
     diag = torch.tensor([1.0, 2.0, 4.0, 2.0, 3.0])
     lt = LowRankRootLazyTensor(tensor).add_diag(diag)
     assert isinstance(lt, LowRankRootAddedDiagLazyTensor)
     return lt