コード例 #1
0
def test_min_graph_size():
    """Test Glow fuser minimum fusion group size mechanism."""

    torch_glow.disableFusionPass()

    # Disable aten::div so that each group of aten::mul nodes will be forced
    # into separate subgraphs
    torch_glow.setFusionBlacklist(["aten::div"])

    # Set minimum fusion group size to 3 nodes so that the smallest group which
    # contains only 2 nodes will not be created
    torch_glow.setMinFusionGroupSize(3)

    a = torch.randn(5, 5)
    b = torch.randn(5, 5)
    c = torch.randn(5, 5)

    jit_f = torch.jit.trace(f, (a, b, c))
    jit_f_graph = jit_f.graph_for(a, b, c)

    # print("before: ", jit_f_graph)

    torch_glow.glowCustomFuseDebug_(jit_f_graph)

    # print("after: ", jit_f_graph)

    fusion_nodes = 0
    for node in jit_f_graph.nodes():
        if node.kind() == GLOW_NODE_NAME:
            fusion_nodes += 1

    assert fusion_nodes == 2, "Expected smallest fusion group to not be created"

    torch_glow.clearFusionBlacklist()
    torch_glow.setMinFusionGroupSize(0)
コード例 #2
0
    def test_max_fusion_merge_size(self):
        """Test Glow fuser maximum fusion merge size mechanism."""
        def f(a):
            return a * a * a * a * a * a

        torch_glow.disableFusionPass()

        # Set maximum fusion merge size to 3 nodes so that the
        # graph will not fit into 1 node
        torch_glow.setMaxFusionMergeSize(3)

        a = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a))
        jit_f_graph = jit_f.graph_for(a)

        # print("before: ", jit_f_graph)

        torch_glow.glowCustomFuseDebug_(jit_f_graph)

        # print("after: ", jit_f_graph)

        fusion_nodes = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                fusion_nodes += 1

        assert fusion_nodes > 1, "Expected more than one fusion group to be created"

        torch_glow.setMaxFusionMergeSize(0)
コード例 #3
0
    def test_max_fusion_merge_size_zero(self):
        """Test Glow fuser maximum fusion merge size mechanism set to zero."""
        def f(a):
            return a * a * a * a * a * a

        torch_glow.disableFusionPass()

        # Set maximum fusion merge size to 0 so that there is
        # no limit to fusion
        torch_glow.setMaxFusionMergeSize(0)

        a = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a))
        jit_f_graph = jit_f.graph_for(a)

        # print("before: ", jit_f_graph)

        torch_glow.glowCustomFuseDebug_(jit_f_graph)

        # print("after: ", jit_f_graph)

        fusion_nodes = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_FUSION_GROUP:
                fusion_nodes += 1

        assert fusion_nodes == 1, "Expected just one fusion group to be created"

        torch_glow.setMaxFusionMergeSize(0)
コード例 #4
0
ファイル: utils.py プロジェクト: zhangnju/glow
def traceVsGlow(f_torch,
                f_glow,
                check_trace,
                atol,
                rtol,
                *inputs,
                expected_fused_ops=None,
                accept_all_ops=False,
                black_list=None):
    if black_list is None:
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()

        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)
        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
コード例 #5
0
def ephemeral_torchglow_settings(
    fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None
):
    old_fp16 = torch_glow.get_convert_to_fp16()
    old_clip = torch_glow.get_clip_fp16()
    old_convert_fused = torch_glow.get_convert_fused_to_fp16()
    old_backend = torch_glow.getGlowBackendName()
    old_blocklist = torch_glow.getFusionBlacklist()
    old_fusion = torch_glow.getFusionPassEnabled()
    try:
        if fusion:
            torch_glow.enableFusionPass()
        else:
            torch_glow.disableFusionPass()
        if fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()
        if blocklist is None:
            torch_glow.clearFusionBlacklist()
        else:
            torch_glow.setFusionBlacklist(list(blocklist))
        torch_glow.setGlowBackend(backend)
        yield
    finally:
        torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16()
        torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16()
        torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16()
        torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass()
        torch_glow.setGlowBackend(old_backend)
        torch_glow.setFusionBlacklist(old_blocklist)
コード例 #6
0
    def test_only_tensor_outputs(self):
        """Test that Glow fuser only produces tensor outputs."""
        def f(a, b):
            x = (a + b).size(0)
            c = a.reshape(x, -1)
            return a + c

        torch_glow.disableFusionPass()

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))
        jit_f_graph = jit_f.graph_for(a, b)

        # By creating a graph with an aten::size (supported) feeding into an
        # unsupported op (prim::ListConstruct), we see that even if an op is
        # supported, if it produces a non-tensor output to the fusion group it
        # would not be fused.
        torch_glow.glowCustomFuseDebug_(
            jit_f_graph,
            ["prim::Constant", "aten::add", "aten::size", "aten::reshape"])

        fusion_nodes = 0
        aten_sizes = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                fusion_nodes += 1
            if node.kind() == "aten::size":
                aten_sizes += 1

        assert (
            fusion_nodes == 2
        ), "Expected two fusion nodes to be split up with aten::size between them"
        assert aten_sizes == 1, "Expected aten::size not to be fused"
コード例 #7
0
    def test_save_preprocessed_module(self):
        with torch.no_grad():
            x = torch.randn([1, 4, 4, 4], dtype=torch.float32)
            model = Bar()
            model.eval()
            model = torch.jit.trace(model, x)

            spec = torch_glow.CompilationSpec()
            spec.get_settings().set_glow_backend("Interpreter")

            compilation_group = torch_glow.CompilationGroup()
            spec.compilation_groups_append(compilation_group)

            compilation_group.input_sets_append(
                torch_glow.input_specs_from_tensors([x]))

            torch_glow.disableFusionPass()
            torch_glow.enable_convert_to_fp16()
            glow_mod = torch_glow.to_glow(model, spec)

            reloaded = utils.save_and_reload_model(glow_mod)

            wrappername = "__loweredModule__"
            attrname = "__processed_module"
            wp = getattr(reloaded._c, wrappername)
            pp = getattr(wp, attrname)
            pt_model = torch.jit._recursive.wrap_cpp_module(pp)
            graph = pt_model.graph_for(x)
            found = False
            for node in graph.nodes():
                if node.kind() == "quantized::conv2d":
                    found = True

            assert found
コード例 #8
0
def jitVsGlow_(f_torch,
               f_glow,
               *inputs,
               expected_fused_ops=None,
               accept_all_ops=False):

    with torch.no_grad():
        torch_glow.disableFusionPass()
        torch_trace = torch.jit.trace(f_torch, inputs)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        glow_trace = torch.jit.trace(f_glow, inputs)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        expected_fused_ops_seen = set()

        # Check that ops that were *not* fused are *not* in expected_fused_ops
        for node in glow_graph.nodes():
            kind = node.kind()
            if kind != GLOW_NODE_NAME:
                # If the node is not a Glow fusion group, check that it is
                # *not* in expected_fused_ops
                assert accept_all_ops or kind not in expected_fused_ops, \
                    "Expected {} to be fused".format(kind)
            else:
                # If the node is a Glow fusion group, record which ops from
                # expected_fused_ops were in it

                # Get the definition of the fusion group
                glow_group = node.g(SUBGRAPH_ATTR)

                # Put all nodes that are in the group and in expected_fused_ops
                # into expected_fused_ops_seen
                for fused_node in glow_group.nodes():
                    fused_node_kind = fused_node.kind()

                    if accept_all_ops or fused_node_kind in expected_fused_ops:
                        expected_fused_ops_seen.add(fused_node_kind)

        # If the sizes of expected_fused_ops and expected_fused_ops_seen are
        # different, some ops in expected_fused_ops are not in the graph at all
        assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \
            "Expected all of expected_fused_ops to be in the graph"
        assert len(torch_res) == len(glow_res)
        for i in range(len(torch_res)):
            assert torch.allclose(torch_res[i], glow_res[i], atol=01e-6)
コード例 #9
0
ファイル: utils.py プロジェクト: wwwyiwenchen/glow
def scriptVsGlow(
    f,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():

        torch_res = f(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.script(f)
        glow_res = glow_trace(*inputs)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
コード例 #10
0
def run_model(m, input, randomize):
    torch_glow.disableFusionPass()
    traced_m = torch.jit.trace(m, input)

    input_meta = InputMeta()
    input_meta.set_same_as(input)
    inputs = [input_meta]
    options = CompilationOptions()
    options.backend = "Interpreter"
    options.randomize_constants = randomize
    spec = GlowCompileSpec()
    spec.set(inputs, options)

    glow_m = torch_glow.to_glow(traced_m, {"forward": spec})
    return glow_m.forward(input)
コード例 #11
0
def test_fuse_necessary_getattrs_only():
    m = Model()
    x = torch.randn(1, 3, 5, 5)

    torch_glow.disableFusionPass()

    jit_m = torch.jit.trace(m, x)
    jit_m_graph = jit_m.graph_for(x)

    # don't fuse aten::_convolutions
    torch_glow.glowCustomFuseDebug_(
        jit_m_graph,
        ["prim::Constant", "prim::GetAttr", "aten::t", "aten::matmul", "aten::add_"],
    )

    return m(x)
コード例 #12
0
def run_model(m, input, randomize):
    if randomize:
        torch_glow.enable_randomize_constants()
    else:
        torch_glow.disable_randomize_constants()

    torch_glow.disableFusionPass()
    traced_m = torch.jit.trace(m, input)

    spec = torch.classes.glow.GlowCompileSpec()
    spec.setBackend("Interpreter")
    sim = torch.classes.glow.SpecInputMeta()
    sim.setSameAs(input)
    spec.addInputs([sim])

    glow_m = torch_glow.to_glow(traced_m, {"forward": spec})
    return glow_m.forward(input)
コード例 #13
0
ファイル: to_glow.py プロジェクト: tarunbansal11/glow
def onnx_capture(filename_prefix=None,
                 zip_mode=True,
                 write_without_randomize=False):
    try:
        torch_glow.disableFusionPass()
        torch_glow.enable_write_to_onnx()
        if write_without_randomize:
            torch_glow.enable_write_without_randomize()
        if zip_mode:
            torch_glow.enable_onnx_zip_mode()
        if filename_prefix is not None:
            torch_glow.set_onnx_file_name_prefix(filename_prefix)
        yield
    finally:
        torch_glow.disable_write_without_randomize()
        torch_glow.disable_write_to_onnx()
        torch_glow.disable_onnx_zip_mode()
        torch_glow.set_onnx_file_name_prefix("")
コード例 #14
0
    def test_serialization(self):
        with torch.no_grad():
            x = torch.randn([1, 4, 4, 4], dtype=torch.float32)
            y = torch.randn([1, 4, 4, 4], dtype=torch.float32)
            model = Bar()
            model = torch.jit.trace(model, (x, y))

            spec = torch_glow.CompilationSpec()
            spec_settings = spec.get_settings()
            spec_settings.set_glow_backend("NNPI")
            # Enabled the serialize in this spec
            spec_settings.set_enable_serialize(True)

            compilation_group = torch_glow.CompilationGroup()
            compilation_group_settings = compilation_group.get_settings()
            compilation_group_settings.set_replication_count(1)
            compilation_group_settings.backend_specific_opts_insert(
                "NNPI_IceCores", "1")

            compilation_group.input_sets_append(
                torch_glow.input_specs_from_tensors([x, y]))

            spec.compilation_groups_append(compilation_group)
            torch_glow.disableFusionPass()
            torch_glow.enable_convert_to_fp16()

            # Enable global serialize
            # then compile(serialize) the model and save it
            torch_glow.enable_dump_serialized_model()
            glow_mod = torch_glow.to_glow(model, spec)
            res1 = glow_mod(x, y)
            torch.jit.save(glow_mod, "/tmp/serialize_to_glow.pt")

            # Enable global deserialize and disable serialize
            # and load(deserialize) the model to loaded_glow_mod
            torch_glow.enable_deserialize()
            torch_glow.disable_dump_serialized_model()
            loaded_glow_mod = torch.jit.load("/tmp/serialize_to_glow.pt")
            res2 = loaded_glow_mod(x, y)
            assert torch.allclose(res1, res2, 1e-5, 1e-5)
コード例 #15
0
def run_model(m, input, randomize):
    torch_glow.disableFusionPass()
    traced_m = torch.jit.trace(m, input)

    if randomize:
        torch_glow.enable_randomize_constants()
    else:
        torch_glow.disable_randomize_constants()

    spec = torch_glow.CompilationSpec()
    spec.get_settings().set_glow_backend("Interpreter")

    compilation_group = torch_glow.CompilationGroup()
    spec.compilation_groups_append(compilation_group)

    input_spec = torch_glow.InputSpec()
    input_spec.set_same_as(input)

    compilation_group.input_sets_append([input_spec])

    glow_m = torch_glow.to_glow(traced_m, {"forward": spec})
    return glow_m(input)
コード例 #16
0
ファイル: utils.py プロジェクト: tarunkp/glow
def jitVsGlow(f, *inputs):
    torch_glow.disableFusionPass()
    torch_trace = torch.jit.trace(f, inputs)
    torch_res = torch_trace(*inputs)

    torch_glow.enableFusionPass()
    glow_trace = torch.jit.trace(f, inputs)
    glow_res = glow_trace(*inputs)

    # check that there are no Glow nodes in the torch graph
    torch_graph = torch_trace.graph_for(*inputs)
    print("torch_graph,", torch_graph)
    num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
    assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
        num_glow_nodes)

    # check that there is exactly 1 Glow node in the glow graph
    glow_graph = glow_trace.graph_for(*inputs)
    print("glow_graph,", glow_graph)
    num_glow_nodes = len(glow_graph.findAllNodes(GLOW_NODE_NAME))
    assert num_glow_nodes == 1, "Expected exactly 1 Glow node, found {}".format(
        num_glow_nodes)

    assert torch.allclose(torch_res, glow_res, atol=01e-6)
コード例 #17
0
ファイル: utils.py プロジェクト: wwwyiwenchen/glow
def traceVsGlow(
    f_torch,
    f_glow,
    check_trace,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()

        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
コード例 #18
0
def jitVsGlow_(f_torch,
               f_glow,
               check_trace,
               atol,
               rtol,
               *inputs,
               expected_fused_ops=None,
               accept_all_ops=False,
               black_list=None):
    if (black_list is None):
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()
        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)
        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        expected_fused_ops_seen = set()

        # Whether or not at least one node was fused to Glow.
        nodes_were_fused = False

        # Check that ops that were *not* fused are *not* in expected_fused_ops
        for node in glow_graph.nodes():
            kind = node.kind()
            if kind != GLOW_NODE_NAME:
                # If the node is not a Glow fusion group, check that it is
                # *not* in expected_fused_ops
                assert accept_all_ops or kind not in expected_fused_ops, \
                    "Expected {} to be fused".format(kind)
            else:
                # If the node is a Glow fusion group, record which ops from
                # expected_fused_ops were in it

                # Get the definition of the fusion group
                glow_group = node.g(SUBGRAPH_ATTR)

                # Put all nodes that are in the group and in expected_fused_ops
                # into expected_fused_ops_seen
                for fused_node in glow_group.nodes():
                    nodes_were_fused = True
                    fused_node_kind = fused_node.kind()

                    if accept_all_ops or fused_node_kind in expected_fused_ops:
                        expected_fused_ops_seen.add(fused_node_kind)

        assert nodes_were_fused, "Expected some nodes to be fused to Glow"

        # If the sizes of expected_fused_ops and expected_fused_ops_seen are
        # different, some ops in expected_fused_ops are not in the graph at all
        assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \
            "Expected all of expected_fused_ops to be in the graph"

        if isinstance(torch_res, tuple) or isinstance(glow_res, tuple):
            assert isinstance(torch_res, tuple) and isinstance(glow_res, tuple)
            assert len(torch_res) == len(glow_res)
            for i in range(len(torch_res)):
                print("torch shape: {}".format(torch_res[i].shape),
                      file=sys.stderr)
                print("glow shape: {}".format(glow_res[i].shape),
                      file=sys.stderr)
                assert torch.allclose(torch_res[i],
                                      glow_res[i],
                                      atol=atol,
                                      rtol=rtol)
        else:
            print("torch shape: {}".format(torch_res.shape), file=sys.stderr)
            print("glow shape: {}".format(glow_res.shape), file=sys.stderr)
            is_all_close = torch.allclose(torch_res,
                                          glow_res,
                                          atol=atol,
                                          rtol=rtol)
            if not is_all_close:
                print("torch_res\n", torch_res)
                print("glow_res\n", glow_res)
                print("diff\n", torch.abs(glow_res - torch_res))
            assert is_all_close