Example #1
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        beta = match['beta']
        mul = match['mul']
        mul_beta = match['mul_beta']
        mul_name = mul.soft_get('name', mul.id)

        # determine the input port of Muls which get the 'input' node output
        mul_beta_input_port_idx = int(
            mul_beta.in_port(0).get_connection().get_source().node.id ==
            beta.id)
        mul_input_port_idx = int(
            mul.in_port(0).get_connection().get_source().node.soft_get('op') ==
            'Sigmoid')

        # check that the same tensor provided as input to Mul and MulBeta
        if mul.in_port(mul_input_port_idx).get_source() != mul_beta.in_port(
                mul_beta_input_port_idx).get_source():
            return

        swish = Swish(graph, {}).create_node()
        swish.in_port(0).connect(
            mul_beta.in_port(mul_beta_input_port_idx).get_source())

        # connect Beta value
        swish.in_port(1).connect(
            mul_beta.in_port(1 - mul_beta_input_port_idx).get_source())

        mul.out_port(0).get_connection().set_source(swish.out_port(0))

        rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)])
 def extract(cls, node: Node):
     Swish.update_node_stat(node, {})
     return cls.enabled