Esempio n. 1
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5
                or node.in_node(0).value is not None or  # input
                node.in_node(1).value is None or  # scale
                node.in_node(2).value is None or  # offset
                node.in_node(3).value is not None or  # mean
                node.in_node(4).value is not None or  # variance
                node.in_node(1).value.ndim != 1 or
                node.in_node(2).value.ndim != 1):
            return

        scale_mul = Eltwise(
            graph, dict(operation='mul', name=node.name + '/scale_mul_'))
        shift_add = Eltwise(
            graph, dict(operation='sum', name=node.name + '/shift_add_'))
        mean_add = Eltwise(
            graph, dict(operation='sum', name=node.name + '/mean_add_'))
        variance_mul = Eltwise(
            graph, dict(operation='mul', name=node.name + '/variance_mul_'))

        mean_negate = Power(graph,
                            dict(scale=-1, name=node.name + '/mean_negate_'))
        mean_arg = mean_add.create_node_with_data([
            node.in_node(0),
            mean_negate.create_node_with_data([node.in_node(3)])
        ])

        variance_square = Power(
            graph, dict(power=2, name=node.name + '/variance_square_'))
        variance_denom = Power(
            graph,
            dict(shift=node.eps,
                 power=-0.5,
                 name=node.name + '/variance_denom_'))
        variance_arg = variance_mul.create_node_with_data([
            mean_arg,
            variance_denom.create_node_with_data([node.in_node(4)])
        ])

        shift_add.create_node_with_data([
            scale_mul.create_node_with_data([variance_arg,
                                             node.in_node(1)]),
            node.in_node(2)
        ],
                                        data_nodes=node.out_node())

        node.graph.remove_node(node.id)
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['minimum']
        # Constant propagation case
        if node.in_node(0).value is not None and node.in_node(1).value is not None:
            return

        negate_1 = Power(graph, dict(scale=-1, name=node.name + '/negate1_'))
        negate_2 = Power(graph, dict(scale=-1, name=node.name + '/negate2_'))
        maximum = Eltwise(graph, dict(operation='max', name=node.name + '/Max_'))
        negate_output = Power(graph, dict(scale=-1, name=node.name + '/negate_out_'))

        negate_output.create_node_with_data(
            inputs=[maximum.create_node_with_data([negate_1.create_node_with_data([node.in_node(0)]),
                                                   negate_2.create_node_with_data([node.in_node(1)])])],
            data_nodes=node.out_node())
        # Delete minimum vertex
        node.graph.remove_node(node.id)