示例#1
0
 def version_9(cls, ctx, node, **kwargs):
     node.type = "ConstantOfShape"
     # both shape and value in tensorflow are passed as tensor.
     # In onnx the value is an attribute so we need to fetch the value as const which
     # sooner or later will be a problem for tensorflow-onnx.
     # ConstantOfShape in onnxruntime only support int64, so insert cast op
     input_dtype_is_int64 = utils.map_onnx_to_numpy_type(
         ctx.get_dtype(node.input[0])) == np.int64
     if not input_dtype_is_int64:
         ctx.insert_new_node_on_input(node,
                                      "Cast",
                                      node.input[0],
                                      to=onnx_pb.TensorProto.INT64)
     dtype = ctx.get_dtype(node.output[0])
     value = np.array([node.inputs[1].get_tensor_value()
                       ]).astype(utils.map_onnx_to_numpy_type(dtype))
     value_proto = numpy_helper.from_array(value)
     node.set_attr("value", value_proto)
     ctx.remove_input(node, node.input[1], 1)
示例#2
0
    def any_version(cls, opset, ctx, node, **kwargs):
        """
        Computes the modules of a complex.
        If the matrix dtype is not complex64 or complex128,
        it assumes the first dimension means real part (0)
        and imaginary part (1, :, :...).
        """
        supported_dtypes = [
            onnx_pb.TensorProto.FLOAT,
            onnx_pb.TensorProto.FLOAT16,
            onnx_pb.TensorProto.DOUBLE,
            onnx_pb.TensorProto.COMPLEX64,
            onnx_pb.TensorProto.COMPLEX128,
        ]
        onnx_dtype = ctx.get_dtype(node.input[0])
        utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.")
        shape = ctx.get_shape(node.input[0])
        np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
        utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape)

        ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64))
        ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64))
        p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype))

        real_part = ctx.make_node(
            'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0),
            name=utils.make_name('Real_' + node.name))
        imag_part = ctx.make_node(
            'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0),
            name=utils.make_name('Imag_' + node.name))

        real_part2 = ctx.make_node(
            'Pow', inputs=[real_part.output[0], p2.name],
            name=utils.make_name(real_part.name + 'p2p'))

        imag_part2 = ctx.make_node(
            'Pow', inputs=[imag_part.output[0], p2.name],
            name=utils.make_name(imag_part.name + 'p2p'))

        ctx.remove_node(node.name)
        add = ctx.make_node(
            "Add", inputs=[real_part2.output[0], imag_part2.output[0]],
            name=utils.make_name('ComplexAbs_' + node.name))

        squeezed = GraphBuilder(ctx).make_squeeze(
            {'data': add.output[0], 'axes': [0]}, name=utils.make_name('ComplexAbs' + node.name), return_node=True)

        last_node = ctx.make_node(
            "Sqrt", inputs=squeezed.output[:1],
            name=utils.make_name('ComplexAbs' + node.name),
            shapes=[shape[1:]], dtypes=[onnx_dtype])

        ctx.replace_all_inputs(node.output[0], last_node.output[0])  # ops=ctx.get_nodes()
示例#3
0
 def version_10(cls, ctx, node, **kwargs):
     scale = node.get_attr_value('scale')
     zero_point = node.get_attr_value('zero_point')
     axis = node.get_attr_value('quantized_dimension')
     np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.output[0]))
     if len(scale) > 1 or len(zero_point) > 1:
         utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization for node %s", node.name)
         node.set_attr("axis", axis)
     scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32))
     zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type))
     ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]])
     del node.attr["scale"]
     del node.attr["zero_point"]
     del node.attr["quantized_dimension"]
     if "min" in node.attr:
         del node.attr["min"]
     if "max" in node.attr:
         del node.attr["max"]
示例#4
0
def get_weights_from_const_node(g, node):
    temp = node
    val = None
    # this would help ignore Identity in non-const_folded graph.
    while temp.type == 'Identity':
        temp = temp.inputs[0]

    if temp and temp.type == 'Const':
        val = temp.get_tensor_value(as_list=False)
        dtype = utils.map_onnx_to_numpy_type(g.get_dtype(temp.output[0]))
        val = val.astype(dtype)
        logger.debug("found weights %s", temp.name)
    else:
        logger.debug(
            "weight node seems not to be Const, skip, node name is %s",
            temp.name)
        return None

    return val
示例#5
0
    def version_9(cls, ctx, node, **kwargs):
        """
        Obtained with a linear regression.

        ::

            def atan2(y, x):
                sx = numpy.sign(x)
                sy = numpy.sign(y)
                pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-numpy.pi/2)
                atan_part = numpy.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
                return atan_part + pi_part
        """
        supported_dtypes = [
            onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT16,
            onnx_pb.TensorProto.DOUBLE
        ]

        onnx_dtype = ctx.get_dtype(node.input[0])
        utils.make_sure(onnx_dtype in supported_dtypes,
                        "Unsupported input type.")
        shape = ctx.get_shape(node.input[0])
        np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)

        # sign part

        sign_x_node = ctx.make_node("Sign",
                                    inputs=node.input[1:],
                                    name=utils.make_name(node.name + 'signx'))
        sign_y_node = ctx.make_node("Sign",
                                    inputs=node.input[:1],
                                    name=utils.make_name(node.name + 'signy'))

        sx_node = ctx.make_node("Cast",
                                sign_x_node.output[:1],
                                attr={"to": onnx_dtype},
                                name=utils.make_name(node.name + 'csignx'))
        sy_node = ctx.make_node("Cast",
                                sign_y_node.output[:1],
                                attr={"to": onnx_dtype},
                                name=utils.make_name(node.name + 'csigny'))

        # cst

        one_node = ctx.make_const(utils.make_name("{}_one".format(node.name)),
                                  np.array([1], dtype=np_dtype))

        pib2_node = ctx.make_const(utils.make_name("{}_pi".format(node.name)),
                                   np.array(-np.pi / 2, dtype=np_dtype))

        # pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-numpy.pi/2)

        sxm1_node = ctx.make_node("Sub",
                                  [sx_node.output[0], one_node.output[0]],
                                  name=utils.make_name(node.name + 'sxm1'))
        sy2_node = ctx.make_node("Mul", [sy_node.output[0], sy_node.output[0]],
                                 name=utils.make_name(node.name + 'sy2'))
        sy2m1_node = ctx.make_node("Sub",
                                   [sy2_node.output[0], one_node.output[0]],
                                   name=utils.make_name(node.name + 'sy2m1'))
        sxsy2m1_node = ctx.make_node(
            "Mul", [sx_node.output[0], sy2m1_node.output[0]],
            name=utils.make_name(node.name + 'sxsy2m1'))
        sysxsy2m1_node = ctx.make_node(
            "Add", [sy_node.output[0], sxsy2m1_node.output[0]],
            name=utils.make_name(node.name + 'sysxsy2m1'))
        m1_node = ctx.make_node(
            "Mul", [sysxsy2m1_node.output[0], sxm1_node.output[0]],
            name=utils.make_name(node.name + 'm1'))
        pi_part = ctx.make_node("Mul",
                                [m1_node.output[0], pib2_node.output[0]],
                                name=utils.make_name(node.name + 'pip'))

        # atan

        sx2_node = ctx.make_node("Mul", [sx_node.output[0], sx_node.output[0]],
                                 name=utils.make_name(node.name + 'sx2'))
        sx2m1_node = ctx.make_node("Sub",
                                   [sx2_node.output[0], one_node.output[0]],
                                   name=utils.make_name(node.name + 'sx2m1'))
        xsx2m1_node = ctx.make_node("Add",
                                    [node.input[1], sx2m1_node.output[0]],
                                    name=utils.make_name(node.name + 'xsx2m1'))
        div_node = ctx.make_node("Div",
                                 inputs=[node.input[0], xsx2m1_node.output[0]],
                                 name=utils.make_name(node.name + 'div'))
        atan0_node = ctx.make_node("Atan",
                                   inputs=[div_node.output[0]],
                                   name=utils.make_name(node.name + 'atan0'))
        atan_node = ctx.make_node(
            "Mul",
            inputs=[sx2_node.output[0], atan0_node.output[0]],
            name=utils.make_name(node.name + 'atan'))

        # final

        ctx.remove_node(node.name)

        last_node = ctx.make_node(
            "Add",
            inputs=[atan_node.output[0], pi_part.output[0]],
            op_name_scope=node.name + 'all',
            shapes=[shape],
            dtypes=[onnx_dtype])
        ctx.replace_all_inputs(node.output[0],
                               last_node.output[0])  # ops=ctx.get_nodes()
示例#6
0
def compute_const_folding_using_tf(g, const_node_values, graph_outputs):
    """Find nodes with constant inputs and compute their values using TF"""
    if const_node_values is None:
        const_node_values = {}
    graph_outputs = set(graph_outputs)
    from tf2onnxnightly.tf_loader import tf_session, tf_placeholder  # pylint: disable=import-outside-toplevel

    ops = g.get_operations()
    outputs_to_values = {}
    outputs_to_dtypes = {}
    outputs_to_shapes = {}
    shape_node_outputs = {}

    def is_small_shape(x):
        return np.product(x) <= 1000

    def is_huge_shape(x):
        return np.product(x) >= 1000000

    for node in ops:
        # Load values of constants. Use const_node_values if possible
        if node.type in ["Const", "ConstV2"]:
            tensor = node.node_def.attr["value"].tensor
            if node.name in const_node_values:
                tensor.tensor_content = const_node_values[node.name]
            outputs_to_values[node.outputs[0].name] = get_tf_tensor_data(
                tensor)
            outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
        for out in node.outputs:
            outputs_to_shapes[out.name] = get_tf_tensor_shape(out)

    for node in ops:
        if node.type == "Shape":
            shape = outputs_to_shapes.get(node.inputs[0].name)
            if shape is not None:
                shape_node_outputs[node.outputs[0].name] = shape

    unneeded_outputs = set()
    progress = True
    while progress:
        progress = False
        for node in ops:
            # Find ops with constant inputs and compute their values
            input_names = [i.name for i in node.inputs]
            output_names = [i.name for i in node.outputs]
            if node.type == 'StridedSlice' and input_names[0] in shape_node_outputs \
                                           and output_names[0] not in outputs_to_values:
                shape = shape_node_outputs[input_names[0]]
                i = get_index_from_strided_slice_of_shape(
                    node, outputs_to_values)
                if i is not None and 0 <= i < len(
                        shape) and shape[i] is not None:
                    np_dtype = map_onnx_to_numpy_type(
                        map_tf_dtype(node.outputs[0].dtype))
                    outputs_to_values[output_names[0]] = np.array(
                        shape[i], dtype=np_dtype)
                    outputs_to_dtypes[
                        node.outputs[0].name] = node.outputs[0].dtype
                    progress = True
            can_fold = node.type not in [
                'Enter', 'Placeholder', 'PlaceholderWithDefault'
            ]
            can_fold = can_fold and not node.type.startswith('Random')
            can_fold = can_fold and len(input_names) > 0 and all(
                inp in outputs_to_values for inp in input_names)
            # We can only fold nodes with a single output
            can_fold = can_fold and len(
                output_names) == 1 and output_names[0] not in outputs_to_values
            # Skip if value already computed, used, and discarded
            can_fold = can_fold and output_names[
                0] not in unneeded_outputs and output_names[
                    0] not in graph_outputs
            if can_fold:
                # Make a mini graph containing just the node to fold
                g2 = tf.Graph()
                with g2.as_default():
                    for inp in input_names:
                        tf_placeholder(outputs_to_dtypes[inp],
                                       name=inp.split(':')[0])
                    mini_graph_def = g2.as_graph_def()
                    mini_graph_def.node.append(node.node_def)
                g3 = tf.Graph()
                with g3.as_default():
                    feed_dict = {}
                    inp_shapes = []
                    for inp in input_names:
                        inp_np = outputs_to_values[inp]
                        feed_dict[inp] = inp_np
                        inp_shapes.append(inp_np.shape)
                    try:
                        with tf_session() as sess:
                            tf.import_graph_def(mini_graph_def, name='')
                            results = sess.run(output_names,
                                               feed_dict=feed_dict)
                        if is_huge_shape(results[0].shape) and all(
                                is_small_shape(inp) for inp in inp_shapes):
                            logger.debug(
                                "Skipping folding of node %s since result shape %s is much larger "
                                "than input shapes %s", node.name,
                                results[0].shape, inp_shapes)
                        else:
                            outputs_to_values[output_names[0]] = results[0]
                            outputs_to_dtypes[
                                output_names[0]] = node.outputs[0].dtype
                            progress = True
                    except Exception:  # pylint: disable=broad-except
                        logger.debug("Could not fold node %s", node.name)
        unneeded_outputs.update(outputs_to_values.keys())
        for node in ops:
            # Mark values we need to keep
            input_names = [i.name for i in node.inputs]
            output_names = [i.name for i in node.outputs]
            if len(output_names) == 1 and output_names[0] in outputs_to_values:
                continue
            for i in input_names:
                if i in unneeded_outputs:
                    unneeded_outputs.remove(i)
        for node in unneeded_outputs:
            # Remove unneeded values to prevent memory usage explosion
            if node in outputs_to_values:
                del outputs_to_values[node]
                del outputs_to_dtypes[node]

    for node in ops:
        # We don't need the constants any more
        if node.type in ["Const", "ConstV2"
                         ] and node.outputs[0].name in outputs_to_values:
            del outputs_to_values[node.outputs[0].name]
            del outputs_to_dtypes[node.outputs[0].name]

    logger.info("Computed %d values for constant folding",
                len(outputs_to_values))
    return outputs_to_values, outputs_to_dtypes
示例#7
0
    def any_version(cls, opset, ctx, node, **kwargs):
        node_inputs = node.input
        num_segments_specified = False
        if node.type.endswith("WithNumSegments") or node.type.startswith("Unsorted"):
            num_segments_specified = True
            num_segments = node_inputs.pop()
            node.type = node.type.replace("WithNumSegments", "")
            node.type = node.type.replace("Unsorted", "")
        if node.type.startswith("Sparse"):
            data_inp, indices_inp, segment_inp = node_inputs
            gather_node = ctx.make_node("Gather", [data_inp, indices_inp], attr={'axis': 0})
            data_inp = gather_node.output[0]
            node.type = node.type.replace("Sparse", "")
        else:
            data_inp, segment_inp = node_inputs

        # Data has shape [n, a, b, ..., c]
        data_shape = ctx.get_shape(data_inp)
        data_rank = len(data_shape) if data_shape is not None else None
        data_dtype = ctx.get_dtype(data_inp)
        data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype)
        seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))

        if num_segments_specified and ctx.get_dtype(segment_inp) != ctx.get_dtype(num_segments):
            num_segments = ctx.make_node("Cast", [num_segments], attr={"to": ctx.get_dtype(segment_inp)}).output[0]

        data_is_float = np.dtype(data_np_dtype).kind == 'f'
        data_is_int = np.dtype(data_np_dtype).kind == 'i'
        utils.make_sure(data_is_float or data_is_int, "dtype for Segment ops must be float or int")

        if node.type in ["SegmentSum", "SegmentMean", "SegmentSqrtN"]:
            onnx_op = "ReduceSum"
            identity_value = np.array(0, dtype=data_np_dtype)
        elif node.type == "SegmentProd":
            onnx_op = "ReduceProd"
            identity_value = np.array(1, dtype=data_np_dtype)
        elif node.type == "SegmentMax":
            onnx_op = "ReduceMax"
            if data_is_float:
                identity_value = np.array('-inf', dtype=data_np_dtype)
            else:
                identity_value = np.iinfo(data_np_dtype).min
        elif node.type == "SegmentMin":
            onnx_op = "ReduceMin"
            if data_is_float:
                identity_value = np.array('inf', dtype=data_np_dtype)
            else:
                identity_value = np.iinfo(data_np_dtype).max

        if not num_segments_specified:
            max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={'axes': [0], 'keepdims': 0})
            one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype))
            num_segments = ctx.make_node("Add", [max_segment.output[0], one_const.output[0]]).output[0]
        # ORT doesn't support bool for OneHot so we use float32 and cast to bool
        onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=np.float32))
        # one_hot_node has shape [s, n] (s is # segments)
        one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments, onehot_values.output[0]],
                                     attr={'axis': 0})
        if node.type == "SegmentMean":
            scaling_node_output = GraphBuilder(ctx).make_reduce_sum(
                {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1})
        elif node.type == "SegmentSqrtN":
            seg_cnts_node_output = GraphBuilder(ctx).make_reduce_sum(
                {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1})
            scaling_node_output = ctx.make_node("Sqrt", [seg_cnts_node_output]).output[0]
        else:
            scaling_node_output = None

        if scaling_node_output is not None and num_segments_specified:
            # If empty segments are possible, we must avoid division by zero
            const_one_float = ctx.make_const(utils.make_name("const_one_float"), np.array(1, dtype=np.float32))
            scaling_node_output = ctx.make_node("Max", [scaling_node_output, const_one_float.output[0]]).output[0]


        if onnx_op == "ReduceSum":
            # If the op is a summation, we can use MatMul instead of Where, which is faster

            # Data shape is [n, a, b, ..., c]
            data_shape_node = ctx.make_node("Shape", [data_inp])
            new_shape = ctx.make_const(utils.make_name("reshape_const"), np.array([0, -1], dtype=np.int64))
            # Reshape the data from [n, a, b, ..., c] to [n, P]
            data_reshape = ctx.make_node("Reshape", [data_inp, new_shape.output[0]])

            one_hot_cast = one_hot_node
            if data_dtype != onnx_pb.TensorProto.FLOAT:
                one_hot_cast = ctx.make_node("Cast", [one_hot_node.output[0]], attr={'to': data_dtype})

            # Shapes [s, n] * [n, P] => [s, P]
            product = ctx.make_node("MatMul", [one_hot_cast.output[0], data_reshape.output[0]], op_name_scope=node.name)
            if scaling_node_output is not None:
                scaling_node_unsqueeze = GraphBuilder(ctx).make_unsqueeze(
                    {'data': scaling_node_output, 'axes': [1]}, return_node=True)
                product = ctx.make_node("Div", [product.output[0], scaling_node_unsqueeze.output[0]])

            # Create new shape [0, a, b, ..., c]
            max_int64 = int(utils.get_max_value(np.int64))
            new_shape_slice = GraphBuilder(ctx).make_slice(
                {"data": data_shape_node.output[0], "ends": [max_int64], "starts": [1], "axes": [0]})
            zero_const = ctx.make_const(utils.make_name("zero_const"), np.array([0], dtype=np.int64))
            new_shape = ctx.make_node("Concat", [zero_const.output[0], new_shape_slice], attr={'axis': 0})

            shapes = node.output_shapes
            dtypes = node.output_dtypes
            ctx.remove_node(node.name)
            # Reshape result from [s, P] to [s, a, b, ..., c]
            ctx.make_node("Reshape", [product.output[0], new_shape.output[0]],
                          name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
            return

        identity_const = ctx.make_const(utils.make_name("const_identity"), identity_value)
        one_hot_bool = ctx.make_node("Cast", [one_hot_node.output[0]], attr={"to": onnx_pb.TensorProto.BOOL})
        one_hot_unsqueeze = one_hot_bool

        # Make one_hot_unsqueeze have shape [s, n, 1, 1, ..., 1]
        if data_rank is None:
            # Unsqueeze requires known rank, but we can use Reshape if rank is unknown
            shape_node = ctx.make_node("Shape", [data_inp])
            rank_node = ctx.make_node("Shape", [shape_node.output[0]])
            one_const_int64 = ctx.make_const(utils.make_name("const_one"), np.array([1], dtype=np.int64))
            num_unsqueeze_dims = ctx.make_node("Sub", [rank_node.output[0], one_const_int64.output[0]])

            one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[1])
            unsqueeze_dims = ctx.make_node("ConstantOfShape", inputs=[num_unsqueeze_dims.output[0]],
                                           attr={"value": one_tensor})
            # Zero indicates a dimension should be unchanged
            double_zero_const = ctx.make_const(utils.make_name("double_zero"), np.array([0, 0], dtype=np.int64))
            expanded_shape = ctx.make_node("Concat", [double_zero_const.output[0], unsqueeze_dims.output[0]],
                                           attr={'axis': 0})
            one_hot_unsqueeze = ctx.make_node("Reshape", [one_hot_bool.output[0], expanded_shape.output[0]])
        elif data_rank > 1:
            new_dims = list(range(2, 2 + data_rank - 1))
            one_hot_unsqueeze = GraphBuilder(ctx).make_unsqueeze(
                {'data': one_hot_bool.output[0], 'axes': new_dims}, return_node=True)

        # Shape of data:       [n, a, b, ..., c]
        # Shape of one_hot: [s, n, 1, 1, ..., 1]
        # Broadcast left-pads shape with 1s, so result is shape: [s, n, a, b, ..., c]
        where_node = ctx.make_node("Where", [one_hot_unsqueeze.output[0], data_inp, identity_const.output[0]])

        shapes = node.output_shapes
        dtypes = node.output_dtypes
        ctx.remove_node(node.name)
        # After reduction over axis 1, shape is: [s, a, b, ..., c]
        ctx.make_node(onnx_op, [where_node.output[0]], attr={'axes': [1], 'keepdims': 0},
                      name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
示例#8
0
    def version_1(cls, ctx, node, **kwargs):
        """
        Args:
          x: A `Tensor`. Must be one of the following types: `float32`.
            The input to the LSTM cell, shape (batch_size, num_inputs).
          cs_prev: A `Tensor`. Must have the same type as `x`.
            Value of the cell state at previous time step.
          h_prev: A `Tensor`. Must have the same type as `x`.
            Output of the previous cell at previous time step.
          w: A `Tensor`. Must have the same type as `x`. The weight matrix.
          wci: A `Tensor`. Must have the same type as `x`.
            The weight matrix for input gate peephole connection.
          wcf: A `Tensor`. Must have the same type as `x`.
            The weight matrix for forget gate peephole connection.
          wco: A `Tensor`. Must have the same type as `x`.
            The weight matrix for output gate peephole connection.
          b: A `Tensor`. Must have the same type as `x`. The bias vector.
          forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
          cell_clip: An optional `float`. Defaults to `-1` (no clipping).
            Value to clip the 'cs' value to. Disable by setting to negative value.
          use_peephole: An optional `bool`. Defaults to `False`.
            Whether to use peephole weights.
          name: A name for the operation (optional).
        Returns:
          A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
          i: A `Tensor`. Has the same type as `x`. The input gate.
          cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh.
          f: A `Tensor`. Has the same type as `x`. The forget gate.
          o: A `Tensor`. Has the same type as `x`. The output gate.
          ci: A `Tensor`. Has the same type as `x`. The cell input.
          co: A `Tensor`. Has the same type as `x`. The cell after the tanh.
          h: A `Tensor`. Has the same type as `x`. The output h vector.
        ```python
        xh = [x, h_prev]
        [i, ci, f, o] = xh * w + b
        f = f + forget_bias
        if not use_peephole:
          wci = wcf = wco = 0
        i = sigmoid(cs_prev .* wci + i)
        f = sigmoid(cs_prev .* wcf + f)
        ci = tanh(ci)
        cs = ci .* i + cs_prev .* f
        cs = clip(cs, cell_clip)
        o = sigmoid(cs * wco + o)
        co = tanh(cs)
        h = co .* o
        ```
        """
        nodes = []
        x, cs_prev, h_prev, w, wci, wcf, wco, b = node.input
        forget_bias = float(node.get_attr("forget_bias").f)
        cell_clip = float(node.get_attr("cell_clip").f)
        use_peephole = bool(node.get_attr("use_peephole").i)

        def make_sigmoid(i, w, b):
            i_w_node = ctx.make_node("Mul", [i, w])
            i_w_b_node = ctx.make_node("Add", [i_w_node.output[0], b])
            output_node = ctx.make_node("Sigmoid", [i_w_b_node.output[0]])
            nodes.extend([i_w_node, i_w_b_node, output_node])
            return output_node.output[0]

        # xh = [x, h]
        xh_node = ctx.make_node("Concat", [x, h_prev], attr={"axis": 1})

        # i, ci, f, o = xh * w + b
        xh_w_node = ctx.make_node("MatMul", [xh_node.output[0], w])
        w_shape = ctx.get_shape(w)
        if len(w_shape) != 2 or w_shape[1] % 4 != 0:
            raise RuntimeError(
                "shape of W of LSTMBlockCell {} should be times of 4".format(
                    node.name))
        merged_output_node = ctx.make_node("Add", [xh_w_node.output[0], b])
        w_last_dim = int(w_shape[1] / 4)
        split_output_node = ctx.make_node("Split",
                                          [merged_output_node.output[0]],
                                          attr={"axis": 1},
                                          output_count=4)
        i, ci, f, o = split_output_node.output

        # f = f + forget_bias
        forget_bias_const = ctx.make_const(
            utils.make_name("{}__forget_bias".format(node.name)),
            np.array(forget_bias, dtype=np.float32))
        f_node = ctx.make_node("Add", [f, forget_bias_const.output[0]])

        if not use_peephole:
            zeros_const = ctx.make_const(
                utils.make_name("{}__zeros_const".format(node.name)),
                np.zeros([w_last_dim], dtype=np.float32))
            nodes.append(zeros_const)
            wci = zeros_const.output[0]
            wcf = zeros_const.output[0]
            wco = zeros_const.output[0]

        # i = sigmoid(cs_prev .* wci + i)
        i = make_sigmoid(cs_prev, wci, i)
        # f = sigmoid(cs_prev .* wcf + f)
        f = make_sigmoid(cs_prev, wcf, f_node.output[0])
        # ci = Tanh(ci)
        ci_node = ctx.make_node("Tanh", [ci])
        # cs = ci .* i + f .* cs_prev
        ci_i_node = ctx.make_node("Mul", [ci_node.output[0], i])
        cs_prev_f_node = ctx.make_node("Mul", [cs_prev, f])
        cs_node = ctx.make_node(
            "Add", [ci_i_node.output[0], cs_prev_f_node.output[0]])
        cs = cs_node.output[0]
        # cs = clip(cs)
        if cell_clip > 0:
            if ctx.opset < 11:
                cs_clip_node = ctx.make_node("Clip", [cs],
                                             attr={
                                                 "max": cell_clip,
                                                 "min": -cell_clip
                                             })
                nodes.append(cs_clip_node)
                cs = cs_clip_node.output[0]
            else:
                dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(cs))
                name_min = utils.make_name("{}_min".format(node.name))
                name_max = utils.make_name("{}_max".format(node.name))
                min_const = ctx.make_const(name_min,
                                           np.array(-cell_clip, dtype=dtype))
                max_const = ctx.make_const(name_max,
                                           np.array(cell_clip, dtype=dtype))
                cs_clip_node = ctx.make_node(
                    'Clip', [cs, min_const.output[0], max_const.output[0]])
                nodes.append(cs_clip_node)
                cs = cs_clip_node.output[0]

        # o = cs * wco + o
        o = make_sigmoid(cs, wco, o)
        # co = Tanh(cs)
        co_node = ctx.make_node("Tanh", [cs])
        # h = co .* o
        h_node = ctx.make_node("Mul", [co_node.output[0], o])

        def replace_output(old_output, new_output):
            ctx.replace_all_inputs(old_output,
                                   new_output)  # ops=ctx.get_nodes()
            ctx.copy_dtype(old_output, new_output)
            ctx.copy_shape(old_output, new_output)

        replace_output(node.output[0], i)
        replace_output(node.output[1], cs)
        replace_output(node.output[2], f)
        replace_output(node.output[3], o)
        replace_output(node.output[4], ci_node.output[0])
        replace_output(node.output[5], co_node.output[0])
        replace_output(node.output[6], h_node.output[0])
示例#9
0
    def any_version(cls, const_length, opset, ctx, node, **kwargs):
        """
        Inspired from `Python implementation of RFFT
        <https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/>`_.

        Complex version:

        ::

            import numpy as np

            def _DFT_cst(N, fft_length):
                n = np.arange(N)
                k = n.reshape((N, 1)).astype(np.float64)
                M = np.exp(-2j * np.pi * k * n / N)
                return M[:fft_length // 2 + 1]

            def DFT(x, fft_length=None):
                if len(x.shape) == 1:
                    x = x.reshape((-1, 1))
                else:
                    x = x.T
                if fft_length is None:
                    fft_length = x.shape[0]
                cst = _DFT_cst(x.shape[0], fft_length)
                return np.dot(cst, x).T

        Real version, first axis is (real, imag) part:

        ::

            import numpy as np

            def _DFT_real_cst(N, fft_length):
                n = np.arange(N)
                k = n.reshape((N, 1)).astype(np.float64)
                M = np.exp(-2j * np.pi * k * n / N)
                M = M[:fft_length // 2 + 1]
                both = np.empty((2,) + M.shape)
                both[0, :, :] = np.real(M)
                both[1, :, :] = np.imag(M)
                return both

            def DFT_real(x, fft_length=None):
                if len(x.shape) == 1:
                    x = x.reshape((-1, 1))
                else:
                    x = x.T
                if fft_length is None:
                    fft_length = x.shape[0]
                cst = _DFT_real_cst(x.shape[0], fft_length)
                res = np.dot(cst, x)
                return np.transpose(res, (0, 2, 1))
        """
        supported_dtypes = [
            onnx_pb.TensorProto.FLOAT,
            onnx_pb.TensorProto.FLOAT16,
            onnx_pb.TensorProto.DOUBLE,
            onnx_pb.TensorProto.COMPLEX64,
            onnx_pb.TensorProto.COMPLEX128,
        ]
        consumers = ctx.find_output_consumers(node.output[0])
        consumer_types = set(op.type for op in consumers)
        utils.make_sure(
            consumer_types == {'ComplexAbs'},
            "Current implementation of RFFT or FFT only allows ComplexAbs as consumer not %r",
            consumer_types)

        input_name = node.input[0]
        onnx_dtype = ctx.get_dtype(input_name)
        utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.")
        shape = ctx.get_shape(node.input[0])
        shape_n = shape[-1]

        if onnx_dtype in (onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128):
            parent = ctx.get_node_by_output_in_current_graph(node.input[0])
            utils.make_sure(
                parent.type == 'Cast' and parent.get_attr_value('to') == onnx_dtype,
                "Current implementation of FFT or RFFT assumes the input is real or complex produced "
                "by a node Cast just before this one.")
            input_name = parent.input[0]
            onnx_dtype = ctx.get_dtype(input_name)

        np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)

        if np_dtype == np.float16:
            res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float16)
            np_dtype = np.float16
        elif np_dtype in (np.float32, np.complex64):
            res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float32)
            np_dtype = np.float32
        else:
            res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float64)
            np_dtype = np.float64

        if const_length:
            # RFFT: length of FFT is known, some computation
            # (see function make_dft_constant)
            # can be done at conversion time and stored as constant
            utils.make_sure(len(node.input) == 2, "Two inputs expected not %r", len(node.input))

            # This input should be a constant.
            fft_length_name = node.input[1]
            node_fft_length = ctx.get_node_by_output(fft_length_name, search_in_parent_graphs=True)
            utils.make_sure(node_fft_length.type == 'Const',
                            "fft_length should be a constant, the other case is not implemented yet.")
            value = node_fft_length.get_attr("value")
            value_array = to_array(value.t)
            utils.make_sure(value_array.shape == (1,), "Unexpected shape for fft_length (%r)", value_array.shape)
            fft_length = value_array[0]

            # TODO: handle this parameter when onnx.helper.make_node is fixed.
            # Tcomplex = node.get_attr("Tcomplex")

            real_imag_part = make_dft_constant(shape_n, np_dtype, fft_length)
            onx_real_imag_part = ctx.make_const(
                name=utils.make_name('cst_rfft_%d' % shape_n), np_val=real_imag_part)
            onx_real_imag_part_name = onx_real_imag_part.name
        else:
            # FFT: length of FFT is unknown, the matrix
            # created by function make_dft_constant must be
            # done in ONNX.
            dyn_shape_all = ctx.make_node("Shape", inputs=[input_name],
                                          name=utils.make_name('CPLX_' + node.name + 'shape'))
            m1_cst = ctx.make_const(name=utils.make_name('CPLX_m1'), np_val=np.array([-1], dtype=np.int64))
            dyn_shape = ctx.make_node('Gather', inputs=[dyn_shape_all.output[0], m1_cst.name])
            one_tensor = helper.make_tensor("value", res_onnx_dtype, dims=[1], vals=[1])
            cst_1 = ctx.make_node("ConstantOfShape", inputs=[dyn_shape.output[0]], attr={"value": one_tensor})
            just_0 = ctx.make_const(name=utils.make_name('CPLX1'), np_val=np.array([0], dtype=np.int64))
            rng1 = ctx.make_node("CumSum", inputs=[cst_1.output[0], just_0.name],
                                 name=utils.make_name('CPLX_' + node.name + 'range'))
            p1_cst = ctx.make_const(name=utils.make_name('CPLX_p1'), np_val=np.array([1], dtype=np_dtype))
            rng = ctx.make_node("Sub", inputs=[rng1.output[0], p1_cst.name],
                                name=utils.make_name('CPLX_' + node.name + 'range'))
            resh_cst = ctx.make_const(name=utils.make_name('CPLX_reshape'), np_val=np.array([1, -1], dtype=np.int64))
            rng_tr1 = ctx.make_node("Reshape", inputs=[rng.output[0], resh_cst.name],
                                    name=utils.make_name('CPLX_' + node.name + 'range'))
            resh_cst = ctx.make_const(name=utils.make_name('CPLX_reshape'), np_val=np.array([-1, 1], dtype=np.int64))
            rng_tr2 = ctx.make_node("Reshape", inputs=[rng.output[0], resh_cst.name],
                                    name=utils.make_name('CPLX_' + node.name + 'range'))
            rng_mat = ctx.make_node('MatMul', inputs=[rng_tr2.output[0], rng_tr1.output[0]],
                                    name=utils.make_name('CPLX_' + node.name + 'range2'))
            pi_cst = ctx.make_const(name=utils.make_name('CPLX_pi'), np_val=np.array([np.pi * 2], dtype=np_dtype))
            angle_pi = ctx.make_node("Mul", inputs=[rng_mat.output[0], pi_cst.name],
                                     name=utils.make_name('CPLX_' + node.name + 'angle_pi'))
            shape_cast = ctx.make_node('Cast', inputs=[dyn_shape.output[0]], attr={'to': res_onnx_dtype})
            angle_pibn = ctx.make_node("Div", inputs=[angle_pi.output[0], shape_cast.output[0]],
                                       name=utils.make_name('CPLX_' + node.name + 'angle'))
            if opset >= 13:
                angle = ctx.make_node("Unsqueeze", inputs=[angle_pibn.output[0], just_0.name],
                                      name=utils.make_name('CPLX_' + node.name + 'angles'))
            else:
                angle = ctx.make_node("Unsqueeze", inputs=[angle_pibn.output[0]],
                                      name=utils.make_name('CPLX_' + node.name + 'angles'),
                                      attr={'axes': [0]})
            rng_cos = ctx.make_node("Cos", inputs=[angle.output[0]],
                                    name=utils.make_name('CPLX_' + node.name + 'cos'))
            rng_sin = ctx.make_node("Sin", inputs=[angle.output[0]],
                                    name=utils.make_name('CPLX_' + node.name + 'sin'))
            onx_real_imag_part = ctx.make_node("Concat", inputs=[rng_cos.output[0], rng_sin.output[0]],
                                               name=utils.make_name('CPLX_' + node.name + '_cst_fft'),
                                               attr={'axis': 0})
            onx_real_imag_part_name = onx_real_imag_part.output[0]

        shapei = list(np.arange(len(shape)))
        perm = shapei[:-2] + [shapei[-1], shapei[-2]]
        trx = ctx.make_node(
            "Transpose", inputs=[input_name], attr=dict(perm=perm),
            name=utils.make_name(node.name + 'tr'))

        ctx.remove_node(node.name)
        mult = ctx.make_node(
            "MatMul", inputs=[onx_real_imag_part_name, trx.output[0]],
            name=utils.make_name('CPLX_' + node.name + 'rfft'))

        new_shape = [2] + list(shape)
        shapei = list(np.arange(len(new_shape)))
        perm = shapei[:-2] + [shapei[-1], shapei[-2]]
        last_node = ctx.make_node(
            "Transpose", inputs=[mult.output[0]], attr=dict(perm=perm),
            name=utils.make_name('CPLX_' + node.name + 'rfft'),
            shapes=[new_shape], dtypes=[res_onnx_dtype])

        ctx.replace_all_inputs(node.output[0], last_node.output[0])  # ops=ctx.get_nodes()