Exemplo n.º 1
0
    def __init__(self,
                 module: torch.fx.GraphModule,
                 sample_input: Tuple[torch.Tensor],
                 operator_support: op_support.OperatorSupport = None,
                 settings: splitter_base._SplitterSettingBase = None):
        if not operator_support:
            operator_support = op_support.OperatorSupport()

        if not settings:
            settings = splitter_base._SplitterSettingBase()
            settings.allow_non_tensor = True
            settings.skip_fusion = True

        super().__init__(module, sample_input, operator_support, settings)
Exemplo n.º 2
0
def create_trt_operator_support() -> ops.OperatorSupportBase:
    """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.
    """
    # Create an `OperatorSupport` that declares a node supported if it
    # finds a registered TRT converter.
    supported_if_converter_registered = ops.OperatorSupport(
        support_dict={get_acc_ops_name(k): None
                      for k in CONVERTERS.keys()})

    return ops.chain(
        # 1. Node is not supported if it has args with int64 dtype:
        ops.OpSupports.decline_if_input_dtype(torch.int64),
        # 2. Node is supported if it has TRT converter:
        supported_if_converter_registered,
    )
Exemplo n.º 3
0
def create_trt_operator_support(
        use_implicit_batch_dim=True) -> ops.OperatorSupportBase:
    """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.
    """
    # Create an `OperatorSupport` that declares a node supported if it
    # finds a registered TRT converter.
    support_dict: Dict[str, None] = {}
    for k in CONVERTERS.keys():
        if use_implicit_batch_dim:
            if k not in NO_IMPLICIT_BATCH_DIM_SUPPORT.keys():
                support_dict[get_acc_ops_name(k)] = None
        elif k not in NO_EXPLICIT_BATCH_DIM_SUPPORT.keys():
            support_dict[get_acc_ops_name(k)] = None
    supported_if_converter_registered = ops.OperatorSupport(
        support_dict=support_dict)

    return ops.chain(
        # 1. Node is not supported if it has args with int64 dtype:
        ops.OpSupports.decline_if_input_dtype(torch.int64),
        # 2. Node is supported if it has TRT converter:
        supported_if_converter_registered,
    )
Exemplo n.º 4
0
def op_support_with_support_dict(
        support_dict: dict) -> op_support.OperatorSupportBase:
    return op_support.OperatorSupport(support_dict)