def test_addmm_broadcast_with_alpha_and_beta(self): """Test of the PyTorch addmm with broadcasting add on Glow.""" utils.run_comparison_tests( SimpleAddMmModule(2.0, 3.0), (torch.randn(4), torch.randn(6, 10), torch.randn(10, 4)), fusible_ops={"aten::add", "aten::mm"}, )
def test_div(self, _, module, a, b, skip_for_backends={}): utils.run_comparison_tests( module, (a, b), fusible_ops={"aten::div"}, skip_for_backends=skip_for_backends, )
def test_floor_div(self, _, module, left, right): utils.run_comparison_tests( module, (left, right), fusible_ops={"aten::floor_divide"}, skip_for_backends="NNPI", )
def test_baddbmm_broadcast_with_alpha_and_beta(self): """Test of the PyTorch baddbmm with broadcasting add on Glow, a=2/b=3""" utils.run_comparison_tests( SimpleBAddBmmModule(2.0, 3.0), (torch.randn(1, 4), torch.randn(3, 6, 10), torch.randn(3, 10, 4)), fusible_ops={"aten::baddbmm"}, )
def test_baddbmm_broadcast(self): """Test of the PyTorch baddbmm with broadcasting add on Glow.""" utils.run_comparison_tests( SimpleBAddBmmModule(), (torch.randn(1, 4), torch.randn(3, 6, 10), torch.randn(3, 10, 4)), fusible_ops={"aten::baddbmm"}, )
def test_and(self, _, a, b, skip_to_glow=False): utils.run_comparison_tests( SimpleAndModule(), (a, b), fusible_ops={"aten::__and__"}, skip_to_glow=skip_to_glow, )
def test_avg_pool3d_basic(self): """Basic test of the PyTorch avg_pool3d Node on Glow.""" inputs = torch.randn(1, 4, 5, 5, 5) utils.run_comparison_tests( SimpleAvgPool3dModule(3), inputs, fusible_ops={"aten::avg_pool3d"} )
def test_baddbmm_basic(self): """Basic test of the PyTorch baddbmm Node on Glow.""" utils.run_comparison_tests( SimpleBAddBmmModule(), (torch.randn(3, 6, 4), torch.randn(3, 6, 10), torch.randn( 3, 10, 4)), fusible_ops={"aten::baddbmm"}, )
def test_avg_pool3d_with_args(self): """Test of the PyTorch avg_pool3d Node with arguments on Glow.""" inputs = torch.randn(1, 4, 10, 10, 10) utils.run_comparison_tests( SimpleAvgPool3dModule(3, (4, 7, 7)), inputs, fusible_ops={"aten::avg_pool3d"}, )
def test_addmm_basic(self): """Basic test of the PyTorch addmm Node on Glow.""" utils.run_comparison_tests( SimpleAddMmModule(), (torch.randn(6, 4), torch.randn(6, 10), torch.randn(10, 4)), fusible_ops={"aten::add", "aten::mm"}, fp16vfp16_atol=1e-3, fp16vfp16_rtol=1e-3, )
def test_abs_3d(self): """Test multidimensional tensor for the PyTorch Abs Node on Glow.""" x = torch.randn(2, 3, 5) utils.run_comparison_tests( SimpleAbsModule(), x, fusible_ops={"aten::abs"}, )
def test_abs_basic(self): """Basic test of the PyTorch Abs Node on Glow.""" x = torch.randn(10) utils.run_comparison_tests( SimpleAbsModule(), x, fusible_ops={"aten::abs"}, )
def test_adaptive_avg_pool2d_basic(self): """Basic test of PyTorch adaptive_avg_pool2d Node.""" inputs = torch.randn(3, 6, 14, 14) utils.run_comparison_tests( SimpleAdapativeAvgPool2dModule((5, 5)), inputs, fusible_ops={"aten::adaptive_avg_pool2d"}, )
def test_mul(self, _, left, right, skip_to_glow=False): """Basic test of the PyTorch mul Node on Glow.""" utils.run_comparison_tests( SimpleMulModule(), (left, right), fusible_ops={"aten::mul"}, skip_to_glow=skip_to_glow, )
def test_adaptive_avg_pool2d_nonsquare_outputs(self): """Test of PyTorch adaptive_avg_pool2d Node with non-square outputs.""" inputs = torch.randn(3, 6, 14, 14) utils.run_comparison_tests( SimpleAdapativeAvgPool2dModule((5, 3)), inputs, fusible_ops={"aten::adaptive_avg_pool2d"}, )
def test_avg_pool2d_with_args(self): """Test of the PyTorch avg_pool2d Node with arguments on Glow.""" inputs = torch.randn(1, 4, 10, 10) utils.run_comparison_tests( SimpleAvgPool2dModule(3, stride=7), inputs, fusible_ops={"aten::avg_pool2d"}, fp16vfp16_atol=1e-3, )
def test_subtract(self, _, module, tensor, other): utils.run_comparison_tests(module, (tensor, other), fusible_ops={"aten::sub"})
def test_argmax_node(self, _, module, tensor): """Test of the PyTorch ArgMax node on Glow.""" utils.run_comparison_tests(module, tensor, fusible_ops={"aten::argmax"})
def test_arange(self, _, module, dummy): """Testing arange with minimum parameters""" utils.run_comparison_tests(module, dummy, fusible_ops={"aten::arange"})
def test_add(self, _, module, a, b, skip_to_glow=False): utils.run_comparison_tests( module, (a, b), fusible_ops={"aten::add_"} if module.inplace else {"aten::add"}, )