def __init__(self, module: torch.fx.GraphModule, sample_input: Tuple[torch.Tensor], operator_support: op_support.OperatorSupport = None, settings: splitter_base._SplitterSettingBase = None): if not operator_support: operator_support = op_support.OperatorSupport() if not settings: settings = splitter_base._SplitterSettingBase() settings.allow_non_tensor = True settings.skip_fusion = True super().__init__(module, sample_input, operator_support, settings)
def create_trt_operator_support() -> ops.OperatorSupportBase: """Creates an `OperatorSupportBase` instance used for TRT splitting purpose. """ # Create an `OperatorSupport` that declares a node supported if it # finds a registered TRT converter. supported_if_converter_registered = ops.OperatorSupport( support_dict={get_acc_ops_name(k): None for k in CONVERTERS.keys()}) return ops.chain( # 1. Node is not supported if it has args with int64 dtype: ops.OpSupports.decline_if_input_dtype(torch.int64), # 2. Node is supported if it has TRT converter: supported_if_converter_registered, )
def create_trt_operator_support( use_implicit_batch_dim=True) -> ops.OperatorSupportBase: """Creates an `OperatorSupportBase` instance used for TRT splitting purpose. """ # Create an `OperatorSupport` that declares a node supported if it # finds a registered TRT converter. support_dict: Dict[str, None] = {} for k in CONVERTERS.keys(): if use_implicit_batch_dim: if k not in NO_IMPLICIT_BATCH_DIM_SUPPORT.keys(): support_dict[get_acc_ops_name(k)] = None elif k not in NO_EXPLICIT_BATCH_DIM_SUPPORT.keys(): support_dict[get_acc_ops_name(k)] = None supported_if_converter_registered = ops.OperatorSupport( support_dict=support_dict) return ops.chain( # 1. Node is not supported if it has args with int64 dtype: ops.OpSupports.decline_if_input_dtype(torch.int64), # 2. Node is supported if it has TRT converter: supported_if_converter_registered, )
def op_support_with_support_dict( support_dict: dict) -> op_support.OperatorSupportBase: return op_support.OperatorSupport(support_dict)