Exemplo n.º 1
0
 def extract(cls, node: Node):
     axis = onnx_attr(node,
                      'axes',
                      'ints',
                      default=None,
                      dst_type=lambda x: int64_array(x))
     keep_dims = onnx_attr(node, 'keepdims', 'i', default=True)
     ReduceSum.update_node_stat(node, {
         'axis': axis,
         'keep_dims': keep_dims
     })
     return cls.enabled
Exemplo n.º 2
0
    def replace_op(self, graph: Graph, node: Node):
        pow_2 = Const(graph, {'value': np.float32(2.0)}).create_node()
        reduce_axis = Const(graph, {'value': np.int32(-1)}).create_node()
        pow_0_5 = Const(graph, {'value': np.float32(0.5)}).create_node()

        sq = Pow(graph, dict(name=node.in_node(0).name + '/sq',
                             power=2.0)).create_node([node.in_node(0), pow_2])
        sum = ReduceSum(graph, dict(name=sq.name + '/sum')).create_node(
            [sq, reduce_axis])
        sqrt = Pow(graph, dict(name=sum.name + '/sqrt',
                               power=0.5)).create_node([sum, pow_0_5])
        return [sqrt.id]
Exemplo n.º 3
0
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)

        rename_node(node, node_name + '/TBR')
        sqr_node = Mul(graph, {}).create_node()
        reduce_sum_node = ReduceSum(
            graph, {
                'keep_dims': node.soft_get('keep_dims', 0),
                'axis': node.soft_get('axis', None)
            }).create_node()
        sqrt_node = create_op_with_const_inputs(graph, Pow,
                                                {1: float_array(0.5)})
        rename_node(sqrt_node, node_name)

        # Connect nodes
        node.in_port(0).get_connection().set_destination(sqr_node.in_port(0))
        sqr_node.in_port(0).get_connection().add_destination(
            sqr_node.in_port(1))
        sqr_node.out_port(0).connect(reduce_sum_node.in_port(0))
        reduce_sum_node.out_port(0).connect(sqrt_node.in_port(0))

        return [sqrt_node.id]
Exemplo n.º 4
0
 def extract(cls, node: Node):
     ReduceSum.update_node_stat(node,
                                {'keep_dims': node.pb.attr["keep_dims"].b})
     return cls.enabled