Exemplo n.º 1
0
def test_set_glow_backend():
    """Test setting the Glow backend type"""

    backend_name_before = torch_glow.getGlowBackendName()
    backend_num_devices_before = torch_glow.getGlowBackendNumDevices()

    torch_glow.setGlowBackend("CPU", 4)

    assert (torch_glow.getGlowBackendName() == "CPU")
    assert (torch_glow.getGlowBackendNumDevices() == 4)

    # reset everything
    torch_glow.setGlowBackend(backend_name_before, backend_num_devices_before)
Exemplo n.º 2
0
def ephemeral_torchglow_settings(
    fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None
):
    old_fp16 = torch_glow.get_convert_to_fp16()
    old_clip = torch_glow.get_clip_fp16()
    old_convert_fused = torch_glow.get_convert_fused_to_fp16()
    old_backend = torch_glow.getGlowBackendName()
    old_blocklist = torch_glow.getFusionBlacklist()
    old_fusion = torch_glow.getFusionPassEnabled()
    try:
        if fusion:
            torch_glow.enableFusionPass()
        else:
            torch_glow.disableFusionPass()
        if fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()
        if blocklist is None:
            torch_glow.clearFusionBlacklist()
        else:
            torch_glow.setFusionBlacklist(list(blocklist))
        torch_glow.setGlowBackend(backend)
        yield
    finally:
        torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16()
        torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16()
        torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16()
        torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass()
        torch_glow.setGlowBackend(old_backend)
        torch_glow.setFusionBlacklist(old_blocklist)
Exemplo n.º 3
0
def run_comparison_tests(
        module,
        inputs,
        fusible_ops,
        fp32vfp32_atol=5e-4,
        fp32vfp32_rtol=1e-3,
        fp32vfp16_atol=1e-2,
        fp32vfp16_rtol=1e-2,
        fp16vfp16_atol=5e-4,
        fp16vfp16_rtol=1e-3,
        fusion_blocklist=None,
        scripted=False,
        check_trace=True,
        skip_for_backends=DEFAULT_SKIP_BACKENDS_SET,
        skip_fp32_vs_fp16=False,
        skip_to_glow=False,  # Ugly hack, TODO: Remove
):
    # tuplify inputs
    if not isinstance(inputs, tuple):
        inputs = (inputs, )

    # Check that test is setup properly
    if not isinstance(module, torch.nn.Module):
        raise AssertionError("to_glow only supports nn.Modules")

    if "Interpreter" in skip_for_backends:
        raise AssertionError(
            "Interpreter backend can't be skipped, skip entire test until Interpreter is supported"
        )

    # If other_backend isn't supported then skip the test
    other_backend = torch_glow.getGlowBackendName()
    if other_backend in skip_for_backends:
        raise unittest.SkipTest(
            f"backend {other_backend} not supported for this test")

    # Get other Glow backend besides Interpreter to test if applicable
    if other_backend == "Interpreter":
        other_backend = None

    if skip_to_glow and other_backend:
        raise AssertionError(
            f"to_glow must be used for non-interpreter backends, skip this test for {other_backend} backend until the test supports to_glow"
        )

    def prepare(m, inputs, fp16, backend, fusion):
        """ "Helper to prepare a JIT module to run either on PyTorch or Glow"""

        inputs = deepcopy(inputs)

        def getJITModule():
            m_jit = None
            if scripted:
                m_jit = torch.jit.script(m)
            else:
                m_jit = torch.jit.trace(m, inputs, check_trace=check_trace)
            if scripted or not check_trace:
                # run it once to activate the fuser if not run yet
                m_jit(*inputs)
            return m_jit

        with torch.no_grad():
            m_jit = None
            if fusion:
                with ephemeral_torchglow_settings(fusion=True,
                                                  fp16=fp16,
                                                  backend=backend,
                                                  blocklist=fusion_blocklist):
                    m_jit = getJITModule()
                    assert_fused(
                        m_jit.graph_for(*(deepcopy(inputs))),
                        fusible_ops,
                    )
            else:
                m_jit = getJITModule()

            if backend != "PyTorch":  # to_glow
                m_jit = torch_glow.lower(
                    model=m_jit,
                    example_inputs=inputs,
                    backend=backend,
                    convert_to_fp16=fp16,
                )

            return m_jit

    def compare(a_name, a, b_name, b, atol, rtol, use_eq=False):
        """ "Helper to compare two JIT modules, skip comparison if either is None"""

        if not a:
            print(
                f"Skipping {a_name} vs {b_name} because {a_name} not computed")
            return
        if not b:
            print(
                f"Skipping {a_name} vs {b_name} because {b_name} not computed")
            return
        a_ouputs = a(*deepcopy(inputs))
        b_ouputs = b(*deepcopy(inputs))
        assert_equivalent(a_name, a_ouputs, b_name, b_ouputs, atol, rtol,
                          use_eq)

    # Prepare modules for testing
    m_pytorch_fp32 = prepare(module,
                             inputs,
                             fp16=False,
                             backend="PyTorch",
                             fusion=False)

    m_interpreter_fuser_fp32 = prepare(module,
                                       inputs,
                                       fp16=False,
                                       backend="Interpreter",
                                       fusion=True)

    m_interpreter_fp32 = None
    m_interpreter_fp16 = None
    m_other_fp16 = None

    if not skip_to_glow:
        m_interpreter_fp32 = prepare(module,
                                     inputs,
                                     fp16=False,
                                     backend="Interpreter",
                                     fusion=True)

        m_interpreter_fp16 = prepare(module,
                                     inputs,
                                     fp16=True,
                                     backend="Interpreter",
                                     fusion=True)

        m_other_fp16 = None
        if other_backend:
            m_other_fp16 = prepare(module,
                                   inputs,
                                   fp16=True,
                                   backend=other_backend,
                                   fusion=False)

    # JIT vs Interpreter, via to_glow, fp32-fp32
    compare(
        "m_pytorch_fp32",
        m_pytorch_fp32,
        "m_interpreter_fp32",
        m_interpreter_fp32,
        fp32vfp32_atol,
        fp32vfp32_rtol,
    )

    # Interpreter vs Interpreter, via to_glow and fuser, fp32-fp32
    compare(
        "m_interpreter_fp32",
        m_interpreter_fp32,
        "m_interpreter_fuser_fp32",
        m_interpreter_fuser_fp32,
        fp32vfp32_atol,
        fp32vfp32_rtol,
        use_eq=True,  # fuser and to_glow should match exactly
    )

    # Interpreter vs Other, via to_glow, fp16-fp16
    compare(
        "m_interpreter_fp16",
        m_interpreter_fp16,
        "m_other_fp16",
        m_other_fp16,
        fp16vfp16_atol,
        fp16vfp16_rtol,
    )

    if not skip_fp32_vs_fp16:
        # JIT vs Interpreter, via to_glow, fp32-fp16
        compare(
            "m_pytorch_fp32",
            m_pytorch_fp32,
            "m_interpreter_fp16",
            m_interpreter_fp16,
            fp32vfp16_atol,
            fp32vfp16_rtol,
        )