def test_set_glow_backend(): """Test setting the Glow backend type""" backend_name_before = torch_glow.getGlowBackendName() backend_num_devices_before = torch_glow.getGlowBackendNumDevices() torch_glow.setGlowBackend("CPU", 4) assert (torch_glow.getGlowBackendName() == "CPU") assert (torch_glow.getGlowBackendNumDevices() == 4) # reset everything torch_glow.setGlowBackend(backend_name_before, backend_num_devices_before)
def ephemeral_torchglow_settings( fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None ): old_fp16 = torch_glow.get_convert_to_fp16() old_clip = torch_glow.get_clip_fp16() old_convert_fused = torch_glow.get_convert_fused_to_fp16() old_backend = torch_glow.getGlowBackendName() old_blocklist = torch_glow.getFusionBlacklist() old_fusion = torch_glow.getFusionPassEnabled() try: if fusion: torch_glow.enableFusionPass() else: torch_glow.disableFusionPass() if fp16: torch_glow.enable_convert_to_fp16() torch_glow.enable_convert_fused_to_fp16() torch_glow.enable_clip_fp16() else: torch_glow.disable_convert_to_fp16() torch_glow.disable_convert_fused_to_fp16() torch_glow.disable_clip_fp16() if blocklist is None: torch_glow.clearFusionBlacklist() else: torch_glow.setFusionBlacklist(list(blocklist)) torch_glow.setGlowBackend(backend) yield finally: torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16() torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16() torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16() torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass() torch_glow.setGlowBackend(old_backend) torch_glow.setFusionBlacklist(old_blocklist)
def run_comparison_tests( module, inputs, fusible_ops, fp32vfp32_atol=5e-4, fp32vfp32_rtol=1e-3, fp32vfp16_atol=1e-2, fp32vfp16_rtol=1e-2, fp16vfp16_atol=5e-4, fp16vfp16_rtol=1e-3, fusion_blocklist=None, scripted=False, check_trace=True, skip_for_backends=DEFAULT_SKIP_BACKENDS_SET, skip_fp32_vs_fp16=False, skip_to_glow=False, # Ugly hack, TODO: Remove ): # tuplify inputs if not isinstance(inputs, tuple): inputs = (inputs, ) # Check that test is setup properly if not isinstance(module, torch.nn.Module): raise AssertionError("to_glow only supports nn.Modules") if "Interpreter" in skip_for_backends: raise AssertionError( "Interpreter backend can't be skipped, skip entire test until Interpreter is supported" ) # If other_backend isn't supported then skip the test other_backend = torch_glow.getGlowBackendName() if other_backend in skip_for_backends: raise unittest.SkipTest( f"backend {other_backend} not supported for this test") # Get other Glow backend besides Interpreter to test if applicable if other_backend == "Interpreter": other_backend = None if skip_to_glow and other_backend: raise AssertionError( f"to_glow must be used for non-interpreter backends, skip this test for {other_backend} backend until the test supports to_glow" ) def prepare(m, inputs, fp16, backend, fusion): """ "Helper to prepare a JIT module to run either on PyTorch or Glow""" inputs = deepcopy(inputs) def getJITModule(): m_jit = None if scripted: m_jit = torch.jit.script(m) else: m_jit = torch.jit.trace(m, inputs, check_trace=check_trace) if scripted or not check_trace: # run it once to activate the fuser if not run yet m_jit(*inputs) return m_jit with torch.no_grad(): m_jit = None if fusion: with ephemeral_torchglow_settings(fusion=True, fp16=fp16, backend=backend, blocklist=fusion_blocklist): m_jit = getJITModule() assert_fused( m_jit.graph_for(*(deepcopy(inputs))), fusible_ops, ) else: m_jit = getJITModule() if backend != "PyTorch": # to_glow m_jit = torch_glow.lower( model=m_jit, example_inputs=inputs, backend=backend, convert_to_fp16=fp16, ) return m_jit def compare(a_name, a, b_name, b, atol, rtol, use_eq=False): """ "Helper to compare two JIT modules, skip comparison if either is None""" if not a: print( f"Skipping {a_name} vs {b_name} because {a_name} not computed") return if not b: print( f"Skipping {a_name} vs {b_name} because {b_name} not computed") return a_ouputs = a(*deepcopy(inputs)) b_ouputs = b(*deepcopy(inputs)) assert_equivalent(a_name, a_ouputs, b_name, b_ouputs, atol, rtol, use_eq) # Prepare modules for testing m_pytorch_fp32 = prepare(module, inputs, fp16=False, backend="PyTorch", fusion=False) m_interpreter_fuser_fp32 = prepare(module, inputs, fp16=False, backend="Interpreter", fusion=True) m_interpreter_fp32 = None m_interpreter_fp16 = None m_other_fp16 = None if not skip_to_glow: m_interpreter_fp32 = prepare(module, inputs, fp16=False, backend="Interpreter", fusion=True) m_interpreter_fp16 = prepare(module, inputs, fp16=True, backend="Interpreter", fusion=True) m_other_fp16 = None if other_backend: m_other_fp16 = prepare(module, inputs, fp16=True, backend=other_backend, fusion=False) # JIT vs Interpreter, via to_glow, fp32-fp32 compare( "m_pytorch_fp32", m_pytorch_fp32, "m_interpreter_fp32", m_interpreter_fp32, fp32vfp32_atol, fp32vfp32_rtol, ) # Interpreter vs Interpreter, via to_glow and fuser, fp32-fp32 compare( "m_interpreter_fp32", m_interpreter_fp32, "m_interpreter_fuser_fp32", m_interpreter_fuser_fp32, fp32vfp32_atol, fp32vfp32_rtol, use_eq=True, # fuser and to_glow should match exactly ) # Interpreter vs Other, via to_glow, fp16-fp16 compare( "m_interpreter_fp16", m_interpreter_fp16, "m_other_fp16", m_other_fp16, fp16vfp16_atol, fp16vfp16_rtol, ) if not skip_fp32_vs_fp16: # JIT vs Interpreter, via to_glow, fp32-fp16 compare( "m_pytorch_fp32", m_pytorch_fp32, "m_interpreter_fp16", m_interpreter_fp16, fp32vfp16_atol, fp32vfp16_rtol, )