Ejemplo n.º 1
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        shape = node.in_port(0).data.get_shape().copy()

        assert shape[1] % node.group == 0

        power_node = Power(graph, attrs={'name': node.id + '_power',
                                         'power': node.p}).create_node()

        reshape_node = create_op_node_with_second_input(graph, Reshape,
                                                        int64_array([shape[0], shape[1] / node.group, node.group]),
                                                        {'name': node.id + '_reshape'})
        reshape_node.in_port(0).connect(power_node.out_port(0))

        reducesum_node = create_op_node_with_second_input(graph, ReduceSum,
                                                          int64_array([2]),
                                                          {'name': node.id + '_sum', 'keep_dims': False})
        reducesum_node.in_port(0).connect(reshape_node.out_port(0))

        invpower_node = Power(graph, attrs={'name': node.id + '_invpower',
                                            'power': 1.0 / node.p}).create_node()
        invpower_node.in_port(0).connect(reducesum_node.out_port(0))

        node.in_port(0).get_connection().set_destination(power_node.in_port(0))
        node.out_port(0).get_connection().set_source(invpower_node.out_port(0))
Ejemplo n.º 2
0
    def replace_op(self, graph: Graph, node: Node):
        reciprocal = Power(graph, {'scale': 1, 'power': np.float64(-1), 'shift': 0,
                                   'name': node.name + '/reciprocal_'}).create_node()
        mul = Eltwise(graph, {'operation': 'mul', 'name': node.name + '/mul_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(reciprocal.in_port(0))
        node.in_port(0).get_connection().set_destination(mul.in_port(1))
        reciprocal.out_port(0).connect(mul.in_port(0))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [mul.id]
Ejemplo n.º 3
0
    def replace_op(self, graph: Graph, node: Node):
        power = Power(graph, dict(scale=0, name=node.name + '/Power/')).create_node()

        # Reconnecting inputs to this new node
        node.in_port(0).get_connection().set_destination(power.in_port(0))
        node.out_port(0).get_connection().set_source(power.out_port(0))
        return [power.id]
    def replace_pattern(graph: Graph, match: [str, Node]):
        consumers = [
            n for n in match if n not in ['mul_add', 'pow_d']
            and not check_node_usages_out_of_match(match, n)
        ]
        if consumers:
            log.warning(
                'Power(mul_add,pow) pattern was detected. Non pattern consumers of nodes: "{}" were found.'
                ' Won\'t replace'.format(', '.join(
                    [match[n].id for n in consumers])))
            return
        mul_add = match['mul_add']
        pow = match['pow']
        new_power = Power(
            graph, {
                'name': mul_add.name + '/fused_power',
                'shift': mul_add.shift,
                'scale': mul_add.scale,
                'power': pow.power
            }).create_node()

        source = mul_add.in_port(0).get_connection().get_source()
        mul_add.in_port(0).disconnect()
        new_power.in_port(0).connect(source)
        pow.out_port(0).get_connection().set_source(new_power.out_port(0))

        log.debug(
            'Power nodes {} and {} were fused to single Power node {}'.format(
                mul_add.name, pow.name, new_power.name))
Ejemplo n.º 5
0
    def replace_pattern(self, graph: Graph, match: dict):

        const = 0.99
        merge = match['merge']
        digits = significant_digits()
        pnorm = Power(
            graph, {
                'name': merge.name + '/reciprocal_',
                'type': 'PNORM',
                'significant': digits[0],
                'to_significant': digits[1],
                'scale': 1,
                'shift': 0,
                'power': get_power_attr()
            }).create_node()
        merge.in_port(0).get_connection().set_destination(pnorm.in_port(0))

        in_shape = pnorm.in_port(0).data.get_shape()
        in_shape = list(in_shape)
        in_shape.insert(0, 1)

        reshape1 = Reshape(graph, {
            'name': merge.name + '/Reshape_Node1'
        }).create_node()
        reshape_dim1 = Const(graph, {
            'value': np.array(in_shape),
            'name': merge.id + '/Reshape_Dim1'
        }).create_node()
        pnorm.out_port(0).connect(reshape1.in_port(0))
        reshape1.in_port(1).connect(reshape_dim1.out_port(0))

        concat_node = Concat(
            graph, {
                'axis': 0,
                'name': merge.name + '/Concat_',
                'override_output_shape': True
            }).create_node()
        const3 = Const(graph, {
            'name': merge.name + '/const_reduce',
            'value': 0
        }).create_node()

        for ii, idx in enumerate(
                range(pnorm.significant, pnorm.to_significant + 1, 1)):
            const_node = Const(
                graph, {
                    'value': float_array(math.pow(const, idx)),
                    'name': merge.name + '/Const_' + ii.__str__()
                }).create_node()

            mul_node = Mul(graph, {
                'name': merge.name + '/Mul_' + ii.__str__()
            }).create_node()
            const_node.out_port(0).connect(mul_node.in_port(0))

            reshape1.out_port(0).connect(
                mul_node.in_port(1))  # connect to the graph node
            mul_node2 = Mul(graph, {
                'name': merge.name + '/Mul_Div_' + ii.__str__()
            }).create_node()

            const_node2 = Const(
                graph, {
                    'value': float_array(math.pow(const, -1 * idx)),
                    'name': merge.name + '/Const_Pow_' + ii.__str__()
                }).create_node()
            cast_node = ExpOp(graph, {
                'name': merge.name + '/Exp_' + idx.__str__()
            }).create_node()

            mul_node.out_port(0).connect(cast_node.in_port(0))
            const_node2.out_port(0).connect(mul_node2.in_port(1))
            cast_node.out_port(0).connect(mul_node2.in_port(0))
            concat_node.add_input_port(ii, skip_if_exist=True)
            concat_node.in_port(ii).get_connection().set_source(
                mul_node2.out_port(0))

        in_shape = pnorm.in_port(0).data.get_shape()
        in_shape = list(in_shape)

        reducesum_node = ReduceMean(
            graph, {
                'name': merge.id + '/_pnorm_reduced_sum',
                'keep_dims': True,
                'in_ports_count': 2,
                'shape': in_shape,
                'axis': 0,
                'need_shape_inference': None,
                'infer': reduce_infer
            }).create_node()

        const3.out_port(0).connect(reducesum_node.in_port(1))
        reducesum_node.in_port(0).get_connection().set_source(
            concat_node.out_port(0))

        reshape = Reshape(graph, {
            'name': merge.name + '/Reshape_Node'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': np.array(in_shape),
            'name': merge.id + '/Reshape_Dim'
        }).create_node()
        reducesum_node.out_port(0).connect(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_dim.out_port(0))
        merge.out_port(0).get_connection().set_source(reshape.out_port(0))