Exemple #1
0
def fail_not_enough_args():
    target = torch.ops.aten.convolution.default
    kwargs = dict(
        input=torch.randn(1, 3, 32, 32),
        weight=torch.randn(3, 3, 3, 3),
        bias=torch.randn(3),
        stride=[1, 1],
        padding=[0, 0],
        dilation=[1, 1],
        transposed=False,
        output_padding=[0, 0],
        # Missing groups=1,
    )
    build_ts_script_function(target._schema, kwargs)
Exemple #2
0
def simple():
    target = torch.ops.aten.addmm.default
    kwargs = dict(
        input=torch.randn(1, 3, 32, 32),
        mat1=torch.randn(1, 3, 32, 32),
        mat2=torch.randn(1, 3, 32, 32),
        beta=1,
        alpha=1,
    )

    script_fun = build_ts_script_function(target._schema, kwargs)
    print(script_fun.graph)
Exemple #3
0
def handle_ones_like():
    target = torch.ops.aten.ones_like.default
    kwargs = dict(
        input=torch.randn(1, 3, 32, 32),
        dtype=None,
        layout=None,
        device=None,
        pin_memory=None,
        memory_format=None,
    )
    script_fun = build_ts_script_function(target._schema, kwargs)
    print(script_fun.graph)
Exemple #4
0
def check_legal_name():
    target = torch.ops.aten.native_batch_norm.default
    kwargs = dict(input=torch.randn(1, 3, 32, 32),
                  weight=torch.randn(1, 3, 32, 32),
                  bias=torch.randn(1, 3, 32, 32),
                  running_mean=None,
                  running_var=None,
                  training=False,
                  momentum=1.0,
                  eps=1.0)

    script_fun = build_ts_script_function(target._schema, kwargs)
    print(script_fun.name)
Exemple #5
0
def handle_nones():
    target = torch.ops.aten.max_pool2d_with_indices.default
    # print(target._schema)
    kwargs = dict(
        input=torch.randn((1, 3, 32, 32)),
        kernel_size=[3, 3],
        stride=None,
        padding=[0, 0],
        dilation=[1, 1],
        ceil_mode=False,
    )
    script_fun = build_ts_script_function(target._schema, kwargs)
    print(script_fun.graph)
Exemple #6
0
def handle_optional_tensor_input():
    target = torch.ops.aten.convolution.default
    kwargs = dict(
        input=torch.randn(1, 3, 32, 32),
        weight=torch.randn(3, 3, 3, 3),
        bias=torch.randn(3),
        stride=[1, 1],
        padding=[0, 0],
        dilation=[1, 1],
        transposed=False,
        output_padding=[0, 0],
        groups=1,
    )
    script_fun = build_ts_script_function(target._schema, kwargs)
    print(script_fun.graph)
Exemple #7
0
def simple_kwargs():
    target = torch.ops.aten.convolution.default
    script_fun1 = build_ts_script_function(
        target._schema,
        dict(
            input=torch.randn(1, 3, 32, 32),
            weight=torch.randn(3, 3, 3, 3),
            bias=torch.randn(3),
            stride=[1, 1],
            padding=[0, 0],
            dilation=[1, 1],
            transposed=False,
            output_padding=[0, 0],
            groups=1,
        ),
    )

    print(script_fun1.graph)
Exemple #8
0
def correctly_order_kwargs():
    target = torch.ops.aten.native_batch_norm.out

    input = torch.randn(2, 5, 2, 3)
    running_mean = torch.randn(5)
    running_var = torch.randn(5)

    kwargs = dict(
        input=torch.randn(2, 5, 2, 3),
        weight=torch.randn(5),
        bias=torch.randn(5),
        running_mean=running_mean,
        running_var=running_var,
        training=False,
        momentum=0.1,
        eps=0.0001,
        out=torch.empty_like(input),
        save_mean=torch.empty_like(running_mean),
        save_invstd=torch.empty_like(running_var),
    )

    script_fun = build_ts_script_function(target._schema, kwargs)
    print(script_fun.graph)