def fail_not_enough_args(): target = torch.ops.aten.convolution.default kwargs = dict( input=torch.randn(1, 3, 32, 32), weight=torch.randn(3, 3, 3, 3), bias=torch.randn(3), stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=False, output_padding=[0, 0], # Missing groups=1, ) build_ts_script_function(target._schema, kwargs)
def simple(): target = torch.ops.aten.addmm.default kwargs = dict( input=torch.randn(1, 3, 32, 32), mat1=torch.randn(1, 3, 32, 32), mat2=torch.randn(1, 3, 32, 32), beta=1, alpha=1, ) script_fun = build_ts_script_function(target._schema, kwargs) print(script_fun.graph)
def handle_ones_like(): target = torch.ops.aten.ones_like.default kwargs = dict( input=torch.randn(1, 3, 32, 32), dtype=None, layout=None, device=None, pin_memory=None, memory_format=None, ) script_fun = build_ts_script_function(target._schema, kwargs) print(script_fun.graph)
def check_legal_name(): target = torch.ops.aten.native_batch_norm.default kwargs = dict(input=torch.randn(1, 3, 32, 32), weight=torch.randn(1, 3, 32, 32), bias=torch.randn(1, 3, 32, 32), running_mean=None, running_var=None, training=False, momentum=1.0, eps=1.0) script_fun = build_ts_script_function(target._schema, kwargs) print(script_fun.name)
def handle_nones(): target = torch.ops.aten.max_pool2d_with_indices.default # print(target._schema) kwargs = dict( input=torch.randn((1, 3, 32, 32)), kernel_size=[3, 3], stride=None, padding=[0, 0], dilation=[1, 1], ceil_mode=False, ) script_fun = build_ts_script_function(target._schema, kwargs) print(script_fun.graph)
def handle_optional_tensor_input(): target = torch.ops.aten.convolution.default kwargs = dict( input=torch.randn(1, 3, 32, 32), weight=torch.randn(3, 3, 3, 3), bias=torch.randn(3), stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=False, output_padding=[0, 0], groups=1, ) script_fun = build_ts_script_function(target._schema, kwargs) print(script_fun.graph)
def simple_kwargs(): target = torch.ops.aten.convolution.default script_fun1 = build_ts_script_function( target._schema, dict( input=torch.randn(1, 3, 32, 32), weight=torch.randn(3, 3, 3, 3), bias=torch.randn(3), stride=[1, 1], padding=[0, 0], dilation=[1, 1], transposed=False, output_padding=[0, 0], groups=1, ), ) print(script_fun1.graph)
def correctly_order_kwargs(): target = torch.ops.aten.native_batch_norm.out input = torch.randn(2, 5, 2, 3) running_mean = torch.randn(5) running_var = torch.randn(5) kwargs = dict( input=torch.randn(2, 5, 2, 3), weight=torch.randn(5), bias=torch.randn(5), running_mean=running_mean, running_var=running_var, training=False, momentum=0.1, eps=0.0001, out=torch.empty_like(input), save_mean=torch.empty_like(running_mean), save_invstd=torch.empty_like(running_var), ) script_fun = build_ts_script_function(target._schema, kwargs) print(script_fun.graph)