Exemplo n.º 1
0
    def replace_op(self, graph: Graph, node: Node):
        negate = Power(graph, dict(scale=-1, name=node.name + '/negate_'))
        add = Eltwise(graph, dict(operation='sum', name=node.name + '/add_'))
        squared = Power(graph, dict(power=2, name=node.name + '/squared_'))

        out_node = squared.create_node([
            add.create_node(
                [node.in_node(0),
                 negate.create_node([node.in_node(1)])])
        ])
        # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [out_node.id]
Exemplo n.º 2
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        reciprocal = Power(
            graph, dict(scale=1, power=-1, shift=0,
                        name=node.name + '/power_'))
        out_node = reciprocal.create_node([node.in_node(0)])

        return [out_node.id]
Exemplo n.º 3
0
 def _create_sub(graph: Graph, input_1: Node, port_1: int, input_2: Node,
                 port_2: int):
     negate = Power(graph, dict(scale=-1, name=input_2.name + '/negate_'))
     add = Eltwise(graph, dict(operation='sum',
                               name=input_1.name + '/add_'))
     out_node = add.create_node([(input_1, port_1),
                                 negate.create_node([(input_2, port_2)])])
     return out_node
Exemplo n.º 4
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        reciprocal = Power(
            graph,
            dict(scale=1,
                 power=np.float64(-1),
                 shift=0,
                 name=node.name + '/reciprocal_'))
        mul = Eltwise(graph, dict(operation='mul', name=node.name + '/mul_'))

        out_node = mul.create_node([
            (node.in_node(0), node.in_edge(0)['out']),
            reciprocal.create_node([(node.in_node(1), node.in_edge(1)['out'])])
        ])
        # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [out_node.id]