def extract(cls, node: Node): axis = onnx_attr(node, 'axes', 'ints', default=None, dst_type=lambda x: int64_array(x)) keep_dims = onnx_attr(node, 'keepdims', 'i', default=True) ReduceSum.update_node_stat(node, { 'axis': axis, 'keep_dims': keep_dims }) return cls.enabled
def replace_op(self, graph: Graph, node: Node): pow_2 = Const(graph, {'value': np.float32(2.0)}).create_node() reduce_axis = Const(graph, {'value': np.int32(-1)}).create_node() pow_0_5 = Const(graph, {'value': np.float32(0.5)}).create_node() sq = Pow(graph, dict(name=node.in_node(0).name + '/sq', power=2.0)).create_node([node.in_node(0), pow_2]) sum = ReduceSum(graph, dict(name=sq.name + '/sum')).create_node( [sq, reduce_axis]) sqrt = Pow(graph, dict(name=sum.name + '/sqrt', power=0.5)).create_node([sum, pow_0_5]) return [sqrt.id]
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') sqr_node = Mul(graph, {}).create_node() reduce_sum_node = ReduceSum( graph, { 'keep_dims': node.soft_get('keep_dims', 0), 'axis': node.soft_get('axis', None) }).create_node() sqrt_node = create_op_with_const_inputs(graph, Pow, {1: float_array(0.5)}) rename_node(sqrt_node, node_name) # Connect nodes node.in_port(0).get_connection().set_destination(sqr_node.in_port(0)) sqr_node.in_port(0).get_connection().add_destination( sqr_node.in_port(1)) sqr_node.out_port(0).connect(reduce_sum_node.in_port(0)) reduce_sum_node.out_port(0).connect(sqrt_node.in_port(0)) return [sqrt_node.id]
def extract(cls, node: Node): ReduceSum.update_node_stat(node, {'keep_dims': node.pb.attr["keep_dims"].b}) return cls.enabled