def test_lowering( module, *inputs, fusible_ops=None, fusion_blocklist=None, fp16=False, scripted=False, check_trace=True, accept_all_layouts=False, ): if not isinstance(module, torch.nn.Module): raise AssertionError("to_glow only supports nn.Modules") def trace(mod, ins): if scripted: return torch.jit.script(mod) else: return torch.jit.trace(mod, ins, check_trace=check_trace) with torch.no_grad(): with ephemeral_torchglow_settings( fusion=False, fp16=fp16, accept_all_layouts=accept_all_layouts): glow_inputs = deepcopy(inputs) traced_module = trace(module, glow_inputs) # If deferred weight loader is not set, it will throw a runtime exception _lowered_module = torch_glow.lower(traced_module, glow_inputs, DEFAULT_BACKEND) # unused
def compare_tracing_methods_error( module, *inputs, fusible_ops=None, fusion_blocklist=None, fp16=False, ): if not isinstance(module, torch.nn.Module): raise AssertionError("to_glow only supports nn.Modules") def trace(mod, ins): return torch.jit.trace(mod, ins) with torch.no_grad(): with ephemeral_torchglow_settings(fusion=True, fp16=fp16, blocklist=fusion_blocklist): fusion_inputs = deepcopy(inputs) try: fusion_trace = trace(module, fusion_inputs) assert_fused( fusion_trace.graph_for(*fusion_inputs), *(fusible_ops or []), accept_any=fusible_ops is None, ) fusion_trace(*fusion_inputs) except Exception: pass else: raise AssertionError( "Error expected (fusion), but none were received") with ephemeral_torchglow_settings(fusion=False, fp16=fp16): try: torchscript_inputs = deepcopy(inputs) torchscript_trace = trace(module, torchscript_inputs) torchscript_trace(*torchscript_inputs) except Exception: pass else: raise AssertionError( "Error expected (torchscript), but none were received") with ephemeral_torchglow_settings(fusion=False, fp16=fp16): try: glow_inputs = deepcopy(inputs) glow_spec = torch_glow.lower( model=module, example_inputs=glow_inputs, backend=DEFAULT_BACKEND, ) glow_trace = torch_glow.to_glow(trace(module, glow_inputs), glow_spec) glow_trace(*glow_inputs) except Exception: pass else: raise AssertionError( "Error expected (glow), but none were received")
def prepare(m, inputs, fp16, backend, fusion): """ "Helper to prepare a JIT module to run either on PyTorch or Glow""" inputs = deepcopy(inputs) def getJITModule(): m_jit = None if scripted: m_jit = torch.jit.script(m) else: m_jit = torch.jit.trace(m, inputs, check_trace=check_trace) if scripted or not check_trace: # run it once to activate the fuser if not run yet m_jit(*inputs) return m_jit with torch.no_grad(): m_jit = None if fusion: with ephemeral_torchglow_settings(fusion=True, fp16=fp16, backend=backend, blocklist=fusion_blocklist): m_jit = getJITModule() assert_fused( m_jit.graph_for(*(deepcopy(inputs))), fusible_ops, ) else: m_jit = getJITModule() if backend != "PyTorch": # to_glow m_jit = torch_glow.lower( model=m_jit, example_inputs=inputs, backend=backend, convert_to_fp16=fp16, ) return m_jit
def compare_tracing_methods( module, *inputs, atol=5e-4, rtol=1e-3, reference=None, fusible_ops=None, fusion_blocklist=None, fp16=False, scripted=False, check_trace=True, accept_all_layouts=False, skip_to_glow=False, # Ugly hack, TODO: Remove ): if not isinstance(module, torch.nn.Module): raise AssertionError("to_glow only supports nn.Modules") def trace(mod, ins): if scripted: return torch.jit.script(mod) else: return torch.jit.trace(mod, ins, check_trace=check_trace) with torch.no_grad(): with ephemeral_torchglow_settings( fusion=True, fp16=fp16, blocklist=fusion_blocklist, accept_all_layouts=accept_all_layouts, ): fusion_inputs = deepcopy(inputs) fusion_trace = trace(module, fusion_inputs) assert_fused( fusion_trace.graph_for(*fusion_inputs), *(fusible_ops or []), accept_any=fusible_ops is None, ) fusion_result = fusion_trace(*fusion_inputs) with ephemeral_torchglow_settings( fusion=False, fp16=fp16, accept_all_layouts=accept_all_layouts ): if scripted: torchscript_result = module(*deepcopy(inputs)) else: torchscript_inputs = deepcopy(inputs) torchscript_trace = trace(module, torchscript_inputs) torchscript_result = torchscript_trace(*torchscript_inputs) with ephemeral_torchglow_settings( fusion=False, fp16=fp16, accept_all_layouts=accept_all_layouts ): if not skip_to_glow: glow_inputs = deepcopy(inputs) traced_module = trace(module, glow_inputs) lowered_module = torch_glow.lower( traced_module, glow_inputs, DEFAULT_BACKEND ) glow_result = lowered_module(*glow_inputs) if reference: assert_equivalent( "Reference", reference, "Glow fusion", fusion_trace, atol=atol, rtol=rtol, ) assert_equivalent( "Reference", reference, "TorchScript", torchscript_result, atol=atol, rtol=rtol, ) if not skip_to_glow: assert_equivalent( "Reference", reference, "Glow", glow_result, atol=atol, rtol=rtol ) # This is written out manually instead of using combinations in order to aid # debugging. TODO: Clean up. assert_equivalent( "Glow fusion", fusion_result, "TorchScript", torchscript_result, atol=atol, rtol=rtol, ) if not skip_to_glow: assert_equivalent( "Glow fusion", fusion_result, "Glow", glow_result, atol=atol, rtol=rtol ) assert_equivalent( "TorchScript", torchscript_result, "Glow", glow_result, atol=atol, rtol=rtol, )