def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def test_save_preprocessed_module(self): with torch.no_grad(): x = torch.randn([1, 4, 4, 4], dtype=torch.float32) model = Bar() model.eval() model = torch.jit.trace(model, x) spec = torch_glow.CompilationSpec() spec.get_settings().set_glow_backend("Interpreter") compilation_group = torch_glow.CompilationGroup() spec.compilation_groups_append(compilation_group) compilation_group.input_sets_append( torch_glow.input_specs_from_tensors([x])) torch_glow.disableFusionPass() torch_glow.enable_convert_to_fp16() glow_mod = torch_glow.to_glow(model, spec) reloaded = utils.save_and_reload_model(glow_mod) wrappername = "__loweredModule__" attrname = "__processed_module" wp = getattr(reloaded._c, wrappername) pp = getattr(wp, attrname) pt_model = torch.jit._recursive.wrap_cpp_module(pp) graph = pt_model.graph_for(x) found = False for node in graph.nodes(): if node.kind() == "quantized::conv2d": found = True assert found
def test_batchnorm_relu_basic(self): """ Basic test of the PyTorch 3D batchnorm RELU Node on Glow. """ class SimpleQuantizedBatchNormRelu(nn.Module): def __init__(self, w, b, m, v): super(SimpleQuantizedBatchNormRelu, self).__init__() self.bn = torch.nn.BatchNorm3d(4) self.relu = torch.nn.ReLU() self.bn.weight = torch.nn.Parameter(w) self.bn.bias = torch.nn.Parameter(b) self.bn.running_mean = m self.bn.running_var = v self.q = QuantStub() self.dq = DeQuantStub() def forward(self, x): qx = self.q(x) qy = self.bn(qx) qy_relu = self.relu(qy) y = self.dq(qy_relu) return y C = 4 weight = torch.ones(C) + torch.rand(C) * 0.001 bias = torch.rand(C) * 0.0001 running_mean = torch.zeros(C) running_var = torch.ones(C) inputs = torch.randn((10, C, 2, 3, 4), requires_grad=False) model = SimpleQuantizedBatchNormRelu(weight, bias, running_mean, running_var) model.eval() model.qconfig = my_qconfig modules_to_fuse = [["bn", "relu"]] fuse_modules(model, modules_to_fuse, inplace=True) prepare(model, inplace=True) model.forward(inputs) convert(model, inplace=True) torch_glow.enable_convert_to_fp16() # Because of the difference of quantization between PyTorch & Glow # We set eps big enough. # Batchnorm introduced great accuracy issues, which could create up to # ~1e-2 difference in some rare cases. In order to prevent this test # to be flaky, atol is set to be 0.1. jitVsGlow( model, inputs, expected_fused_ops={"quantized::batch_norm3d_relu"}, atol=1e-1, )
def scriptVsGlow( f, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_res = f(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.script(f) glow_res = glow_trace(*inputs) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def test_batchnorm_with_weights(self): """ Test of the PyTorch 2D batchnorm Node with weights and biases on Glow. """ class SimpleQuantizedBatchNorm(nn.Module): def __init__( self, C, weight, bias, running_mean, running_var, scale, zero_point ): super(SimpleQuantizedBatchNorm, self).__init__() self.qconfig = my_qconfig self.batchnorm = nn.quantized.BatchNorm3d(C) self.batchnorm.scale = scale self.batchnorm.zero_point = zero_point self.batchnorm.weight = torch.nn.Parameter(weight) self.batchnorm.bias = torch.nn.Parameter(bias) self.batchnorm.running_mean = running_mean self.batchnorm.running_var = running_var self.relu = torch.nn.ReLU() self.dq = torch.nn.quantized.DeQuantize() def forward(self, x): return self.dq(self.relu(self.batchnorm(x))) C = 7 in_scale = out_scale = 0.0047 in_zero_point = out_zero_point = -7 weight = torch.ones(C) + torch.rand(C) * 0.001 bias = torch.rand(C) * 0.0001 running_mean = torch.zeros(C) running_var = torch.ones(C) inputs = torch.randn(6, C, 4, 33, 42) inputs = torch.quantize_per_tensor( inputs, scale=in_scale, zero_point=in_zero_point, dtype=torch.qint8 ) model = SimpleQuantizedBatchNorm( C, weight, bias, running_mean, running_var, out_scale, out_zero_point ) model.eval() torch_glow.enable_convert_to_fp16() jitVsGlow(model, inputs, expected_fused_ops={"quantized::batch_norm3d"})
def test_serialization(self): with torch.no_grad(): x = torch.randn([1, 4, 4, 4], dtype=torch.float32) y = torch.randn([1, 4, 4, 4], dtype=torch.float32) model = Bar() model = torch.jit.trace(model, (x, y)) spec = torch_glow.CompilationSpec() spec_settings = spec.get_settings() spec_settings.set_glow_backend("NNPI") # Enabled the serialize in this spec spec_settings.set_enable_serialize(True) compilation_group = torch_glow.CompilationGroup() compilation_group_settings = compilation_group.get_settings() compilation_group_settings.set_replication_count(1) compilation_group_settings.backend_specific_opts_insert( "NNPI_IceCores", "1") compilation_group.input_sets_append( torch_glow.input_specs_from_tensors([x, y])) spec.compilation_groups_append(compilation_group) torch_glow.disableFusionPass() torch_glow.enable_convert_to_fp16() # Enable global serialize # then compile(serialize) the model and save it torch_glow.enable_dump_serialized_model() glow_mod = torch_glow.to_glow(model, spec) res1 = glow_mod(x, y) torch.jit.save(glow_mod, "/tmp/serialize_to_glow.pt") # Enable global deserialize and disable serialize # and load(deserialize) the model to loaded_glow_mod torch_glow.enable_deserialize() torch_glow.disable_dump_serialized_model() loaded_glow_mod = torch.jit.load("/tmp/serialize_to_glow.pt") res2 = loaded_glow_mod(x, y) assert torch.allclose(res1, res2, 1e-5, 1e-5)
def test_batchnorm_basic(self): """ Basic test of the PyTorch 3D batchnorm Node on Glow. """ class SimpleQuantizedBatchNorm(nn.Module): def __init__(self, C, running_mean, running_var, scale, zero_point): super(SimpleQuantizedBatchNorm, self).__init__() self.qconfig = my_qconfig self.batchnorm = nn.quantized.BatchNorm3d(C) self.batchnorm.scale = scale self.batchnorm.zero_point = zero_point self.batchnorm.running_mean = running_mean self.batchnorm.running_var = running_var self.relu = torch.nn.ReLU() self.dq = torch.nn.quantized.DeQuantize() def forward(self, x): return self.dq(self.relu(self.batchnorm(x))) C = 4 in_scale = out_scale = 0.004 in_zero_point = out_zero_point = 4 running_mean = torch.zeros(C) running_var = torch.ones(C) inputs = torch.randn((5, C, 6, 32, 73), requires_grad=False) inputs = torch.quantize_per_tensor( inputs, scale=in_scale, zero_point=in_zero_point, dtype=torch.qint8 ) model = SimpleQuantizedBatchNorm( C, running_mean, running_var, out_scale, out_zero_point ) model.eval() torch_glow.enable_convert_to_fp16() jitVsGlow(model, inputs, expected_fused_ops={"quantized::batch_norm3d"})
def traceVsGlow( f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)