Beispiel #1
0
 def _create_sub(graph: Graph, input_1: Node, port_1: int, input_2: Node,
                 port_2: int):
     negate = Power(graph, dict(scale=-1, name=input_2.name + '/negate_'))
     add = Eltwise(graph, dict(operation='sum',
                               name=input_1.name + '/add_'))
     out_node = add.create_node([(input_1, port_1),
                                 negate.create_node([(input_2, port_2)])])
     return out_node
Beispiel #2
0
 def replace_op(self, graph: nx.MultiDiGraph, node: Node):
     negate = Power(graph, dict(scale=-1, name=node.name + '/negate_'))
     add = Eltwise(graph, dict(operation='sum', name=node.name + '/add_'))
     out_node = add.create_node([
         (node.in_node(0), node.in_edge(0)['out']),
         negate.create_node([(node.in_node(1), node.in_edge(1)['out'])])
     ])
     # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
     # The "explicit" version of the return value is: [(out_node.id, 0)])
     return [out_node.id]
Beispiel #3
0
 def replace_op(self, graph: Graph, node: Node):
     out_node = node.in_node(0)
     operation = node.operation
     for ind in range(1, len(node.in_nodes())):
         eltwise_op = Eltwise(
             graph,
             dict(operation=operation,
                  name=node.name + '/' + operation + '_' + str(ind)))
         out_node = eltwise_op.create_node([out_node, node.in_node(ind)])
     return [out_node.id]
Beispiel #4
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        fbn = match['fbn']
        input = fbn.in_node(0)
        log.debug('Found potential MVN pattern after {} with name {}'.format(
            input.op, input.name))
        if input.id != match['mean'].in_node(
                0).id or input.id != match['sqdiff'].in_node(0).id:
            return

        log.debug('Confirmed MVN pattern after {} with name {}'.format(
            input.op, input.name))
        MVN = Op.get_op_class_by_name('MVN')

        mvn = MVN(
            graph,
            dict(name=fbn.name + '/MVN_',
                 eps=fbn.eps,
                 required_reduction_indices=[1, 2]
                 if fbn.data_format == b'NHWC' else [2, 3]))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mul = Eltwise(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
        add = Eltwise(graph, dict(operation='sum', name=fbn.name + '/Add_'))

        input_gamma = fbn.in_node(1)
        input_beta = fbn.in_node(2)

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)

        new_subgraph = add.create_node([
            mul.create_node([
                mvn.create_node([input, mean_reduction, variance_reduction]),
                input_gamma
            ]), input_beta
        ])
        fbn.replace_node(new_subgraph)
Beispiel #5
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        reciprocal = Power(
            graph,
            dict(scale=1,
                 power=np.float64(-1),
                 shift=0,
                 name=node.name + '/reciprocal_'))
        mul = Eltwise(graph, dict(operation='mul', name=node.name + '/mul_'))

        out_node = mul.create_node([
            (node.in_node(0), node.in_edge(0)['out']),
            reciprocal.create_node([(node.in_node(1), node.in_edge(1)['out'])])
        ])
        # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [out_node.id]