def test_shape_inference_basics(self): """Test Glow shape inference basic usage.""" def f(a): return a * a a = torch.randn(1) jit_f = torch.jit.trace(f, (a)) jit_f_graph = jit_f.graph_for(a) args = (a, ) actual = torch_glow.glow_shape_inference( jit_f_graph, args, ) assert actual
def test_shape_inference_input_mismatch(self): """Test Glow shape inference basic error handling.""" def f(a): return a * a a = torch.randn(1) jit_f = torch.jit.trace(f, (a)) jit_f_graph = jit_f.graph_for(a) # Input/args is empty, but the funciton expects one input. # Shape Inference should raise an exception in this case. args = () self.assertRaises( Exception, lambda: torch_glow.glow_shape_inference( jit_f_graph, args, ), )