Exemplo n.º 1
0
def ephemeral_torchglow_settings(
    fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None
):
    old_fp16 = torch_glow.get_convert_to_fp16()
    old_clip = torch_glow.get_clip_fp16()
    old_convert_fused = torch_glow.get_convert_fused_to_fp16()
    old_backend = torch_glow.getGlowBackendName()
    old_blocklist = torch_glow.getFusionBlacklist()
    old_fusion = torch_glow.getFusionPassEnabled()
    try:
        if fusion:
            torch_glow.enableFusionPass()
        else:
            torch_glow.disableFusionPass()
        if fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()
        if blocklist is None:
            torch_glow.clearFusionBlacklist()
        else:
            torch_glow.setFusionBlacklist(list(blocklist))
        torch_glow.setGlowBackend(backend)
        yield
    finally:
        torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16()
        torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16()
        torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16()
        torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass()
        torch_glow.setGlowBackend(old_backend)
        torch_glow.setFusionBlacklist(old_blocklist)
Exemplo n.º 2
0
def scriptVsGlow(
    f,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():

        torch_res = f(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.script(f)
        glow_res = glow_trace(*inputs)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
Exemplo n.º 3
0
def traceVsGlow(
    f_torch,
    f_glow,
    check_trace,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()

        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)