def test_min_graph_size(): """Test Glow fuser minimum fusion group size mechanism.""" torch_glow.disableFusionPass() # Disable aten::div so that each group of aten::mul nodes will be forced # into separate subgraphs torch_glow.setFusionBlacklist(["aten::div"]) # Set minimum fusion group size to 3 nodes so that the smallest group which # contains only 2 nodes will not be created torch_glow.setMinFusionGroupSize(3) a = torch.randn(5, 5) b = torch.randn(5, 5) c = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b, c)) jit_f_graph = jit_f.graph_for(a, b, c) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 assert fusion_nodes == 2, "Expected smallest fusion group to not be created" torch_glow.clearFusionBlacklist() torch_glow.setMinFusionGroupSize(0)
def test_max_fusion_merge_size(self): """Test Glow fuser maximum fusion merge size mechanism.""" def f(a): return a * a * a * a * a * a torch_glow.disableFusionPass() # Set maximum fusion merge size to 3 nodes so that the # graph will not fit into 1 node torch_glow.setMaxFusionMergeSize(3) a = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a)) jit_f_graph = jit_f.graph_for(a) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 assert fusion_nodes > 1, "Expected more than one fusion group to be created" torch_glow.setMaxFusionMergeSize(0)
def test_max_fusion_merge_size_zero(self): """Test Glow fuser maximum fusion merge size mechanism set to zero.""" def f(a): return a * a * a * a * a * a torch_glow.disableFusionPass() # Set maximum fusion merge size to 0 so that there is # no limit to fusion torch_glow.setMaxFusionMergeSize(0) a = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a)) jit_f_graph = jit_f.graph_for(a) # print("before: ", jit_f_graph) torch_glow.glowCustomFuseDebug_(jit_f_graph) # print("after: ", jit_f_graph) fusion_nodes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_FUSION_GROUP: fusion_nodes += 1 assert fusion_nodes == 1, "Expected just one fusion group to be created" torch_glow.setMaxFusionMergeSize(0)
def traceVsGlow(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def test_only_tensor_outputs(self): """Test that Glow fuser only produces tensor outputs.""" def f(a, b): x = (a + b).size(0) c = a.reshape(x, -1) return a + c torch_glow.disableFusionPass() a = torch.randn(5, 5) b = torch.randn(5, 5) jit_f = torch.jit.trace(f, (a, b)) jit_f_graph = jit_f.graph_for(a, b) # By creating a graph with an aten::size (supported) feeding into an # unsupported op (prim::ListConstruct), we see that even if an op is # supported, if it produces a non-tensor output to the fusion group it # would not be fused. torch_glow.glowCustomFuseDebug_( jit_f_graph, ["prim::Constant", "aten::add", "aten::size", "aten::reshape"]) fusion_nodes = 0 aten_sizes = 0 for node in jit_f_graph.nodes(): if node.kind() == GLOW_NODE_NAME: fusion_nodes += 1 if node.kind() == "aten::size": aten_sizes += 1 assert ( fusion_nodes == 2 ), "Expected two fusion nodes to be split up with aten::size between them" assert aten_sizes == 1, "Expected aten::size not to be fused"
def test_save_preprocessed_module(self): with torch.no_grad(): x = torch.randn([1, 4, 4, 4], dtype=torch.float32) model = Bar() model.eval() model = torch.jit.trace(model, x) spec = torch_glow.CompilationSpec() spec.get_settings().set_glow_backend("Interpreter") compilation_group = torch_glow.CompilationGroup() spec.compilation_groups_append(compilation_group) compilation_group.input_sets_append( torch_glow.input_specs_from_tensors([x])) torch_glow.disableFusionPass() torch_glow.enable_convert_to_fp16() glow_mod = torch_glow.to_glow(model, spec) reloaded = utils.save_and_reload_model(glow_mod) wrappername = "__loweredModule__" attrname = "__processed_module" wp = getattr(reloaded._c, wrappername) pp = getattr(wp, attrname) pt_model = torch.jit._recursive.wrap_cpp_module(pp) graph = pt_model.graph_for(x) found = False for node in graph.nodes(): if node.kind() == "quantized::conv2d": found = True assert found
def jitVsGlow_(f_torch, f_glow, *inputs, expected_fused_ops=None, accept_all_ops=False): with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() glow_trace = torch.jit.trace(f_glow, inputs) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) expected_fused_ops_seen = set() # Check that ops that were *not* fused are *not* in expected_fused_ops for node in glow_graph.nodes(): kind = node.kind() if kind != GLOW_NODE_NAME: # If the node is not a Glow fusion group, check that it is # *not* in expected_fused_ops assert accept_all_ops or kind not in expected_fused_ops, \ "Expected {} to be fused".format(kind) else: # If the node is a Glow fusion group, record which ops from # expected_fused_ops were in it # Get the definition of the fusion group glow_group = node.g(SUBGRAPH_ATTR) # Put all nodes that are in the group and in expected_fused_ops # into expected_fused_ops_seen for fused_node in glow_group.nodes(): fused_node_kind = fused_node.kind() if accept_all_ops or fused_node_kind in expected_fused_ops: expected_fused_ops_seen.add(fused_node_kind) # If the sizes of expected_fused_ops and expected_fused_ops_seen are # different, some ops in expected_fused_ops are not in the graph at all assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \ "Expected all of expected_fused_ops to be in the graph" assert len(torch_res) == len(glow_res) for i in range(len(torch_res)): assert torch.allclose(torch_res[i], glow_res[i], atol=01e-6)
def scriptVsGlow( f, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_res = f(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.script(f) glow_res = glow_trace(*inputs) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def run_model(m, input, randomize): torch_glow.disableFusionPass() traced_m = torch.jit.trace(m, input) input_meta = InputMeta() input_meta.set_same_as(input) inputs = [input_meta] options = CompilationOptions() options.backend = "Interpreter" options.randomize_constants = randomize spec = GlowCompileSpec() spec.set(inputs, options) glow_m = torch_glow.to_glow(traced_m, {"forward": spec}) return glow_m.forward(input)
def test_fuse_necessary_getattrs_only(): m = Model() x = torch.randn(1, 3, 5, 5) torch_glow.disableFusionPass() jit_m = torch.jit.trace(m, x) jit_m_graph = jit_m.graph_for(x) # don't fuse aten::_convolutions torch_glow.glowCustomFuseDebug_( jit_m_graph, ["prim::Constant", "prim::GetAttr", "aten::t", "aten::matmul", "aten::add_"], ) return m(x)
def run_model(m, input, randomize): if randomize: torch_glow.enable_randomize_constants() else: torch_glow.disable_randomize_constants() torch_glow.disableFusionPass() traced_m = torch.jit.trace(m, input) spec = torch.classes.glow.GlowCompileSpec() spec.setBackend("Interpreter") sim = torch.classes.glow.SpecInputMeta() sim.setSameAs(input) spec.addInputs([sim]) glow_m = torch_glow.to_glow(traced_m, {"forward": spec}) return glow_m.forward(input)
def onnx_capture(filename_prefix=None, zip_mode=True, write_without_randomize=False): try: torch_glow.disableFusionPass() torch_glow.enable_write_to_onnx() if write_without_randomize: torch_glow.enable_write_without_randomize() if zip_mode: torch_glow.enable_onnx_zip_mode() if filename_prefix is not None: torch_glow.set_onnx_file_name_prefix(filename_prefix) yield finally: torch_glow.disable_write_without_randomize() torch_glow.disable_write_to_onnx() torch_glow.disable_onnx_zip_mode() torch_glow.set_onnx_file_name_prefix("")
def test_serialization(self): with torch.no_grad(): x = torch.randn([1, 4, 4, 4], dtype=torch.float32) y = torch.randn([1, 4, 4, 4], dtype=torch.float32) model = Bar() model = torch.jit.trace(model, (x, y)) spec = torch_glow.CompilationSpec() spec_settings = spec.get_settings() spec_settings.set_glow_backend("NNPI") # Enabled the serialize in this spec spec_settings.set_enable_serialize(True) compilation_group = torch_glow.CompilationGroup() compilation_group_settings = compilation_group.get_settings() compilation_group_settings.set_replication_count(1) compilation_group_settings.backend_specific_opts_insert( "NNPI_IceCores", "1") compilation_group.input_sets_append( torch_glow.input_specs_from_tensors([x, y])) spec.compilation_groups_append(compilation_group) torch_glow.disableFusionPass() torch_glow.enable_convert_to_fp16() # Enable global serialize # then compile(serialize) the model and save it torch_glow.enable_dump_serialized_model() glow_mod = torch_glow.to_glow(model, spec) res1 = glow_mod(x, y) torch.jit.save(glow_mod, "/tmp/serialize_to_glow.pt") # Enable global deserialize and disable serialize # and load(deserialize) the model to loaded_glow_mod torch_glow.enable_deserialize() torch_glow.disable_dump_serialized_model() loaded_glow_mod = torch.jit.load("/tmp/serialize_to_glow.pt") res2 = loaded_glow_mod(x, y) assert torch.allclose(res1, res2, 1e-5, 1e-5)
def run_model(m, input, randomize): torch_glow.disableFusionPass() traced_m = torch.jit.trace(m, input) if randomize: torch_glow.enable_randomize_constants() else: torch_glow.disable_randomize_constants() spec = torch_glow.CompilationSpec() spec.get_settings().set_glow_backend("Interpreter") compilation_group = torch_glow.CompilationGroup() spec.compilation_groups_append(compilation_group) input_spec = torch_glow.InputSpec() input_spec.set_same_as(input) compilation_group.input_sets_append([input_spec]) glow_m = torch_glow.to_glow(traced_m, {"forward": spec}) return glow_m(input)
def jitVsGlow(f, *inputs): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f, inputs) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() glow_trace = torch.jit.trace(f, inputs) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) # check that there is exactly 1 Glow node in the glow graph glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) num_glow_nodes = len(glow_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 1, "Expected exactly 1 Glow node, found {}".format( num_glow_nodes) assert torch.allclose(torch_res, glow_res, atol=01e-6)
def traceVsGlow( f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None, use_fp16=False, backend_name=None, ): if black_list is None: black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) if use_fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if backend_name: torch_glow.setGlowBackend(backend_name) else: torch_glow.setGlowBackend("Interpreter") glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) # need to explicitly clear settings to avoid carry-over static settings torch_glow.disableFusionPass() torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() torch_glow.setGlowBackend("Interpreter") checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops) checkResult(torch_res, glow_res, atol, rtol)
def jitVsGlow_(f_torch, f_glow, check_trace, atol, rtol, *inputs, expected_fused_ops=None, accept_all_ops=False, black_list=None): if (black_list is None): black_list = [] with torch.no_grad(): torch_glow.disableFusionPass() torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace) torch_res = torch_trace(*inputs) torch_glow.enableFusionPass() torch_glow.setFusionBlacklist(black_list) glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace) glow_res = glow_trace(*inputs) # check that there are no Glow nodes in the torch graph torch_graph = torch_trace.graph_for(*inputs) print("torch_graph,", torch_graph) num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME)) assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format( num_glow_nodes) glow_graph = glow_trace.graph_for(*inputs) print("glow_graph,", glow_graph) expected_fused_ops_seen = set() # Whether or not at least one node was fused to Glow. nodes_were_fused = False # Check that ops that were *not* fused are *not* in expected_fused_ops for node in glow_graph.nodes(): kind = node.kind() if kind != GLOW_NODE_NAME: # If the node is not a Glow fusion group, check that it is # *not* in expected_fused_ops assert accept_all_ops or kind not in expected_fused_ops, \ "Expected {} to be fused".format(kind) else: # If the node is a Glow fusion group, record which ops from # expected_fused_ops were in it # Get the definition of the fusion group glow_group = node.g(SUBGRAPH_ATTR) # Put all nodes that are in the group and in expected_fused_ops # into expected_fused_ops_seen for fused_node in glow_group.nodes(): nodes_were_fused = True fused_node_kind = fused_node.kind() if accept_all_ops or fused_node_kind in expected_fused_ops: expected_fused_ops_seen.add(fused_node_kind) assert nodes_were_fused, "Expected some nodes to be fused to Glow" # If the sizes of expected_fused_ops and expected_fused_ops_seen are # different, some ops in expected_fused_ops are not in the graph at all assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \ "Expected all of expected_fused_ops to be in the graph" if isinstance(torch_res, tuple) or isinstance(glow_res, tuple): assert isinstance(torch_res, tuple) and isinstance(glow_res, tuple) assert len(torch_res) == len(glow_res) for i in range(len(torch_res)): print("torch shape: {}".format(torch_res[i].shape), file=sys.stderr) print("glow shape: {}".format(glow_res[i].shape), file=sys.stderr) assert torch.allclose(torch_res[i], glow_res[i], atol=atol, rtol=rtol) else: print("torch shape: {}".format(torch_res.shape), file=sys.stderr) print("glow shape: {}".format(glow_res.shape), file=sys.stderr) is_all_close = torch.allclose(torch_res, glow_res, atol=atol, rtol=rtol) if not is_all_close: print("torch_res\n", torch_res) print("glow_res\n", glow_res) print("diff\n", torch.abs(glow_res - torch_res)) assert is_all_close