def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): beta = match['beta'] mul = match['mul'] mul_beta = match['mul_beta'] mul_name = mul.soft_get('name', mul.id) # determine the input port of Muls which get the 'input' node output mul_beta_input_port_idx = int( mul_beta.in_port(0).get_connection().get_source().node.id == beta.id) mul_input_port_idx = int( mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid') # check that the same tensor provided as input to Mul and MulBeta if mul.in_port(mul_input_port_idx).get_source() != mul_beta.in_port( mul_beta_input_port_idx).get_source(): return swish = Swish(graph, {}).create_node() swish.in_port(0).connect( mul_beta.in_port(mul_beta_input_port_idx).get_source()) # connect Beta value swish.in_port(1).connect( mul_beta.in_port(1 - mul_beta_input_port_idx).get_source()) mul.out_port(0).get_connection().set_source(swish.out_port(0)) rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
def extract(cls, node: Node): Swish.update_node_stat(node, {}) return cls.enabled