Example #1
0
def test_lowering(
    module,
    *inputs,
    fusible_ops=None,
    fusion_blocklist=None,
    fp16=False,
    scripted=False,
    check_trace=True,
    accept_all_layouts=False,
):
    if not isinstance(module, torch.nn.Module):
        raise AssertionError("to_glow only supports nn.Modules")

    def trace(mod, ins):
        if scripted:
            return torch.jit.script(mod)
        else:
            return torch.jit.trace(mod, ins, check_trace=check_trace)

    with torch.no_grad():
        with ephemeral_torchglow_settings(
                fusion=False, fp16=fp16,
                accept_all_layouts=accept_all_layouts):
            glow_inputs = deepcopy(inputs)
            traced_module = trace(module, glow_inputs)
            # If deferred weight loader is not set, it will throw a runtime exception
            _lowered_module = torch_glow.lower(traced_module, glow_inputs,
                                               DEFAULT_BACKEND)  # unused
Example #2
0
def compare_tracing_methods_error(
    module,
    *inputs,
    fusible_ops=None,
    fusion_blocklist=None,
    fp16=False,
):
    if not isinstance(module, torch.nn.Module):
        raise AssertionError("to_glow only supports nn.Modules")

    def trace(mod, ins):
        return torch.jit.trace(mod, ins)

    with torch.no_grad():
        with ephemeral_torchglow_settings(fusion=True,
                                          fp16=fp16,
                                          blocklist=fusion_blocklist):
            fusion_inputs = deepcopy(inputs)
            try:
                fusion_trace = trace(module, fusion_inputs)
                assert_fused(
                    fusion_trace.graph_for(*fusion_inputs),
                    *(fusible_ops or []),
                    accept_any=fusible_ops is None,
                )
                fusion_trace(*fusion_inputs)
            except Exception:
                pass
            else:
                raise AssertionError(
                    "Error expected (fusion), but none were received")
        with ephemeral_torchglow_settings(fusion=False, fp16=fp16):
            try:
                torchscript_inputs = deepcopy(inputs)
                torchscript_trace = trace(module, torchscript_inputs)
                torchscript_trace(*torchscript_inputs)
            except Exception:
                pass
            else:
                raise AssertionError(
                    "Error expected (torchscript), but none were received")
        with ephemeral_torchglow_settings(fusion=False, fp16=fp16):
            try:
                glow_inputs = deepcopy(inputs)
                glow_spec = torch_glow.lower(
                    model=module,
                    example_inputs=glow_inputs,
                    backend=DEFAULT_BACKEND,
                )
                glow_trace = torch_glow.to_glow(trace(module, glow_inputs),
                                                glow_spec)
                glow_trace(*glow_inputs)
            except Exception:
                pass
            else:
                raise AssertionError(
                    "Error expected (glow), but none were received")
Example #3
0
    def prepare(m, inputs, fp16, backend, fusion):
        """ "Helper to prepare a JIT module to run either on PyTorch or Glow"""

        inputs = deepcopy(inputs)

        def getJITModule():
            m_jit = None
            if scripted:
                m_jit = torch.jit.script(m)
            else:
                m_jit = torch.jit.trace(m, inputs, check_trace=check_trace)
            if scripted or not check_trace:
                # run it once to activate the fuser if not run yet
                m_jit(*inputs)
            return m_jit

        with torch.no_grad():
            m_jit = None
            if fusion:
                with ephemeral_torchglow_settings(fusion=True,
                                                  fp16=fp16,
                                                  backend=backend,
                                                  blocklist=fusion_blocklist):
                    m_jit = getJITModule()
                    assert_fused(
                        m_jit.graph_for(*(deepcopy(inputs))),
                        fusible_ops,
                    )
            else:
                m_jit = getJITModule()

            if backend != "PyTorch":  # to_glow
                m_jit = torch_glow.lower(
                    model=m_jit,
                    example_inputs=inputs,
                    backend=backend,
                    convert_to_fp16=fp16,
                )

            return m_jit
Example #4
0
def compare_tracing_methods(
    module,
    *inputs,
    atol=5e-4,
    rtol=1e-3,
    reference=None,
    fusible_ops=None,
    fusion_blocklist=None,
    fp16=False,
    scripted=False,
    check_trace=True,
    accept_all_layouts=False,
    skip_to_glow=False,  # Ugly hack, TODO: Remove
):
    if not isinstance(module, torch.nn.Module):
        raise AssertionError("to_glow only supports nn.Modules")

    def trace(mod, ins):
        if scripted:
            return torch.jit.script(mod)
        else:
            return torch.jit.trace(mod, ins, check_trace=check_trace)

    with torch.no_grad():
        with ephemeral_torchglow_settings(
            fusion=True,
            fp16=fp16,
            blocklist=fusion_blocklist,
            accept_all_layouts=accept_all_layouts,
        ):
            fusion_inputs = deepcopy(inputs)
            fusion_trace = trace(module, fusion_inputs)
            assert_fused(
                fusion_trace.graph_for(*fusion_inputs),
                *(fusible_ops or []),
                accept_any=fusible_ops is None,
            )
            fusion_result = fusion_trace(*fusion_inputs)
        with ephemeral_torchglow_settings(
            fusion=False, fp16=fp16, accept_all_layouts=accept_all_layouts
        ):
            if scripted:
                torchscript_result = module(*deepcopy(inputs))
            else:
                torchscript_inputs = deepcopy(inputs)
                torchscript_trace = trace(module, torchscript_inputs)
                torchscript_result = torchscript_trace(*torchscript_inputs)
        with ephemeral_torchglow_settings(
            fusion=False, fp16=fp16, accept_all_layouts=accept_all_layouts
        ):
            if not skip_to_glow:
                glow_inputs = deepcopy(inputs)
                traced_module = trace(module, glow_inputs)
                lowered_module = torch_glow.lower(
                    traced_module, glow_inputs, DEFAULT_BACKEND
                )
                glow_result = lowered_module(*glow_inputs)
        if reference:
            assert_equivalent(
                "Reference",
                reference,
                "Glow fusion",
                fusion_trace,
                atol=atol,
                rtol=rtol,
            )
            assert_equivalent(
                "Reference",
                reference,
                "TorchScript",
                torchscript_result,
                atol=atol,
                rtol=rtol,
            )
            if not skip_to_glow:
                assert_equivalent(
                    "Reference", reference, "Glow", glow_result, atol=atol, rtol=rtol
                )
        # This is written out manually instead of using combinations in order to aid
        # debugging. TODO: Clean up.
        assert_equivalent(
            "Glow fusion",
            fusion_result,
            "TorchScript",
            torchscript_result,
            atol=atol,
            rtol=rtol,
        )
        if not skip_to_glow:
            assert_equivalent(
                "Glow fusion", fusion_result, "Glow", glow_result, atol=atol, rtol=rtol
            )
            assert_equivalent(
                "TorchScript",
                torchscript_result,
                "Glow",
                glow_result,
                atol=atol,
                rtol=rtol,
            )