def test_constant_chunk_negative_indices(): """Test of prim::ConstantChunk node on glow""" def test_f(x): return torch.chunk(x + x, 3, -2) # shapes: [(4,11), (4,11), (2,11)] x = torch.rand((10, 11)) jitVsGlow(test_f, x, expected_fused_ops={"prim::ConstantChunk"})
def test_elementwise_max(self): """Test of the PyTorch max Node on Glow.""" def test_f(a, b): c = torch.max(a, b) return torch.max(c, c) x = torch.randn(4) y = torch.randn(4) jitVsGlow(test_f, x, y, expected_fused_ops={"aten::max"})
def test_constant_chunk_basic(): """Test of prim::ConstantChunk node on glow""" def test_f(x): return torch.chunk(x + x, 3, 1) # shapes: [(10,4), (10,4), (10,3)] x = torch.rand((10, 11)) jitVsGlow(test_f, x, expected_fused_ops={"prim::ConstantChunk"})
def test_mm_basic(): """Test of the PyTorch MatMul Node on Glow.""" def test_mm(a, b, t): r = torch.mm(a, b) return r.mm(t) lhs = torch.randn(2, 3) rhs = torch.randn(3, 4) t = torch.randn(4, 2) jitVsGlow(test_mm, lhs, rhs, t)
def test_cat_basic(self): """Basic test of the PyTorch cat Node on Glow.""" def test_f(a, b): c = torch.cat((a, b), 0) d = torch.cat((c, c), 1) return torch.cat((d, d), 2) x = torch.randn(2, 3, 4) y = torch.randn(2, 3, 4) jitVsGlow(test_f, x, y, expected_fused_ops={"prim::FusedConcat"})
def test_masked_fill_inplace(self): """Test of the PyTorch aten::masked_fill_ op on Glow""" def masked_fill(a, mask): b = a + a b.masked_fill_(mask, 42.0) return b x = torch.randn([3]) m = torch.tensor([True, False, True], dtype=torch.bool) jitVsGlow(masked_fill, x, m, expected_fused_ops={"aten::masked_fill_"})
def test_cat_neg_dim(self): """Test negative dimension index for the PyTorch cat Node on Glow.""" def test_f(a, b): c = torch.cat((a, b), -3) d = torch.cat((c, c), -2) return torch.cat((d, d), -1) x = torch.randn(2, 3, 4) y = torch.randn(2, 3, 4) jitVsGlow(test_f, x, y, expected_fused_ops={"prim::FusedConcat"})
def test_masked_fill_broadcasted_multi_dim(self): """Test of the PyTorch aten::masked_fill op on Glow with a broadcasted mask where the mask's size has a non 1 lead dim""" def masked_fill(a, mask): return torch.masked_fill(a + a, mask, 42.0) x = torch.randn([2, 4, 3, 3]) m = torch.tensor([[[[True, False, True]]], [[[True, False, True]]]], dtype=torch.bool) jitVsGlow(masked_fill, x, m, expected_fused_ops={"aten::masked_fill"})
def test_stack_basic(): """Basic test of the PyTorch aten::stack Node on Glow.""" def test_f(a, b): c = torch.stack((a, b), 0) d = torch.stack((c, c), 1) return torch.stack((d, d), 2) x = torch.randn(2, 3, 4) y = torch.randn(2, 3, 4) jitVsGlow(test_f, x, y, expected_fused_ops={"glow::fused_stack"})
def test_relu_inplace(self): """Test of the PyTorch relu_ Node on Glow.""" def test_f(a): b = F.relu(a, inplace=True) return F.relu(b, inplace=True) x = torch.randn(4) # make sure we have at least one negative x[0] = -2.0 jitVsGlow(test_f, x, expected_fused_ops={"aten::relu_"})
def test_layernorm_basic(self): """Basic test of the PyTorch layernorm Node on Glow.""" def test_f(inputs, weight, bias): return F.layer_norm(inputs, [5], weight, bias) inputs = torch.randn(1, 4, 5, 5) weight = torch.randn(5) bias = torch.randn(5) jitVsGlow(test_f, inputs, weight, bias, expected_fused_ops={"aten::layer_norm"})
def test_conv2d_with_bias(): def conv2d_with_bias(inputs, filters, bias): conv = F.conv2d(inputs, filters, bias) return F.relu(conv) inputs = torch.randn(1, 4, 5, 5) filters = torch.randn(8, 4, 3, 3) bias = torch.randn(8) jitVsGlow(conv2d_with_bias, inputs, filters, bias)
def test_sqrt_basic(): """Test of the PyTorch sqrt Node on Glow.""" def test_f(a): b = torch.sqrt(a) return torch.sqrt(b) # Make sure the input is positive and not super close to zero. x = torch.rand(4) + 5 jitVsGlow(test_f, x, expected_fused_ops={"aten::sqrt"})
def test_sqrt_inplace(): """Test of the PyTorch inplace sqrt Node on Glow.""" def test_f(a): b = a.sqrt_() return b.sqrt_() # Make sure the input is positive and not super close to zero. x = torch.rand(4) + 5 jitVsGlow(test_f, x, expected_fused_ops={"aten::sqrt_"})
def test_mm_basic(self): """Test of the PyTorch mm Node on Glow.""" def test_f(a, b, t): r = torch.mm(a, b) return r.mm(t) x = torch.randn(2, 3) y = torch.randn(4, 3).t() t = torch.randn(4, 2) jitVsGlow(test_f, x, y, t, expected_fused_ops={"aten::mm"})
def test_sub_broadcast_3(): """Test of the PyTorch sub Node on Glow with broadcasting.""" def test_f(a, b): c = a.sub(b) return c.sub(c) x = torch.randn(4, 2) y = torch.randn(8, 3, 4, 2) jitVsGlow(test_f, x, y, expected_fused_ops={"aten::sub"})
def test_relu_basic(self): """Basic test of the PyTorch relu Node on Glow.""" def test_f(a): b = F.relu(a) return F.relu(b) x = torch.randn(4) # make sure we have at least one negative x[0] = -2.0 jitVsGlow(test_f, x, expected_fused_ops={"aten::relu"})
def test_sub_basic(): """Basic test of the PyTorch sub Node on Glow.""" def test_f(a, b): c = a.sub(b) return c.sub(c) x = torch.randn(4) y = torch.randn(4) jitVsGlow(test_f, x, y, expected_fused_ops={"aten::sub"})
def test_typeas_self(self): """Test of the PyTorch mul Node on Glow doing empty convert (float to float).""" def test_f(a, b): a = a + a c = a.type_as(b) return c + c x = torch.randn(4) y = x jitVsGlow(test_f, x, y, expected_fused_ops={})
def test_typeas_self_f2f2(self): """Test of the PyTorch type_as Node on Glow float to float.""" def test_f(a, b): a = a + a c = a.type_as(b) return c + c x = torch.randn(4, 2) y = torch.randn(8, 3, 4, 2) jitVsGlow(test_f, x, y, expected_fused_ops={})
def test_typeas_self_f2i2(self): """Test of the PyTorch type_as Node on Glow with float to int32""" def test_f(a, b): a = a + a c = a.type_as(b) return c + c x = torch.randn(4, 2) y = torch.randn(8, 3, 4, 2).to(dtype=torch.int32) jitVsGlow(test_f, x, y, expected_fused_ops={"aten::type_as"})
def test_cat_oob_neg_dim(self): """Test out of bounds negative dimension index for the PyTorch cat Node on Glow.""" def test_f(a, b): c = torch.cat((a, b), -4) d = torch.cat((c, c), -2) return torch.cat((d, d), -1) x = torch.randn(2, 3, 4) y = torch.randn(2, 3, 4) with self.assertRaises(IndexError): jitVsGlow(test_f, x, y, expected_fused_ops={"prim::FusedConcat"})
def test_batchnorm_basic(): """Basic test of the PyTorch batchnorm Node on Glow.""" def test_f(inputs, running_mean, running_var): return F.batch_norm(inputs, running_mean, running_var) inputs = torch.randn(1, 4, 5, 5) running_mean = torch.rand(4) running_var = torch.rand(4) jitVsGlow(test_f, inputs, running_mean, running_var, expected_fused_ops={"aten::batch_norm"})
def test_gelu_basic(): """Basic test of the PyTorch gelu Node on Glow.""" def test_f(a): return F.gelu(a + a) for i in range(100): x = torch.randn(10) jitVsGlow(test_f, x, check_trace=False, atol=1e-3, expected_fused_ops={"aten::gelu"})
def test_batch_permutation_basic(self): """Basic test of the _caffe2::BatchPermutation Node on Glow.""" def test_f(a, indices): return torch.ops._caffe2.BatchPermutation(a + a, indices) x = torch.randn(4, 2, 3) indices = torch.tensor([1, 3, 0, 2], dtype=torch.int32) jitVsGlow(test_f, x, indices, expected_fused_ops={"_caffe2::BatchPermutation"})
def test_prelu_basic(self): """Basic test of the PyTorch prelu Node on Glow.""" def prelu_basic(inputs, weight): return F.prelu(inputs + inputs, weight) inputs = torch.randn(1, 4, 5, 5) weight = torch.tensor([0.25]) jitVsGlow(prelu_basic, inputs, weight, expected_fused_ops={"aten::prelu"})
def test_batchnorm_relu_basic(self): """ Basic test of the PyTorch 3D batchnorm RELU Node on Glow. """ class SimpleQuantizedBatchNormRelu(nn.Module): def __init__(self, w, b, m, v): super(SimpleQuantizedBatchNormRelu, self).__init__() self.bn = torch.nn.BatchNorm3d(4) self.relu = torch.nn.ReLU() self.bn.weight = torch.nn.Parameter(w) self.bn.bias = torch.nn.Parameter(b) self.bn.running_mean = m self.bn.running_var = v self.q = QuantStub() self.dq = DeQuantStub() def forward(self, x): qx = self.q(x) qy = self.bn(qx) qy_relu = self.relu(qy) y = self.dq(qy_relu) return y C = 4 weight = torch.ones(C) + torch.rand(C) * 0.001 bias = torch.rand(C) * 0.0001 running_mean = torch.zeros(C) running_var = torch.ones(C) inputs = torch.randn((10, C, 2, 3, 4), requires_grad=False) model = SimpleQuantizedBatchNormRelu(weight, bias, running_mean, running_var) model.eval() model.qconfig = my_qconfig modules_to_fuse = [["bn", "relu"]] fuse_modules(model, modules_to_fuse, inplace=True) prepare(model, inplace=True) model.forward(inputs) convert(model, inplace=True) # Because of the difference of quantization between PyTorch & Glow # We set eps big enough. # Batchnorm introduced great accuracy issues, which could create up to # ~1e-2 difference in some rare cases. In order to prevent this test # to be flaky, atol is set to be 0.1. jitVsGlow( model, inputs, expected_fused_ops={"quantized::batch_norm3d_relu"}, atol=1e-1, use_fp16=True, )
def test_jit_vs_glow_inplace(self): """Test JIT vs. Glow logging with in-place op""" torch_glow.enable_jit_vs_glow_compare() def test_f(a, b): a += b return a a = torch.randn(5, 6) b = torch.randn(5, 6) jitVsGlow(test_f, a, b, expected_fused_ops={"aten::add_"})
def test_conv2d_basic(self): """Basic test of the PyTorch conv2d Node on Glow.""" def test_f(inputs, filters): conv = F.conv2d(inputs, filters, padding=1) return F.relu(conv) inputs = torch.randn(1, 4, 5, 5) filters = torch.randn(8, 4, 3, 3) jitVsGlow(test_f, inputs, filters, expected_fused_ops={"aten::_convolution"})
def test_conv2d_non_square_dilation(self): """Test of the PyTorch conv2d Node on Glow with non-square dilation.""" def test_f(inputs, filters): conv = F.conv2d(inputs, filters, dilation=[1, 2]) return F.relu(conv) inputs = torch.randn(1, 4, 5, 5) filters = torch.randn(8, 4, 3, 3) jitVsGlow(test_f, inputs, filters, expected_fused_ops={"aten::_convolution"})