Exemple #1
0
    def test_shape_inference_basics(self):
        """Test Glow shape inference basic usage."""
        def f(a):
            return a * a

        a = torch.randn(1)

        jit_f = torch.jit.trace(f, (a))
        jit_f_graph = jit_f.graph_for(a)

        args = (a, )

        actual = torch_glow.glow_shape_inference(
            jit_f_graph,
            args,
        )

        assert actual
Exemple #2
0
    def test_shape_inference_input_mismatch(self):
        """Test Glow shape inference basic error handling."""
        def f(a):
            return a * a

        a = torch.randn(1)

        jit_f = torch.jit.trace(f, (a))
        jit_f_graph = jit_f.graph_for(a)

        # Input/args is empty, but the funciton expects one input.
        # Shape Inference should raise an exception in this case.
        args = ()

        self.assertRaises(
            Exception,
            lambda: torch_glow.glow_shape_inference(
                jit_f_graph,
                args,
            ),
        )