Example #1
0
    def test_combined(self):
        # Test a case with both custom rules and metatensors

        def func(input, weight, bias, y):
            conv_out = torch.nn.functional.conv2d(input, weight, bias)
            conv_2 = conv_out + y
            flattened = torch.flatten(conv_2, start_dim=2)
            add_res = flattened + y
            return add_res

        conv_ins = sample_inputs_conv2d(None, "cpu", torch.int8, False)
        conv_in = list(conv_ins)[-1]
        y_val = torch.rand((1, ), dtype=torch.float32)
        input_args = [conv_in.input, *conv_in.args, y_val]
        self.assert_dtype_equal_custom_args(func, input_args)
Example #2
0
    def test_conv_no_mixed_args(self):
        def conv2d_fn(input, weight, bias):
            return torch.nn.functional.conv2d(input, weight, bias)

        # Now make sure that conv2d doesn't support mixed args
        conv_ins = sample_inputs_conv2d(None, "cpu", torch.float, False)
        conv_in = list(conv_ins)[-1]
        weight, bias = conv_in.args
        weight = weight.type(torch.long)

        with self.assertRaises(RuntimeError):
            conv2d_fn(conv_in.input, weight, bias)

        # Check that we also don't propagate
        graph = torch.jit.script(
            conv2d_fn).graph  # Note this is a cached graph
        self.prop_dtype_on_graph(graph, [conv_in.input, weight, bias])
        actual_dtype = self.node_output_dtype_single(graph)
        self.assertEqual(actual_dtype, None)