def is_node_supported(self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node): """ Here we want linear layer to not run on TensorRT. Thus, we return False for linear layer and True for all other ops. """ target = op_support.get_node_target(submodules, node) if target == "torch.nn.modules.linear.Linear": return False return True
def is_node_supported(self, submodules, node): return op_support.get_node_target(submodules, node) != "acc_ops.sub"