def get_act_layer(name='relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if name in _OVERRIDE_LAYER: return _OVERRIDE_LAYER[name] no_me = config.is_exportable() or config.is_scriptable() or config.is_no_jit() if not no_me and name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] no_jit = config.is_exportable() or config.is_no_jit() if not no_jit and name in _ACT_LAYER_JIT: # jit scripted models should be okay for export/scripting return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name]
def get_act_fn(name='relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if name in _OVERRIDE_FN: return _OVERRIDE_FN[name] no_me = config.is_exportable() or config.is_scriptable() or config.is_no_jit() if not no_me and name in _ACT_FN_ME: # If not exporting or scripting the model, first look for a memory optimized version # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin return _ACT_FN_ME[name] no_jit = config.is_exportable() or config.is_no_jit() # NOTE: export tracing should work with jit scripted components, but I keep running into issues if no_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if name in _OVERRIDE_LAYER: return _OVERRIDE_LAYER[name] use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) if use_me and name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] if config.is_exportable() and name in ('silu', 'swish'): # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack return Swish use_jit = not (config.is_exportable() or config.is_no_jit()) # NOTE: export tracing should work with jit scripted components, but I keep running into issues if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name]