Esempio n. 1
0
def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output_index, all_nodes, to_remove):
    fw_consumers = g.find_output_consumers(rnn_fw.output[rnn_output_index])
    bw_consumers = g.find_output_consumers(rnn_bw.output[rnn_output_index])
    if not fw_consumers and not bw_consumers:
        return

    if rnn_output_index == 0:
        axis = 1
        # remove reverse op for rnn_bw
        reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw)

        for r_op in reverse_nodes:
            logger.debug("remove reverse op %s", r_op.name)
            g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0])
            to_remove.append(r_op.name)
    elif rnn_output_index in [1, 2]:
        axis = 0
    else:
        raise ValueError("rnn only should has 3 outputs.")

    if fw_consumers:
        attr = {"axes": [axis], "starts": [0], "ends": [1]}
        inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr}
        slice_node_fw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_fw))
        g.replace_all_inputs(fw_consumers, rnn_fw.output[rnn_output_index], slice_node_fw)

    if bw_consumers:
        attr = {"axes": [axis], "starts": [1], "ends": [2]}
        inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr}
        slice_node_bw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_bw))
        g.replace_all_inputs(bw_consumers, rnn_bw.output[rnn_output_index], slice_node_bw)
Esempio n. 2
0
    def create_rnn_node(self, context):
        gb = GraphBuilder(self.g)
        rnn_nodes = list()
        outputs = context.loop_properties.scan_outputs_exits
        logger.debug("number of rnn node outputs: %s", len(outputs))

        for i in range(self.num_lstm_layers):
            logger.debug("creating rnn node for layer: %s", i)
            rnn_nodes.append(self.create_single_rnn_node(context, i))
            output_id = rnn_nodes[i].output[0]
            rnn_output_shape = self.g.get_shape(output_id)
            squeeze_output_shape = [
                rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3]
            ]
            squeeze_node = gb.make_squeeze(
                {
                    "data": output_id,
                    "axes": [1]
                },
                shapes=[squeeze_output_shape],
                dtypes=[self.g.get_dtype(output_id)],
                return_node=True)
            if i + 1 < self.num_lstm_layers:
                logger.debug("setting input for layer: %s", i + 1)
                context.onnx_input_ids[i + 1]["X"] = squeeze_node.output[0]
        return rnn_nodes
    def process_seq_length(self, context):
        # output: [time step, batch size, input size]
        seq_len_node = context.seq_len_node
        shape_node = self.g.make_node("Shape", [context.onnx_input_ids["X"]])
        # LSTMCell only allow inputs of [batch size, input_size], so we assume dynamic_rnn has 3 dims.
        # Slice cannot support Int64 in OPSET 7, so we cast here.
        cast_shape_node = self.g.make_node(
            "Cast", [shape_node.output[0]],
            attr={"to": TensorProto.FLOAT},
            shapes=[self.g.get_shape(shape_node.output[0])])

        attr = {"axes": [0], "starts": [1], "ends": [2]}
        inputs_map = {"data": cast_shape_node.output[0], **attr}
        batchsize_node = GraphBuilder(self.g).make_slice(inputs_map)
        if not seq_len_node:
            # Tile's repeats must be INT64
            repeat_node = self.g.make_node("Cast", [batchsize_node],
                                           attr={"to": TensorProto.INT64})

            attr = {"axes": [0], "starts": [0], "ends": [1]}
            inputs_map = {"data": cast_shape_node.output[0], **attr}
            timestep_node = GraphBuilder(self.g).make_slice(inputs_map)
            tile_node = self.g.make_node(
                "Tile", [timestep_node, repeat_node.output[0]])

            # LSTM sequence_lens needs to be int32
            seq_len_node = self.g.make_node("Cast", [tile_node.output[0]],
                                            attr={"to": TensorProto.INT32})
        context.onnx_input_ids["sequence_lens"] = seq_len_node.output[0]
Esempio n. 4
0
 def version_11(cls, ctx, node, **kwargs):
     # create loop of resize to cater to tensorflow CropAndResize, one box one iteration
     mode = "nearest" if node.get_attr("method") is not None and node.get_attr(
         "method").s == b"nearest" else "linear"
     extrapolation_value = float(node.get_attr("extrapolation_value", "0").f)
     input_x = node.inputs[0]
     boxes = node.inputs[1]
     box_ind = node.inputs[2]
     crop_size = node.inputs[3]
     trip_name = utils.make_name(node.name + "_i")
     cond_name = utils.make_name(node.name + "_cond")
     cond_out_name = utils.make_name(node.name + "cond_out")
     g = ctx.create_new_graph_with_same_config()
     g.add_graph_input(trip_name, TensorProto.INT64, [1])
     g.add_graph_input(cond_name, TensorProto.BOOL, [])
     g.parent_graph = ctx
     const_zero = g.make_const(utils.make_name(node.name + "_const_zero"), np.array([0], dtype=np.int32))
     const_zero_long = g.make_const(utils.make_name(node.name + "_const_zero_long"), np.array([0], dtype=np.int64))
     const_one = g.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int32))
     const_one_long = g.make_const(utils.make_name(node.name + "_const_one_long"), np.array([1], dtype=np.int64))
     index_end = g.make_node("Add", [trip_name, const_one_long.output[0]])
     box_index_from = g.make_node("Slice", [box_ind.output[0], trip_name, index_end.output[0]], name="Slice_a")
     box_index_to = g.make_node("Add", [box_index_from.output[0], const_one.output[0]])
     target_x = g.make_node("Slice", [input_x.output[0], box_index_from.output[0], box_index_to.output[0],
                                      const_zero.output[0]], name="Slice_b")
     transposed_x = g.make_node("Transpose", [target_x.output[0]], attr={'perm': constants.NHWC_TO_NCHW})
     shape_of_transposed_x = g.make_node("Shape", [transposed_x.output[0]])
     const_zero_zero = g.make_const(utils.make_name(node.name + "_const_zero_zero"),
                                    np.array([0, 0], dtype=np.float32))
     const_one_one = g.make_const(utils.make_name(node.name + "_const_one_one"),
                                  np.array([1, 1], dtype=np.float32))
     const_four = g.make_const(utils.make_name(node.name + "_const_four"), np.array([4], dtype=np.int64))
     const_empty_float = g.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32))
     first_half_of_shape = GraphBuilder(g).make_slice(
         {"data": shape_of_transposed_x.output[0], "ends": [2], "starts": [0]})
     box = g.make_node("Slice", [boxes.output[0], trip_name, index_end.output[0], const_zero_long.output[0]],
                       name="Slice_c")
     roi_raw = g.make_node("Reshape", [box.output[0], const_four.output[0]])
     roi_raw_first_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [2], "starts": [0]})
     roi_raw_second_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [4], "starts": [2]})
     roi_concat_1 = g.make_node("Concat", [const_zero_zero.output[0], roi_raw_first_half], attr={'axis': 0})
     roi_concat_2 = g.make_node("Concat", [const_one_one.output[0], roi_raw_second_half], attr={'axis': 0})
     final_roi = g.make_node("Concat", [roi_concat_1.output[0], roi_concat_2.output[0]], attr={'axis': 0})
     crop_size_int64 = g.make_node("Cast", [crop_size.output[0]], attr={'to': TensorProto.INT64})
     final_crop_size = g.make_node("Concat", [first_half_of_shape, crop_size_int64.output[0]], {'axis': 0})
     resized_x = g.make_node("Resize", [transposed_x.output[0], final_roi.output[0], const_empty_float.output[0],
                                        final_crop_size.output[0]],
                             attr={"mode": mode, "extrapolation_value": extrapolation_value,
                                   "coordinate_transformation_mode": "tf_crop_and_resize"})
     recovered_x = g.make_node("Transpose", [resized_x.output[0]], attr={'perm': constants.NCHW_TO_NHWC})
     squeeze_x = g.make_node("Squeeze", inputs=[recovered_x.output[0]], attr={"axes": [0]})
     g.make_node("Identity", [cond_name], outputs=[cond_out_name])
     g.add_graph_output(cond_out_name, TensorProto.BOOL, [])
     g.add_graph_output(squeeze_x.output[0], ctx.get_dtype(node.input[0]), [-1, -1, -1])
     trip_node = ctx.make_node("Size", [box_ind.output[0]])
     cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool))
     ctx.remove_node(node.name)
     inner_loop = ctx.make_node("Loop", [trip_node.output[0], cond_const.output[0]], name=node.name,
                                outputs=node.output)
     inner_loop.set_body_graph_as_attr("body", g)
Esempio n. 5
0
    def _connect_lstm_ych_to_graph(self, context, i):
        # in tf, concat of y_c and y_h output shape is: [batch, hidden *2]
        # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden]
        gb = GraphBuilder(self.g)
        exit_output = context.state_variables["ct_ht" + str(i)].exit_output
        lstm_node = context.rnn_node[i]
        yc_shape = self.g.get_shape(lstm_node.output[2])
        concat_output_shape = [yc_shape[0], yc_shape[1], yc_shape[2] * 2]
        concat = self.g.make_node(
            "Concat", [lstm_node.output[2], lstm_node.output[1]],
            attr={"axis": 2},
            shapes=[concat_output_shape],
            dtypes=[self.g.get_dtype(lstm_node.output[2])])

        squeeze_output_shape = [concat_output_shape[1], concat_output_shape[2]]
        squeeze_node = gb.make_squeeze(
            {
                'data': concat.output[0],
                "axes": [0]
            },
            shapes=[squeeze_output_shape],
            dtypes=[self.g.get_dtype(concat.output[0])],
            return_node=True)

        self.g.replace_all_inputs(
            exit_output.id, squeeze_node.output[0])  # ops=self.g.get_nodes()
    def connect_unit_rnn_output_to_graph(self, context):
        outputs = context.loop_properties.scan_outputs_exits
        if not outputs:
            logger.debug("no one consume output")
            return

        gather_output_id = outputs[0].id
        logger.debug("found output for rnn: %s", gather_output_id)

        # in tf batch major mode, output shape is : [batch, time, hidden]
        # in time major mode, output shape is: [time, batch, hidden]
        # in onnx, output shape is : [time, num_directions, batch, hidden]

        rnn_node = context.rnn_node
        output_id = rnn_node.output[0]
        rnn_output_shape = self.g.get_shape(output_id)
        squeeze_output_shape = [
            rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3]
        ]
        gb = GraphBuilder(self.g)
        squeeze_node = gb.make_squeeze({
            'data': output_id,
            "axes": [1]
        },
                                       shapes=[squeeze_output_shape],
                                       dtypes=[self.g.get_dtype(output_id)],
                                       return_node=True)
        self.g.replace_all_inputs(
            gather_output_id, squeeze_node.output[0])  # ops=self.g.get_nodes()
Esempio n. 7
0
    def version_13(cls, ctx, node, **kwargs):
        ctx.ta_reads.append(node.input[0])
        node.type = "Gather"
        ctx.replace_inputs(node, [node.input[0], node.input[1]])

        g = GraphBuilder(ctx)

        usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True)
        ctx.insert_node_on_output(usq_node)

        sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True)
        ctx.insert_node_on_output(sq_node)
Esempio n. 8
0
    def _process_non_tuple_ch_init_nodes(self, context, i):
        input_id = context.state_variables["ct_ht" + str(i)].enter_input_id
        hidden_size = context.hidden_size[i]

        attr = {"axes": [1], "starts": [0], "ends": [hidden_size]}
        inputs_map = {"data": input_id, **attr}
        slice_node1 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_1 = self.g.make_node("Unsqueeze", [slice_node1], attr={"axes": [0]})

        attr = {"axes": [1], "starts": [hidden_size], "ends": [hidden_size * 2]}
        inputs_map = {"data": input_id, **attr}
        slice_node2 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_2 = self.g.make_node("Unsqueeze", [slice_node2], attr={"axes": [0]})

        return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
Esempio n. 9
0
    def version_9(cls, ctx, node, **kwargs):
        # float32/64 output = SparseSoftmaxCrossEntropyWithLogits(float32/64 features, int32/64 labels)
        # the detail math process of this op is: a = onehot(labels), b = logsoftmax(features), reduce_sum(mul(a, b))
        logit_node = node.inputs[0]
        logit_shape = ctx.get_shape(node.input[0])
        logit_dtype = ctx.get_dtype(node.input[0])

        label_name = node.input[1]

        if logit_shape is not None and logit_shape[-1] != -1:
            num_class = logit_shape[-1]
            node_nme = utils.make_name("onehot_depth")
            depth_node = ctx.make_const(node_nme, np.array([num_class]).astype(np.int64)).output[0]
        else:
            logit_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
            slice_args = {"data": logit_shape,
                          "starts": [-1], "ends": [int(utils.get_max_value(np.int32))]}
            num_class = GraphBuilder(ctx).make_slice(kwargs=slice_args)
            depth_node = num_class
        values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64)).output[0]
        label_dtype = ctx.get_dtype(label_name)
        if label_dtype != TensorProto.INT64:
            onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0]
        else:
            onehot_indice = label_name
        label_node = ctx.make_node(op_type="OneHot",
                                   inputs=[onehot_indice, depth_node, values_node])
        # the above logic makes output dtype of label_node now always int64
        # make sure label has same dtype as logit
        if logit_dtype != TensorProto.INT64:
            label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])

        _make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)
Esempio n. 10
0
 def version_1(cls, ctx, node, **kwargs):
     # in tf-2.0 grappler optimizes the graph pretty well and our matching logic
     # in the rewriter does not trigger. grappler will send the random uniform
     # with shape as input so we need to pickup the input here and if the shape is
     # const we make it an attribute.
     seed = node.get_attr("seed")
     node.set_attr("seed", float(seed.f))
     utils.make_sure(node.inputs[0].is_const(),
                     "%s node with non-const shape requires opset >= 9",
                     node.type)
     shape = node.inputs[0].get_tensor_value()
     ctx.remove_input(node, node.input[0], 0)
     if len(shape) == 0:
         # ORT can't take an empty shape (scalar)
         node.set_attr("shape", [1])
         ctx.set_shape(node.output[0], [1])
         squeeze_node = GraphBuilder(ctx).make_squeeze(
             {
                 'data': node.output[0],
                 'axes': [0]
             }, return_node=True)
         ctx.insert_node_on_output(squeeze_node, node.output[0])
         rand_out = squeeze_node.output[0]
     else:
         node.set_attr("shape", shape)
         ctx.set_shape(node.output[0], shape)
         rand_out = node.output[0]
     if node.type == "RandomUniformInt":
         cls.randuniform_int(ctx, node, rand_out, node.input[0],
                             node.input[1])
         node.type = "RandomUniform"
         ctx.replace_inputs(node, [])
     elif node.type == "RandomStandardNormal":
         node.type = "RandomNormal"
Esempio n. 11
0
    def version_10(cls, ctx, node, **kwargs):
        inp_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
        dim_0 = GraphBuilder(ctx).make_slice({
            'data': inp_shape,
            'starts': [0],
            'ends': [1],
            'axes': [0]
        })
        zeros = ctx.make_node("ConstantOfShape", [dim_0],
                              shapes=[[-1]]).output[0]

        seed = node.get_attr_value("seed", 0)
        seed2 = node.get_attr_value("seed2", 0)
        onnx_seed = utils.combine_seeds(seed, seed2)
        rand_attr = {'dtype': onnx_pb.TensorProto.FLOAT}
        if onnx_seed is not None:
            rand_attr['seed'] = onnx_seed

        random_floats = ctx.make_node("RandomUniformLike", [zeros],
                                      op_name_scope=node.name,
                                      shapes=[[-1]],
                                      attr=rand_attr).output[0]
        # Use indices of the TopK to get a random ordering
        _, random_ordering = ctx.make_node("TopK", [random_floats, dim_0],
                                           output_count=2,
                                           attr={
                                               'axis': -1
                                           }).output
        shuffled_res = ctx.make_node(
            "Gather", [node.input[0], random_ordering]).output[0]
        ctx.replace_all_inputs(node.output[0], shuffled_res)
Esempio n. 12
0
 def any_version(cls, opset, ctx, node, **kwargs):
     node.domain = constants.CONTRIB_OPS_DOMAIN
     separator = node.get_attr_value("separator")
     if separator is None:
         separator = b''
     separator = separator.decode('UTF-8')
     separator_node = ctx.make_const(utils.make_name("separator"),
                                     np.array([separator], np.object))
     axis_node = ctx.make_const(utils.make_name("axis"),
                                np.array([0], np.int64))
     inps_with_shapes = [i for i in node.input if ctx.get_shape(i) != []]
     shape_node = None
     if 0 < len(inps_with_shapes) < len(node.input):
         shape_node = ctx.make_node("Shape", [inps_with_shapes[0]])
     unsqueezes = []
     for inp in node.input:
         if ctx.get_shape(inp) == [] and shape_node is not None:
             expand_node = ctx.make_node("Expand",
                                         [inp, shape_node.output[0]])
             inp = expand_node.output[0]
         unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({
             'data': inp,
             'axes': [0]
         })
         unsqueezes.append(unsqueeze_node)
     stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
     ctx.replace_inputs(node, [
         stack_node.output[0], separator_node.output[0], axis_node.output[0]
     ])
Esempio n. 13
0
    def version_6(cls, ctx, node, **kwargs):
        # T output = All(T x, list(int) reduce_indices, @bool keepdims)
        # T output = Any(T x, list(int) reduce_indices, @bool keepdims)
        reduce_dim = node.inputs[1].get_tensor_value()

        # for Any, the reduce_indices can be scalar as observed.
        if np.isscalar(reduce_dim):
            reduce_dim = [reduce_dim]

        if ctx.opset < 11:
            utils.make_sure(all(i >= 0 for i in reduce_dim), "negative reduce axis is not supported in onnx for now")

        cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
        keepdims = helper.get_attribute_value(node.get_attr("keep_dims"))
        op_type = "ReduceMin" if node.type == "All" else "ReduceSum"

        if op_type == "ReduceSum":
            reduce_node_output = GraphBuilder(ctx).make_reduce_sum(
                {"data": cast.output[0], "axes": reduce_dim, "keepdims": keepdims, "noop_with_empty_axes": 1})
        else:
            reduce_node_output = ctx.make_node(op_type=op_type, inputs=cast.output,
                                               attr={"axes": reduce_dim, "keepdims": keepdims}).output[0]

        zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32))

        shapes = node.output_shapes
        dtypes = node.output_dtypes
        ctx.remove_node(node.name)
        ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]],
                      name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
Esempio n. 14
0
    def _connect_lstm_yc_to_graph(self, context, i):
        # in tf, y_c output shape is: [batch, hidden]
        # in onnx, output shape is: [number_directions, batch, hidden]
        gb = GraphBuilder(self.g)
        exit_output = context.state_variables["ct" + str(i)].exit_output
        output_id = context.rnn_node[i].output[2]
        lstm_yc_shape = self.g.get_shape(output_id)
        squeeze_node = gb.make_squeeze(
            {
                "data": output_id,
                "axes": [0]
            },
            shapes=[[lstm_yc_shape[1], lstm_yc_shape[2]]],
            dtypes=[self.g.get_dtype(output_id)],
            return_node=True)

        self.g.replace_all_inputs(
            exit_output.id, squeeze_node.output[0])  # ops=self.g.get_nodes()
    def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False):
        nodes_to_add = []
        shape_node = self.g.make_node("Shape", [input_id])
        nodes_to_add.append(shape_node)
        inferred_shape = self.g.get_shape(input_id)
        if handle_output is True:
            # handle output:
            # if required dim values don't contain more than one -1,
            # just use a const for Reshape's shape input.
            if inferred_shape is not None and inferred_shape[1:].count(-1) <= 1:
                new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"),
                                                   np.array(inferred_shape[1:], dtype=np.int64))
                nodes_to_add.append(new_shape_node)
            else:
                # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1)
                # from [1, time, real-batch, ...]
                origin_shape_node = self.g.make_node("Cast", [shape_node.output[0]],
                                                     {"to": onnx_pb.TensorProto.FLOAT})
                nodes_to_add.append(origin_shape_node)

                attr = {"axes": [0], "starts": [1], "ends": [sys.maxsize]}
                inputs_map = {"data": origin_shape_node.output[0], **attr}
                sliced_shape_node = GraphBuilder(self.g).make_slice(inputs_map)
                nodes_to_add.append(self.g.get_node_by_output(sliced_shape_node))

                new_shape_node = self.g.make_node("Cast", [sliced_shape_node],
                                                  {"to": onnx_pb.TensorProto.INT64})
                nodes_to_add.append(new_shape_node)

            new_shape = inferred_shape[1:]
        else:
            # handle input:
            if inferred_shape is not None and inferred_shape.count(-1) <= 1:
                new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"),
                                                   np.array([1] + inferred_shape, dtype=np.int64))
                nodes_to_add.append(new_shape_node)
            else:
                # add a fake batch size : 1
                fake_batch_size_node = self.g.make_const(utils.make_name(target_name + "_target_shape"),
                                                         np.array([1], dtype=np.int64))
                nodes_to_add.append(fake_batch_size_node)
                new_shape_node = self.g.make_node("Concat",
                                                  [fake_batch_size_node.output[0], shape_node.output[0]],
                                                  attr={"axis": 0})
                nodes_to_add.append(new_shape_node)
            new_shape = [1] + inferred_shape

        reshape_node = self.g.make_node("Reshape", [input_id, new_shape_node.output[0]],
                                        shapes=[new_shape],
                                        dtypes=[self.g.get_dtype(input_id)],
                                        op_name_scope=target_name)
        nodes_to_add.append(reshape_node)
        logger.debug("create Reshape for scan output %s, with output shape %s",
                     reshape_node.output[0], new_shape)
        return nodes_to_add
Esempio n. 16
0
    def _process_c_or_h_init_nodes(self, initializer_input_id, context):
        node = self.g.get_node_by_output(initializer_input_id)
        if node.is_const():
            val = node.get_tensor_value(as_list=False)
            initial_name = utils.make_name("Const")
            new_val = np.expand_dims(val, axis=0)
            const_node = self.g.make_const(initial_name, new_val)
            return const_node.output[0]

        gb = GraphBuilder(self.g)
        squeeze_node = gb.make_unsqueeze(
            {
                'data': initializer_input_id,
                "axes": [0]
            }, return_node=True)
        to_replace = [n for n in self.g.get_nodes() if n != squeeze_node]
        self.g.replace_all_inputs(initializer_input_id,
                                  squeeze_node.output[0],
                                  ops=to_replace)
        return squeeze_node.output[0]
Esempio n. 17
0
 def version_13(cls, ctx, node, **kwargs):
     keepdims = node.get_attr_value('keep_dims')
     reduce_input = node.input[0]
     if node.type == "All":
         reduce_input = ctx.make_node("Not", [reduce_input]).output[0]
     cast = ctx.make_node("Cast", inputs=[reduce_input], attr={"to": onnx_pb.TensorProto.FLOAT}).output[0]
     axes_cast = node.input[1]
     if ctx.get_rank(axes_cast) == 0:
         # Unsqueeze scalar axes
         axes_cast = GraphBuilder(ctx).make_unsqueeze({'data': axes_cast, 'axes': [0]})
     if ctx.get_dtype(axes_cast) != onnx_pb.TensorProto.INT64:
         axes_cast = ctx.make_node("Cast", inputs=[axes_cast], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
     reduce_node_output = GraphBuilder(ctx).make_reduce_sum(
         {"data": cast, "axes": axes_cast, "keepdims": keepdims, "noop_with_empty_axes": 1},
         shapes=node.output_shapes, op_name_scope=node.name)
     zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32))
     greater_node = ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]])
     result = greater_node.output[0]
     if node.type == "All":
         result = ctx.make_node("Not", [result]).output[0]
     ctx.replace_all_inputs(node.output[0], result)
Esempio n. 18
0
    def any_version(cls, opset, ctx, node, **kwargs):
        """
        Computes the modules of a complex.
        If the matrix dtype is not complex64 or complex128,
        it assumes the first dimension means real part (0)
        and imaginary part (1, :, :...).
        """
        supported_dtypes = [
            onnx_pb.TensorProto.FLOAT,
            onnx_pb.TensorProto.FLOAT16,
            onnx_pb.TensorProto.DOUBLE,
            onnx_pb.TensorProto.COMPLEX64,
            onnx_pb.TensorProto.COMPLEX128,
        ]
        onnx_dtype = ctx.get_dtype(node.input[0])
        utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.")
        shape = ctx.get_shape(node.input[0])
        np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
        utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape)

        ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64))
        ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64))
        p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype))

        real_part = ctx.make_node(
            'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0),
            name=utils.make_name('Real_' + node.name))
        imag_part = ctx.make_node(
            'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0),
            name=utils.make_name('Imag_' + node.name))

        real_part2 = ctx.make_node(
            'Pow', inputs=[real_part.output[0], p2.name],
            name=utils.make_name(real_part.name + 'p2p'))

        imag_part2 = ctx.make_node(
            'Pow', inputs=[imag_part.output[0], p2.name],
            name=utils.make_name(imag_part.name + 'p2p'))

        ctx.remove_node(node.name)
        add = ctx.make_node(
            "Add", inputs=[real_part2.output[0], imag_part2.output[0]],
            name=utils.make_name('ComplexAbs_' + node.name))

        squeezed = GraphBuilder(ctx).make_squeeze(
            {'data': add.output[0], 'axes': [0]}, name=utils.make_name('ComplexAbs' + node.name), return_node=True)

        last_node = ctx.make_node(
            "Sqrt", inputs=squeezed.output[:1],
            name=utils.make_name('ComplexAbs' + node.name),
            shapes=[shape[1:]], dtypes=[onnx_dtype])

        ctx.replace_all_inputs(node.output[0], last_node.output[0])  # ops=ctx.get_nodes()
Esempio n. 19
0
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm,
                                             lstm_output_index, all_nodes,
                                             to_remove):
    fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index])
    bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index])
    if not fw_consumers and not bw_consumers:
        return

    if lstm_output_index == 0:
        axis = 1
        # remove reverse op for lstm_bw
        reverse_nodes = get_reverse_nodes_after_y_output(g, lstm_bw)
        if not reverse_nodes:
            raise ValueError(
                "should not happen y_output is not followed with reverse node")

        for r_op in reverse_nodes:
            logger.debug("remove reverse op called %s", r_op.name)
            g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0])
            to_remove.append(r_op.name)
    elif lstm_output_index in [1, 2]:
        axis = 0
    else:
        raise ValueError("LSTM only should has 3 outputs.")

    if fw_consumers:
        attr = {"axes": [axis], "starts": [0], "ends": [1]}
        inputs_map = {"data": bi_lstm.output[lstm_output_index], **attr}
        slice_node_fw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_fw))
        g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index],
                             slice_node_fw)

    if bw_consumers:
        attr = {"axes": [axis], "starts": [1], "ends": [2]}
        inputs_map = {"data": bi_lstm.output[lstm_output_index], **attr}
        slice_node_bw = GraphBuilder(g).make_slice(inputs_map)
        all_nodes.append(g.get_node_by_output(slice_node_bw))
        g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index],
                             slice_node_bw)
Esempio n. 20
0
    def version_7(cls, ctx, node, **kwargs):
        tfl_while_inputs = node.input
        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        output_names = node.output

        cond_name = node.get_attr_str("cond_subgraph_index")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body_subgraph_index")
        body = find_function(body_name)
        body.parent_graph = ctx

        ctx.remove_node(node.name)

        cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding)

        # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter.
        # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph.
        scan_outputs = sorted(body.scan_outputs, reverse=True)
        def input_is_unused(g, index):
            return len(g.find_output_consumers(g.inputs[index])) == 0
        scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)]

        for idx, _ in scan_outputs:
            del tfl_while_inputs[idx]
            output_shapes.append(output_shapes.pop(idx))
            output_dtypes.append(output_dtypes.pop(idx))
            output_names.append(output_names.pop(idx))

        max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max))

        loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs,
                                  output_count=len(output_shapes), name=node.name + "_loop",
                                  shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(k, v)  # ops=ctx.get_nodes()

        body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs)

        for i in range(len(scan_outputs)):
            squeeze_node = GraphBuilder(body).make_squeeze(
                {'data': body.outputs[-1-i], "axes": [0]}, return_node=True)
            body.outputs[-1-i] = squeeze_node.output[0]

        loop_node.set_body_graph_as_attr("body", body)
Esempio n. 21
0
    def any_version(cls, opset, ctx, node, **kwargs):
        if node.type == "StringSplit":
            skip_empty = node.get_attr_value('skip_empty', True)
        else:
            skip_empty = False
        node.type = "StringSplit"
        node.domain = constants.CONTRIB_OPS_DOMAIN
        for a in list(node.attr.keys()):
            del node.attr[a]
        unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': node.input[1], 'axes': [0]}, return_node=True)

        skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'), np.array([skip_empty], np.bool))
        ctx.replace_inputs(node, [node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0]])
Esempio n. 22
0
 def _connect_gru_state_to_graph(self, context):
     # in tf, state output shape is: [batch, hidden]
     # in onnx, output shape is: [number_directions, batch, hidden]
     exit_output_id = context.state_variables["state"].exit_output.id
     if not exit_output_id:
         logger.debug("no one consume state variable")
         return
     output_id = context.rnn_node.output[1]
     gru_state_shape = self.g.get_shape(output_id)
     output_shape = [gru_state_shape[1], gru_state_shape[2]]
     squeeze_node = GraphBuilder(self.g).make_squeeze(
         {'data': output_id, "axes": [0]}, shapes=[output_shape],
         dtypes=[self.g.get_dtype(output_id)], return_node=True)
     self.g.replace_all_inputs(exit_output_id, squeeze_node.output[0])  # ops=self.g.get_nodes()
Esempio n. 23
0
    def version_1(cls, ctx, node, **kwargs):
        node.domain = constants.CONTRIB_OPS_DOMAIN
        input_node = node.inputs[0]
        utils.make_sure(input_node.type == "SentencepieceOp", "Input 0 to node %s is not SentencepieceOp", node.name)
        ctx.remove_input(node, node.input[0], 0)

        nbest_size_cast = ctx.make_node("Cast", [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
        ctx.replace_input(node, node.input[1], nbest_size_cast, 1)
        for i in range(1, len(node.input)):
            unsqueeze = GraphBuilder(ctx).make_unsqueeze({'data': node.input[i], 'axes': [0]})
            ctx.replace_input(node, node.input[i], unsqueeze, i)
        node.set_attr("model", input_node.attr['model'].s)
        node.type = "SentencepieceTokenizer"
        if ctx.is_safe_to_remove_nodes([input_node]):
            ctx.remove_node(input_node.name)
Esempio n. 24
0
    def _convert_since_9(cls, ctx, node, op_type, roi_required=False):

        # float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size)
        # https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor
        # wants the input to be NHWC - adjust target_shape to this.
        mode = "linear" if node.type == "ResizeBilinear" else "nearest"

        # first create "scales" info for onnx upsample
        # if shape of input and output known then  "scale" is calculated statically and set as a const node
        shape = ctx.get_shape(node.input[0])
        if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const():
            target_shape = node.inputs[1].get_tensor_value()
            n, h, w, c = shape
            nh, nw = target_shape
            # scales is nchw
            # the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852
            scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32)
            scales = ctx.make_const(utils.make_name("scales"), scale_val, raw=False)
        else:
            ori_shape = ctx.make_node("Shape", [node.input[0]])
            attr = {"axes": [0], "starts": [1], "ends": [3]}
            inputs_map = {"data": ori_shape.output[0], **attr}
            ori_shape_hw = GraphBuilder(ctx).make_slice(inputs_map)
            ori_shape_hw_float = ctx.make_node("Cast", [ori_shape_hw], attr={"to": onnx_pb.TensorProto.FLOAT})

            target_hw = node.inputs[1]
            target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT})

            scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]])

            const_one_array = ctx.make_const(utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32))
            # scales is nchw
            scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
        # because onnxruntime only supports to scale the last two dims so transpose is inserted
        input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
        if roi_required:
            roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32))
            upsample = ctx.make_node("Resize", [input_nchw.output[0], roi.output[0], scales.output[0]],
                                     attr={"mode": mode, "nearest_mode": "floor",
                                           "coordinate_transformation_mode": "asymmetric"})
        else:
            upsample = ctx.make_node(op_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode})

        shapes = node.output_shapes
        dtypes = node.output_dtypes
        ctx.remove_node(node.name)
        ctx.make_node("Transpose", upsample.output, {"perm": constants.NCHW_TO_NHWC},
                      name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
Esempio n. 25
0
def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
    dtype = g.get_dtype(output.output[0])
    op_name = utils.make_name("RandomUniform")
    shape_node = ru_op.inputs[0]
    shape = g.get_shape(output.output[0])
    if shape_node.is_const():
        # if the tensorflow input (aka the shape) is const we can use the RandomUniform op
        needs_squeeze = False
        if len(shape) == 0:
            shape = [1]
            needs_squeeze = True
        new_node = g.make_node("RandomUniform", [], name=op_name,
                               attr={"low": tmin, "high": tmax, "dtype": dtype, "shape": shape},
                               shapes=[shape], dtypes=[dtype])
        if needs_squeeze:
            new_node = GraphBuilder(g).make_squeeze({"data": new_node.output[0], "axes": [0]}, return_node=True)
    else:
        if shape_node.type == "Shape":
            # if shape is dynamic - in tensorflow shape comes as tensor VALUE,
            # in onnx RandomUniformLike finds takes the shape from the tensor itself.
            # In many cases there is a shape op in tensorflow before RandomUniform and
            # to make that work for onnx we just need to remove the shape op.
            new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name,
                                   attr={"low": tmin, "high": tmax, "dtype": dtype},
                                   shapes=[shape], dtypes=[dtype])
        else:
            # if the shape is calculated we need to create a tensor so RandomUniformLike
            # can take the shape from there. Pre opset9 this is somewhat hacky because there is
            # no real fill op in onnx. In general this is not going to help performance but the tensors
            # created are expected to be small.

            # tell the caller to not delete the shape node
            to_delete.remove(shape_node)
            # create a fill op with the shape of the value of the input tensor
            zero = g.make_const(utils.make_name("zero"), np.zeros((), dtype=np.float32))
            fill_node = g.make_node("Fill", inputs=[shape_node.output[0], zero.name],
                                    shapes=[shape], dtypes=[dtype])
            func, _ = handler.tf_op.find_effective_op("Fill")
            func(g, fill_node)
            # and use RandomUniformLike to create the random tensor
            new_node = g.make_node("RandomUniformLike", inputs=[fill_node.output[0]], name=op_name,
                                   attr={"low": tmin, "high": tmax, "dtype": dtype},
                                   shapes=[shape], dtypes=[dtype])
    return new_node
Esempio n. 26
0
 def process_var_init_nodes(self, context):
     assert "state" in context.state_variables.keys()
     initializer_input_id = context.state_variables["state"].enter_input_id
     node = self.g.get_node_by_output(initializer_input_id)
     if node.is_const():
         val = node.get_tensor_value(as_list=False)
         initial_name = utils.make_name("Const")
         new_val = np.expand_dims(val, axis=0)
         const_node = self.g.make_const(initial_name, new_val)
         context.onnx_input_ids["initial_state"] = const_node.output[0]
         return
     squeeze_node = GraphBuilder(self.g).make_unsqueeze(
         {
             'data': initializer_input_id,
             'axes': [0]
         }, return_node=True)
     to_replace = [n for n in self.g.get_nodes() if n != squeeze_node]
     self.g.replace_all_inputs(initializer_input_id,
                               squeeze_node.output[0],
                               ops=to_replace)
     context.onnx_input_ids["initial_state"] = squeeze_node.output[0]
Esempio n. 27
0
 def _optimize_reduce(self, node, graph):
     if graph.get_dtype(
             node.output[0]) not in [TensorProto.FLOAT, TensorProto.DOUBLE]:
         return False
     if node.output[0] in graph.outputs:
         # Replacement is unsafe
         return False
     axes = node.get_attr_value('axes')
     inp_rank = graph.get_rank(node.input[0])
     if inp_rank is None:
         return False
     if axes != list(range(2, inp_rank)):
         return False
     op_map = {
         "ReduceMean": "GlobalAveragePool",
         "ReduceMax": "GlobalMaxPool"
     }
     node.type = op_map[node.type]
     del node.attr['axes']
     if not node.get_attr_value('keepdims', True):
         out_shapes = node.output_shapes
         out_dtypes = node.output_dtypes
         new_out_shape = graph.get_shape(
             node.input[0])[:2] + [1] * len(axes)
         graph.set_shape(node.output[0], new_out_shape)
         squeeze_node = GraphBuilder(graph).make_squeeze(
             {
                 'data': node.output[0],
                 'axes': axes
             },
             shapes=out_shapes,
             dtypes=out_dtypes,
             return_node=True,
             op_name_scope=node.name)
         graph.insert_node_on_output(squeeze_node, node.output[0])
     if 'keepdims' in node.attr:
         del node.attr['keepdims']
     return True
Esempio n. 28
0
    def _optimize_reshape(self, node, graph):
        if node.inputs[1].is_const():
            return False
        inp_shape = graph.get_shape(node.input[0])
        if inp_shape is None:
            # The rank must be known
            return False
        feed_dict = {}
        for n in graph.find_output_consumers(node.input[0]):
            if n.type == "Shape":
                symbolic_shape = []
                for i, d in enumerate(inp_shape):
                    if d == -1:
                        # Make a variable representing each unknown dim
                        symbolic_shape.append(
                            SymbolicTensorElement.from_variable(i))
                    else:
                        symbolic_shape.append(
                            SymbolicTensorElement.from_const(d))
                feed_dict[n.output[0]] = np.array(symbolic_shape, np.object)
        try:
            symbolic_res = SymbolicExecutor(graph).compute_outputs(
                [node.input[1]], feed_dict)
        except SymbolicExecutionException:
            return False
        utils.make_sure(
            len(symbolic_res[0].shape) == 1, "Shape must have rank 1")
        symbolic_shape = symbolic_res[0].tolist()
        product_cnt = len(
            [val for val in symbolic_shape if val.has_multiple_terms()])
        idx_cnt = len([val for val in symbolic_shape if val.is_single_var()])
        if product_cnt > 1:
            # The -1 lets us handle at most one dim with multiple terms
            return False
        if idx_cnt + product_cnt <= 1:
            # Only 1 non-const dim. Use -1 and consts for the rest.
            new_shape = [
                v.constant if v.is_const() else -1 for v in symbolic_shape
            ]
            shift = 0
        else:
            # We will need to use some 0s. We can shift using squeeze/unsqueeze to line up equal dims
            def get_shift(val, i):
                if not val.is_single_var():
                    return None
                return val.terms[0] - i

            shifts = [
                get_shift(val, i) for i, val in enumerate(symbolic_shape)
            ]
            # Find the most popular shift
            most_common = Counter(s for s in shifts
                                  if s is not None).most_common(1)
            shift = most_common[0][0] if most_common else 0

            def get_reshape_dim(val, i, shift):
                if val.is_const():
                    return self.constant
                if get_shift(val, i) == shift:
                    return 0
                # Use -1 only as a last resort
                return -1

            new_shape = [
                get_reshape_dim(v, i, shift)
                for i, v in enumerate(symbolic_shape)
            ]
        if new_shape.count(-1) > 1:
            return False

        if shift > 0:
            new_shape = [1] * shift + new_shape
            squeeze_node = GraphBuilder(graph).make_squeeze(
                {
                    'data': node.output[0],
                    'axes': list(range(shift))
                },
                return_node=True,
                shapes=node.output_shapes,
                dtypes=node.output_dtypes)
            graph.insert_node_on_output(squeeze_node, node.output[0])
        const_shape = graph.make_const(utils.make_name(node.name + "_shape"),
                                       np.array(new_shape, np.int64)).output[0]
        graph.replace_inputs(node, [node.input[0], const_shape])
        if shift < 0:
            unsqueeze_node = GraphBuilder(graph).make_unsqueeze({
                'data':
                node.input[0],
                'axes':
                list(range(-shift))
            })
            graph.replace_inputs(node, [unsqueeze_node, const_shape])

        return True
Esempio n. 29
0
    def version_7(cls, ctx, node, **kwargs):
        # T output = MatrixBandPart(T input, int num_lower, int num_upper)
        # data-flow: first generate mask matrix and then use element-wise mul op
        input_rank = len(ctx.get_shape(node.input[0]))
        utils.make_sure(
            input_rank == 2,
            error_msg="MatrixBandPart op: only rank 2 is supported")
        bandpart = [node.inputs[ind].get_tensor_value() for ind in [1, 2]]
        utils.make_sure(bandpart in [[-1, 0], [0, -1]],
                        "only support Lower/Upper triangular for now")
        # methods to generate mask matrix: if lower triangular is needed, then generate column one by one
        # otherwise row is generated one by one.
        axis, counter_axis, squeeze_axis = (1, 0,
                                            2) if bandpart == [-1, 0] else (0,
                                                                            1,
                                                                            1)
        # 1: subgraph to implement tf.onelike(input[:, 0]),
        # no need to worry about the dtype, because bool type is needed as Xor only support bool
        node_name = utils.make_name("const_zero")
        const_zero = ctx.make_const(name=node_name,
                                    np_val=np.array([0]).astype(np.int32))
        first_col_or_row = ctx.make_node(
            op_type="Gather",
            inputs=[node.input[0], const_zero.output[0]],
            attr={"axis": axis})
        first_col_or_row_casted = ctx.make_node(
            op_type="Cast",
            inputs=first_col_or_row.output,
            attr={"to": onnx_pb.TensorProto.BOOL})
        # line means one col or one row
        zero_line = ctx.make_node(op_type="Xor",
                                  inputs=first_col_or_row_casted.output * 2)
        one_line = ctx.make_node(op_type="Not", inputs=zero_line.output)

        # 2: "loop" to generate mask matrix: generate col or row of matrix one by one
        g = ctx.create_new_graph_with_same_config()
        node_name = utils.make_name("const_zero_bool")
        const_zero_bool = ctx.make_const(name=node_name,
                                         np_val=np.array([[0]
                                                          ]).astype(np.bool))
        ctx.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)

        # shift right the line and add zero at the left.
        new_line = g.make_node(op_type="Concat",
                               inputs=[const_zero_bool.output[0], "line"],
                               attr={"axis": counter_axis},
                               dtypes=[onnx_pb.TensorProto.BOOL])
        attr = {"axes": [counter_axis], "starts": [0], "ends": [-1]}
        inputs_map = {"data": new_line.output[0], **attr}
        slice_node = GraphBuilder(g).make_slice(inputs_map)

        g.make_node("Identity", ["cond"], outputs=["cond_out"])
        g.make_node("Identity", ["line"], outputs=["res"])
        g.make_node("Identity", [slice_node], outputs=["line_out"])

        g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
        g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])
        g.add_graph_input("line", onnx_pb.TensorProto.BOOL, [-1, -1])

        g.add_graph_output("cond_out", onnx_pb.TensorProto.BOOL, [])
        g.add_graph_output("line_out", onnx_pb.TensorProto.BOOL, [-1, -1])
        g.add_graph_output("res", onnx_pb.TensorProto.BOOL, [-1, -1])

        # initial value of body vars
        shape = ctx.make_node(op_type="Shape",
                              inputs=[node.input[0]
                                      ])  # dtype of result is int64
        node_name = utils.make_name("line_num_index")
        col_or_row_num_index = ctx.make_const(name=node_name,
                                              np_val=np.array(axis).astype(
                                                  np.int32))
        line_num = ctx.make_node(
            op_type="Gather",
            inputs=[shape.output[0], col_or_row_num_index.output[0]])
        trip_cnt = line_num.output[0]
        node_name = utils.make_name("true")
        cond = ctx.make_const(name=node_name,
                              np_val=np.array(1).astype(np.bool))
        col_init = one_line.output[0]

        loop_node = ctx.make_node(op_type="Loop",
                                  inputs=[trip_cnt, cond.output[0], col_init],
                                  output_count=2)
        loop_node.set_body_graph_as_attr("body", g)
        # convert generated mask matrix from bool to right shape and data type
        squeeze = ctx.make_node(op_type="Squeeze",
                                inputs=[loop_node.output[1]],
                                attr={"axes": [squeeze_axis]})
        cast1 = ctx.make_node(op_type="Cast",
                              inputs=squeeze.output,
                              attr={"to": onnx_pb.TensorProto.FLOAT})
        if axis == 1:
            mask_matrix = ctx.make_node(op_type="Transpose",
                                        inputs=cast1.output)
        else:
            mask_matrix = squeeze
        cast2 = ctx.make_node(op_type="Cast",
                              inputs=mask_matrix.output,
                              attr={"to": ctx.get_dtype(node.input[0])})
        shapes = node.output_shapes
        dtypes = node.output_dtypes
        ctx.remove_node(node.name)
        ctx.make_node(op_type="Mul",
                      inputs=[cast2.output[0], node.input[0]],
                      name=node.name,
                      outputs=node.output,
                      shapes=shapes,
                      dtypes=dtypes)
Esempio n. 30
0
    def version_9(cls, ctx, node, **kwargs):
        node_inputs = node.input
        num_segments_specified = False
        if node.type.endswith("WithNumSegments") or node.type.startswith(
                "Unsorted"):
            num_segments_specified = True
            num_segments = node_inputs.pop()
            node.type = node.type.replace("WithNumSegments", "")
            node.type = node.type.replace("Unsorted", "")
        if node.type.startswith("Sparse"):
            data_inp, indices_inp, segment_inp = node_inputs
            gather_node = ctx.make_node("Gather", [data_inp, indices_inp],
                                        attr={'axis': 0})
            data_inp = gather_node.output[0]
            node.type = node.type.replace("Sparse", "")
        else:
            data_inp, segment_inp = node_inputs

        # Data has shape [n, a, b, ..., c]
        data_shape = ctx.get_shape(data_inp)
        data_rank = len(data_shape) if data_shape is not None else None
        data_dtype = ctx.get_dtype(data_inp)
        data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype)
        seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))

        if num_segments_specified and ctx.get_dtype(
                segment_inp) != ctx.get_dtype(num_segments):
            num_segments = ctx.make_node("Cast", [num_segments],
                                         attr={
                                             "to": ctx.get_dtype(segment_inp)
                                         }).output[0]

        data_is_float = np.dtype(data_np_dtype).kind == 'f'
        data_is_int = np.dtype(data_np_dtype).kind == 'i'
        utils.make_sure(data_is_float or data_is_int,
                        "dtype for Segment ops must be float or int")

        if node.type in ["SegmentSum", "SegmentMean", "SegmentSqrtN"]:
            onnx_op = "ReduceSum"
            identity_value = np.array(0, dtype=data_np_dtype)
        elif node.type == "SegmentProd":
            onnx_op = "ReduceProd"
            identity_value = np.array(1, dtype=data_np_dtype)
        elif node.type == "SegmentMax":
            onnx_op = "ReduceMax"
            if data_is_float:
                identity_value = np.array('-inf', dtype=data_np_dtype)
            else:
                identity_value = np.iinfo(data_np_dtype).min
        elif node.type == "SegmentMin":
            onnx_op = "ReduceMin"
            if data_is_float:
                identity_value = np.array('inf', dtype=data_np_dtype)
            else:
                identity_value = np.iinfo(data_np_dtype).max

        if not num_segments_specified:
            max_segment = ctx.make_node("ReduceMax", [segment_inp],
                                        attr={
                                            'axes': [0],
                                            'keepdims': 0
                                        })
            one_const = ctx.make_const(utils.make_name("const_one"),
                                       np.array(1, dtype=seg_np_dtype))
            num_segments = ctx.make_node(
                "Add", [max_segment.output[0], one_const.output[0]]).output[0]
        # ORT doesn't support bool for OneHot so we use float32 and cast to bool
        onehot_values = ctx.make_const(utils.make_name("onehot_values"),
                                       np.array([0, 1], dtype=np.float32))
        # one_hot_node has shape [s, n] (s is # segments)
        one_hot_node = ctx.make_node(
            "OneHot", [segment_inp, num_segments, onehot_values.output[0]],
            attr={'axis': 0})
        if node.type == "SegmentMean":
            scaling_node_output = GraphBuilder(ctx).make_reduce_sum({
                "data":
                one_hot_node.output[0],
                "axes": [1],
                "keepdims":
                0,
                "noop_with_empty_axes":
                1
            })
        elif node.type == "SegmentSqrtN":
            seg_cnts_node_output = GraphBuilder(ctx).make_reduce_sum({
                "data":
                one_hot_node.output[0],
                "axes": [1],
                "keepdims":
                0,
                "noop_with_empty_axes":
                1
            })
            scaling_node_output = ctx.make_node(
                "Sqrt", [seg_cnts_node_output]).output[0]
        else:
            scaling_node_output = None

        if scaling_node_output is not None and num_segments_specified:
            # If empty segments are possible, we must avoid division by zero
            const_one_float = ctx.make_const(
                utils.make_name("const_one_float"),
                np.array(1, dtype=np.float32))
            scaling_node_output = ctx.make_node(
                "Max",
                [scaling_node_output, const_one_float.output[0]]).output[0]

        if onnx_op == "ReduceSum":
            # If the op is a summation, we can use MatMul instead of Where, which is faster

            # Data shape is [n, a, b, ..., c]
            data_shape_node = ctx.make_node("Shape", [data_inp])
            new_shape = ctx.make_const(utils.make_name("reshape_const"),
                                       np.array([0, -1], dtype=np.int64))
            # Reshape the data from [n, a, b, ..., c] to [n, P]
            data_reshape = ctx.make_node("Reshape",
                                         [data_inp, new_shape.output[0]])

            one_hot_cast = one_hot_node
            if data_dtype != onnx_pb.TensorProto.FLOAT:
                one_hot_cast = ctx.make_node("Cast", [one_hot_node.output[0]],
                                             attr={'to': data_dtype})

            # Shapes [s, n] * [n, P] => [s, P]
            product = ctx.make_node(
                "MatMul", [one_hot_cast.output[0], data_reshape.output[0]],
                op_name_scope=node.name)
            if scaling_node_output is not None:
                scaling_node_unsqueeze = ctx.make_node("Unsqueeze",
                                                       [scaling_node_output],
                                                       attr={'axes': [1]})
                product = ctx.make_node(
                    "Div",
                    [product.output[0], scaling_node_unsqueeze.output[0]])

            # Create new shape [0, a, b, ..., c]
            max_int64 = int(utils.get_max_value(np.int64))
            new_shape_slice = GraphBuilder(ctx).make_slice({
                "data":
                data_shape_node.output[0],
                "ends": [max_int64],
                "starts": [1],
                "axes": [0]
            })
            zero_const = ctx.make_const(utils.make_name("zero_const"),
                                        np.array([0], dtype=np.int64))
            new_shape = ctx.make_node("Concat",
                                      [zero_const.output[0], new_shape_slice],
                                      attr={'axis': 0})

            shapes = node.output_shapes
            dtypes = node.output_dtypes
            ctx.remove_node(node.name)
            # Reshape result from [s, P] to [s, a, b, ..., c]
            ctx.make_node("Reshape", [product.output[0], new_shape.output[0]],
                          name=node.name,
                          outputs=node.output,
                          shapes=shapes,
                          dtypes=dtypes)
            return

        identity_const = ctx.make_const(utils.make_name("const_identity"),
                                        identity_value)
        one_hot_bool = ctx.make_node("Cast", [one_hot_node.output[0]],
                                     attr={"to": onnx_pb.TensorProto.BOOL})
        one_hot_unsqueeze = one_hot_bool

        # Make one_hot_unsqueeze have shape [s, n, 1, 1, ..., 1]
        if data_rank is None:
            # Unsqueeze requires known rank, but we can use Reshape if rank is unknown
            shape_node = ctx.make_node("Shape", [data_inp])
            rank_node = ctx.make_node("Shape", [shape_node.output[0]])
            one_const_int64 = ctx.make_const(utils.make_name("const_one"),
                                             np.array([1], dtype=np.int64))
            num_unsqueeze_dims = ctx.make_node(
                "Sub", [rank_node.output[0], one_const_int64.output[0]])

            one_tensor = helper.make_tensor("value",
                                            onnx_pb.TensorProto.INT64,
                                            dims=[1],
                                            vals=[1])
            unsqueeze_dims = ctx.make_node(
                "ConstantOfShape",
                inputs=[num_unsqueeze_dims.output[0]],
                attr={"value": one_tensor})
            # Zero indicates a dimension should be unchanged
            double_zero_const = ctx.make_const(
                utils.make_name("double_zero"), np.array([0, 0],
                                                         dtype=np.int64))
            expanded_shape = ctx.make_node(
                "Concat",
                [double_zero_const.output[0], unsqueeze_dims.output[0]],
                attr={'axis': 0})
            one_hot_unsqueeze = ctx.make_node(
                "Reshape", [one_hot_bool.output[0], expanded_shape.output[0]])
        elif data_rank > 1:
            new_dims = list(range(2, 2 + data_rank - 1))
            one_hot_unsqueeze = ctx.make_node("Unsqueeze",
                                              [one_hot_bool.output[0]],
                                              attr={'axes': new_dims})

        # Shape of data:       [n, a, b, ..., c]
        # Shape of one_hot: [s, n, 1, 1, ..., 1]
        # Broadcast left-pads shape with 1s, so result is shape: [s, n, a, b, ..., c]
        where_node = ctx.make_node(
            "Where",
            [one_hot_unsqueeze.output[0], data_inp, identity_const.output[0]])

        shapes = node.output_shapes
        dtypes = node.output_dtypes
        ctx.remove_node(node.name)
        # After reduction over axis 1, shape is: [s, a, b, ..., c]
        ctx.make_node(onnx_op, [where_node.output[0]],
                      attr={
                          'axes': [1],
                          'keepdims': 0
                      },
                      name=node.name,
                      outputs=node.output,
                      shapes=shapes,
                      dtypes=dtypes)