def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] identity = IdentityOp(graph, {'name': node.soft_get('name', node.id)}).create_node() node.in_port(0).get_connection().set_destination(identity.in_port(0)) for idx, port in node.out_ports().items(): port.get_connection().set_source(identity.out_port(0))
def extract(cls, node): # some Dropout flavors doesn't have is_test attribute; when it is missing, interpret it as 1 is_test = onnx_attr(node, 'is_test', 'i', 1) if len(node.out_nodes()) > 1: raise Error( 'Dropout node {} has more than one consumer. Unsupported.', node.name) if not is_test: raise Error( 'Dropout node {} has is_test: 0. This means training mode which is not supported.', node.name) IdentityOp.update_node_stat(node) return cls.enabled
def extract(cls, node): IdentityOp.update_node_stat(node, {}) return cls.enabled
def extract(node: Node): IdentityOp.update_node_stat(node, {}) return __class__.enabled
def extract(cls, node: Node): IdentityOp.update_node_stat(node, {'op': 'StopGradient'}) return cls.enabled
def extract(cls, node: Node): IdentityOp.update_node_stat( node, { 'data_type': tf_dtype_extractor(node.pb.attr["T"].type), }) return cls.enabled