def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output_index, all_nodes, to_remove): fw_consumers = g.find_output_consumers(rnn_fw.output[rnn_output_index]) bw_consumers = g.find_output_consumers(rnn_bw.output[rnn_output_index]) if not fw_consumers and not bw_consumers: return if rnn_output_index == 0: axis = 1 # remove reverse op for rnn_bw reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw) for r_op in reverse_nodes: logger.debug("remove reverse op %s", r_op.name) g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0]) to_remove.append(r_op.name) elif rnn_output_index in [1, 2]: axis = 0 else: raise ValueError("rnn only should has 3 outputs.") if fw_consumers: attr = {"axes": [axis], "starts": [0], "ends": [1]} inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr} slice_node_fw = GraphBuilder(g).make_slice(inputs_map) all_nodes.append(g.get_node_by_output(slice_node_fw)) g.replace_all_inputs(fw_consumers, rnn_fw.output[rnn_output_index], slice_node_fw) if bw_consumers: attr = {"axes": [axis], "starts": [1], "ends": [2]} inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr} slice_node_bw = GraphBuilder(g).make_slice(inputs_map) all_nodes.append(g.get_node_by_output(slice_node_bw)) g.replace_all_inputs(bw_consumers, rnn_bw.output[rnn_output_index], slice_node_bw)
def create_rnn_node(self, context): gb = GraphBuilder(self.g) rnn_nodes = list() outputs = context.loop_properties.scan_outputs_exits logger.debug("number of rnn node outputs: %s", len(outputs)) for i in range(self.num_lstm_layers): logger.debug("creating rnn node for layer: %s", i) rnn_nodes.append(self.create_single_rnn_node(context, i)) output_id = rnn_nodes[i].output[0] rnn_output_shape = self.g.get_shape(output_id) squeeze_output_shape = [ rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3] ] squeeze_node = gb.make_squeeze( { "data": output_id, "axes": [1] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(output_id)], return_node=True) if i + 1 < self.num_lstm_layers: logger.debug("setting input for layer: %s", i + 1) context.onnx_input_ids[i + 1]["X"] = squeeze_node.output[0] return rnn_nodes
def process_seq_length(self, context): # output: [time step, batch size, input size] seq_len_node = context.seq_len_node shape_node = self.g.make_node("Shape", [context.onnx_input_ids["X"]]) # LSTMCell only allow inputs of [batch size, input_size], so we assume dynamic_rnn has 3 dims. # Slice cannot support Int64 in OPSET 7, so we cast here. cast_shape_node = self.g.make_node( "Cast", [shape_node.output[0]], attr={"to": TensorProto.FLOAT}, shapes=[self.g.get_shape(shape_node.output[0])]) attr = {"axes": [0], "starts": [1], "ends": [2]} inputs_map = {"data": cast_shape_node.output[0], **attr} batchsize_node = GraphBuilder(self.g).make_slice(inputs_map) if not seq_len_node: # Tile's repeats must be INT64 repeat_node = self.g.make_node("Cast", [batchsize_node], attr={"to": TensorProto.INT64}) attr = {"axes": [0], "starts": [0], "ends": [1]} inputs_map = {"data": cast_shape_node.output[0], **attr} timestep_node = GraphBuilder(self.g).make_slice(inputs_map) tile_node = self.g.make_node( "Tile", [timestep_node, repeat_node.output[0]]) # LSTM sequence_lens needs to be int32 seq_len_node = self.g.make_node("Cast", [tile_node.output[0]], attr={"to": TensorProto.INT32}) context.onnx_input_ids["sequence_lens"] = seq_len_node.output[0]
def version_11(cls, ctx, node, **kwargs): # create loop of resize to cater to tensorflow CropAndResize, one box one iteration mode = "nearest" if node.get_attr("method") is not None and node.get_attr( "method").s == b"nearest" else "linear" extrapolation_value = float(node.get_attr("extrapolation_value", "0").f) input_x = node.inputs[0] boxes = node.inputs[1] box_ind = node.inputs[2] crop_size = node.inputs[3] trip_name = utils.make_name(node.name + "_i") cond_name = utils.make_name(node.name + "_cond") cond_out_name = utils.make_name(node.name + "cond_out") g = ctx.create_new_graph_with_same_config() g.add_graph_input(trip_name, TensorProto.INT64, [1]) g.add_graph_input(cond_name, TensorProto.BOOL, []) g.parent_graph = ctx const_zero = g.make_const(utils.make_name(node.name + "_const_zero"), np.array([0], dtype=np.int32)) const_zero_long = g.make_const(utils.make_name(node.name + "_const_zero_long"), np.array([0], dtype=np.int64)) const_one = g.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int32)) const_one_long = g.make_const(utils.make_name(node.name + "_const_one_long"), np.array([1], dtype=np.int64)) index_end = g.make_node("Add", [trip_name, const_one_long.output[0]]) box_index_from = g.make_node("Slice", [box_ind.output[0], trip_name, index_end.output[0]], name="Slice_a") box_index_to = g.make_node("Add", [box_index_from.output[0], const_one.output[0]]) target_x = g.make_node("Slice", [input_x.output[0], box_index_from.output[0], box_index_to.output[0], const_zero.output[0]], name="Slice_b") transposed_x = g.make_node("Transpose", [target_x.output[0]], attr={'perm': constants.NHWC_TO_NCHW}) shape_of_transposed_x = g.make_node("Shape", [transposed_x.output[0]]) const_zero_zero = g.make_const(utils.make_name(node.name + "_const_zero_zero"), np.array([0, 0], dtype=np.float32)) const_one_one = g.make_const(utils.make_name(node.name + "_const_one_one"), np.array([1, 1], dtype=np.float32)) const_four = g.make_const(utils.make_name(node.name + "_const_four"), np.array([4], dtype=np.int64)) const_empty_float = g.make_const(utils.make_name("const_empty_float"), np.array([], dtype=np.float32)) first_half_of_shape = GraphBuilder(g).make_slice( {"data": shape_of_transposed_x.output[0], "ends": [2], "starts": [0]}) box = g.make_node("Slice", [boxes.output[0], trip_name, index_end.output[0], const_zero_long.output[0]], name="Slice_c") roi_raw = g.make_node("Reshape", [box.output[0], const_four.output[0]]) roi_raw_first_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [2], "starts": [0]}) roi_raw_second_half = GraphBuilder(g).make_slice({"data": roi_raw.output[0], "ends": [4], "starts": [2]}) roi_concat_1 = g.make_node("Concat", [const_zero_zero.output[0], roi_raw_first_half], attr={'axis': 0}) roi_concat_2 = g.make_node("Concat", [const_one_one.output[0], roi_raw_second_half], attr={'axis': 0}) final_roi = g.make_node("Concat", [roi_concat_1.output[0], roi_concat_2.output[0]], attr={'axis': 0}) crop_size_int64 = g.make_node("Cast", [crop_size.output[0]], attr={'to': TensorProto.INT64}) final_crop_size = g.make_node("Concat", [first_half_of_shape, crop_size_int64.output[0]], {'axis': 0}) resized_x = g.make_node("Resize", [transposed_x.output[0], final_roi.output[0], const_empty_float.output[0], final_crop_size.output[0]], attr={"mode": mode, "extrapolation_value": extrapolation_value, "coordinate_transformation_mode": "tf_crop_and_resize"}) recovered_x = g.make_node("Transpose", [resized_x.output[0]], attr={'perm': constants.NCHW_TO_NHWC}) squeeze_x = g.make_node("Squeeze", inputs=[recovered_x.output[0]], attr={"axes": [0]}) g.make_node("Identity", [cond_name], outputs=[cond_out_name]) g.add_graph_output(cond_out_name, TensorProto.BOOL, []) g.add_graph_output(squeeze_x.output[0], ctx.get_dtype(node.input[0]), [-1, -1, -1]) trip_node = ctx.make_node("Size", [box_ind.output[0]]) cond_const = ctx.make_const(utils.make_name("cond"), np.ones((), dtype=np.bool)) ctx.remove_node(node.name) inner_loop = ctx.make_node("Loop", [trip_node.output[0], cond_const.output[0]], name=node.name, outputs=node.output) inner_loop.set_body_graph_as_attr("body", g)
def _connect_lstm_ych_to_graph(self, context, i): # in tf, concat of y_c and y_h output shape is: [batch, hidden *2] # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden] gb = GraphBuilder(self.g) exit_output = context.state_variables["ct_ht" + str(i)].exit_output lstm_node = context.rnn_node[i] yc_shape = self.g.get_shape(lstm_node.output[2]) concat_output_shape = [yc_shape[0], yc_shape[1], yc_shape[2] * 2] concat = self.g.make_node( "Concat", [lstm_node.output[2], lstm_node.output[1]], attr={"axis": 2}, shapes=[concat_output_shape], dtypes=[self.g.get_dtype(lstm_node.output[2])]) squeeze_output_shape = [concat_output_shape[1], concat_output_shape[2]] squeeze_node = gb.make_squeeze( { 'data': concat.output[0], "axes": [0] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(concat.output[0])], return_node=True) self.g.replace_all_inputs( exit_output.id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def connect_unit_rnn_output_to_graph(self, context): outputs = context.loop_properties.scan_outputs_exits if not outputs: logger.debug("no one consume output") return gather_output_id = outputs[0].id logger.debug("found output for rnn: %s", gather_output_id) # in tf batch major mode, output shape is : [batch, time, hidden] # in time major mode, output shape is: [time, batch, hidden] # in onnx, output shape is : [time, num_directions, batch, hidden] rnn_node = context.rnn_node output_id = rnn_node.output[0] rnn_output_shape = self.g.get_shape(output_id) squeeze_output_shape = [ rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3] ] gb = GraphBuilder(self.g) squeeze_node = gb.make_squeeze({ 'data': output_id, "axes": [1] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(output_id)], return_node=True) self.g.replace_all_inputs( gather_output_id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def version_13(cls, ctx, node, **kwargs): ctx.ta_reads.append(node.input[0]) node.type = "Gather" ctx.replace_inputs(node, [node.input[0], node.input[1]]) g = GraphBuilder(ctx) usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True) ctx.insert_node_on_output(usq_node) sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True) ctx.insert_node_on_output(sq_node)
def _process_non_tuple_ch_init_nodes(self, context, i): input_id = context.state_variables["ct_ht" + str(i)].enter_input_id hidden_size = context.hidden_size[i] attr = {"axes": [1], "starts": [0], "ends": [hidden_size]} inputs_map = {"data": input_id, **attr} slice_node1 = GraphBuilder(self.g).make_slice(inputs_map) unsqueeze_node_1 = self.g.make_node("Unsqueeze", [slice_node1], attr={"axes": [0]}) attr = {"axes": [1], "starts": [hidden_size], "ends": [hidden_size * 2]} inputs_map = {"data": input_id, **attr} slice_node2 = GraphBuilder(self.g).make_slice(inputs_map) unsqueeze_node_2 = self.g.make_node("Unsqueeze", [slice_node2], attr={"axes": [0]}) return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
def version_9(cls, ctx, node, **kwargs): # float32/64 output = SparseSoftmaxCrossEntropyWithLogits(float32/64 features, int32/64 labels) # the detail math process of this op is: a = onehot(labels), b = logsoftmax(features), reduce_sum(mul(a, b)) logit_node = node.inputs[0] logit_shape = ctx.get_shape(node.input[0]) logit_dtype = ctx.get_dtype(node.input[0]) label_name = node.input[1] if logit_shape is not None and logit_shape[-1] != -1: num_class = logit_shape[-1] node_nme = utils.make_name("onehot_depth") depth_node = ctx.make_const(node_nme, np.array([num_class]).astype(np.int64)).output[0] else: logit_shape = ctx.make_node("Shape", [node.input[0]]).output[0] slice_args = {"data": logit_shape, "starts": [-1], "ends": [int(utils.get_max_value(np.int32))]} num_class = GraphBuilder(ctx).make_slice(kwargs=slice_args) depth_node = num_class values_node = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1]).astype(np.int64)).output[0] label_dtype = ctx.get_dtype(label_name) if label_dtype != TensorProto.INT64: onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0] else: onehot_indice = label_name label_node = ctx.make_node(op_type="OneHot", inputs=[onehot_indice, depth_node, values_node]) # the above logic makes output dtype of label_node now always int64 # make sure label has same dtype as logit if logit_dtype != TensorProto.INT64: label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype]) _make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)
def version_1(cls, ctx, node, **kwargs): # in tf-2.0 grappler optimizes the graph pretty well and our matching logic # in the rewriter does not trigger. grappler will send the random uniform # with shape as input so we need to pickup the input here and if the shape is # const we make it an attribute. seed = node.get_attr("seed") node.set_attr("seed", float(seed.f)) utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9", node.type) shape = node.inputs[0].get_tensor_value() ctx.remove_input(node, node.input[0], 0) if len(shape) == 0: # ORT can't take an empty shape (scalar) node.set_attr("shape", [1]) ctx.set_shape(node.output[0], [1]) squeeze_node = GraphBuilder(ctx).make_squeeze( { 'data': node.output[0], 'axes': [0] }, return_node=True) ctx.insert_node_on_output(squeeze_node, node.output[0]) rand_out = squeeze_node.output[0] else: node.set_attr("shape", shape) ctx.set_shape(node.output[0], shape) rand_out = node.output[0] if node.type == "RandomUniformInt": cls.randuniform_int(ctx, node, rand_out, node.input[0], node.input[1]) node.type = "RandomUniform" ctx.replace_inputs(node, []) elif node.type == "RandomStandardNormal": node.type = "RandomNormal"
def version_10(cls, ctx, node, **kwargs): inp_shape = ctx.make_node("Shape", [node.input[0]]).output[0] dim_0 = GraphBuilder(ctx).make_slice({ 'data': inp_shape, 'starts': [0], 'ends': [1], 'axes': [0] }) zeros = ctx.make_node("ConstantOfShape", [dim_0], shapes=[[-1]]).output[0] seed = node.get_attr_value("seed", 0) seed2 = node.get_attr_value("seed2", 0) onnx_seed = utils.combine_seeds(seed, seed2) rand_attr = {'dtype': onnx_pb.TensorProto.FLOAT} if onnx_seed is not None: rand_attr['seed'] = onnx_seed random_floats = ctx.make_node("RandomUniformLike", [zeros], op_name_scope=node.name, shapes=[[-1]], attr=rand_attr).output[0] # Use indices of the TopK to get a random ordering _, random_ordering = ctx.make_node("TopK", [random_floats, dim_0], output_count=2, attr={ 'axis': -1 }).output shuffled_res = ctx.make_node( "Gather", [node.input[0], random_ordering]).output[0] ctx.replace_all_inputs(node.output[0], shuffled_res)
def any_version(cls, opset, ctx, node, **kwargs): node.domain = constants.CONTRIB_OPS_DOMAIN separator = node.get_attr_value("separator") if separator is None: separator = b'' separator = separator.decode('UTF-8') separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], np.object)) axis_node = ctx.make_const(utils.make_name("axis"), np.array([0], np.int64)) inps_with_shapes = [i for i in node.input if ctx.get_shape(i) != []] shape_node = None if 0 < len(inps_with_shapes) < len(node.input): shape_node = ctx.make_node("Shape", [inps_with_shapes[0]]) unsqueezes = [] for inp in node.input: if ctx.get_shape(inp) == [] and shape_node is not None: expand_node = ctx.make_node("Expand", [inp, shape_node.output[0]]) inp = expand_node.output[0] unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({ 'data': inp, 'axes': [0] }) unsqueezes.append(unsqueeze_node) stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0}) ctx.replace_inputs(node, [ stack_node.output[0], separator_node.output[0], axis_node.output[0] ])
def version_6(cls, ctx, node, **kwargs): # T output = All(T x, list(int) reduce_indices, @bool keepdims) # T output = Any(T x, list(int) reduce_indices, @bool keepdims) reduce_dim = node.inputs[1].get_tensor_value() # for Any, the reduce_indices can be scalar as observed. if np.isscalar(reduce_dim): reduce_dim = [reduce_dim] if ctx.opset < 11: utils.make_sure(all(i >= 0 for i in reduce_dim), "negative reduce axis is not supported in onnx for now") cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT}) keepdims = helper.get_attribute_value(node.get_attr("keep_dims")) op_type = "ReduceMin" if node.type == "All" else "ReduceSum" if op_type == "ReduceSum": reduce_node_output = GraphBuilder(ctx).make_reduce_sum( {"data": cast.output[0], "axes": reduce_dim, "keepdims": keepdims, "noop_with_empty_axes": 1}) else: reduce_node_output = ctx.make_node(op_type=op_type, inputs=cast.output, attr={"axes": reduce_dim, "keepdims": keepdims}).output[0] zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32)) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def _connect_lstm_yc_to_graph(self, context, i): # in tf, y_c output shape is: [batch, hidden] # in onnx, output shape is: [number_directions, batch, hidden] gb = GraphBuilder(self.g) exit_output = context.state_variables["ct" + str(i)].exit_output output_id = context.rnn_node[i].output[2] lstm_yc_shape = self.g.get_shape(output_id) squeeze_node = gb.make_squeeze( { "data": output_id, "axes": [0] }, shapes=[[lstm_yc_shape[1], lstm_yc_shape[2]]], dtypes=[self.g.get_dtype(output_id)], return_node=True) self.g.replace_all_inputs( exit_output.id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False): nodes_to_add = [] shape_node = self.g.make_node("Shape", [input_id]) nodes_to_add.append(shape_node) inferred_shape = self.g.get_shape(input_id) if handle_output is True: # handle output: # if required dim values don't contain more than one -1, # just use a const for Reshape's shape input. if inferred_shape is not None and inferred_shape[1:].count(-1) <= 1: new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"), np.array(inferred_shape[1:], dtype=np.int64)) nodes_to_add.append(new_shape_node) else: # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1) # from [1, time, real-batch, ...] origin_shape_node = self.g.make_node("Cast", [shape_node.output[0]], {"to": onnx_pb.TensorProto.FLOAT}) nodes_to_add.append(origin_shape_node) attr = {"axes": [0], "starts": [1], "ends": [sys.maxsize]} inputs_map = {"data": origin_shape_node.output[0], **attr} sliced_shape_node = GraphBuilder(self.g).make_slice(inputs_map) nodes_to_add.append(self.g.get_node_by_output(sliced_shape_node)) new_shape_node = self.g.make_node("Cast", [sliced_shape_node], {"to": onnx_pb.TensorProto.INT64}) nodes_to_add.append(new_shape_node) new_shape = inferred_shape[1:] else: # handle input: if inferred_shape is not None and inferred_shape.count(-1) <= 1: new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"), np.array([1] + inferred_shape, dtype=np.int64)) nodes_to_add.append(new_shape_node) else: # add a fake batch size : 1 fake_batch_size_node = self.g.make_const(utils.make_name(target_name + "_target_shape"), np.array([1], dtype=np.int64)) nodes_to_add.append(fake_batch_size_node) new_shape_node = self.g.make_node("Concat", [fake_batch_size_node.output[0], shape_node.output[0]], attr={"axis": 0}) nodes_to_add.append(new_shape_node) new_shape = [1] + inferred_shape reshape_node = self.g.make_node("Reshape", [input_id, new_shape_node.output[0]], shapes=[new_shape], dtypes=[self.g.get_dtype(input_id)], op_name_scope=target_name) nodes_to_add.append(reshape_node) logger.debug("create Reshape for scan output %s, with output shape %s", reshape_node.output[0], new_shape) return nodes_to_add
def _process_c_or_h_init_nodes(self, initializer_input_id, context): node = self.g.get_node_by_output(initializer_input_id) if node.is_const(): val = node.get_tensor_value(as_list=False) initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) return const_node.output[0] gb = GraphBuilder(self.g) squeeze_node = gb.make_unsqueeze( { 'data': initializer_input_id, "axes": [0] }, return_node=True) to_replace = [n for n in self.g.get_nodes() if n != squeeze_node] self.g.replace_all_inputs(initializer_input_id, squeeze_node.output[0], ops=to_replace) return squeeze_node.output[0]
def version_13(cls, ctx, node, **kwargs): keepdims = node.get_attr_value('keep_dims') reduce_input = node.input[0] if node.type == "All": reduce_input = ctx.make_node("Not", [reduce_input]).output[0] cast = ctx.make_node("Cast", inputs=[reduce_input], attr={"to": onnx_pb.TensorProto.FLOAT}).output[0] axes_cast = node.input[1] if ctx.get_rank(axes_cast) == 0: # Unsqueeze scalar axes axes_cast = GraphBuilder(ctx).make_unsqueeze({'data': axes_cast, 'axes': [0]}) if ctx.get_dtype(axes_cast) != onnx_pb.TensorProto.INT64: axes_cast = ctx.make_node("Cast", inputs=[axes_cast], attr={"to": onnx_pb.TensorProto.INT64}).output[0] reduce_node_output = GraphBuilder(ctx).make_reduce_sum( {"data": cast, "axes": axes_cast, "keepdims": keepdims, "noop_with_empty_axes": 1}, shapes=node.output_shapes, op_name_scope=node.name) zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32)) greater_node = ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]]) result = greater_node.output[0] if node.type == "All": result = ctx.make_node("Not", [result]).output[0] ctx.replace_all_inputs(node.output[0], result)
def any_version(cls, opset, ctx, node, **kwargs): """ Computes the modules of a complex. If the matrix dtype is not complex64 or complex128, it assumes the first dimension means real part (0) and imaginary part (1, :, :...). """ supported_dtypes = [ onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.DOUBLE, onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128, ] onnx_dtype = ctx.get_dtype(node.input[0]) utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.") shape = ctx.get_shape(node.input[0]) np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype) utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape) ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64)) ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64)) p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype)) real_part = ctx.make_node( 'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0), name=utils.make_name('Real_' + node.name)) imag_part = ctx.make_node( 'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0), name=utils.make_name('Imag_' + node.name)) real_part2 = ctx.make_node( 'Pow', inputs=[real_part.output[0], p2.name], name=utils.make_name(real_part.name + 'p2p')) imag_part2 = ctx.make_node( 'Pow', inputs=[imag_part.output[0], p2.name], name=utils.make_name(imag_part.name + 'p2p')) ctx.remove_node(node.name) add = ctx.make_node( "Add", inputs=[real_part2.output[0], imag_part2.output[0]], name=utils.make_name('ComplexAbs_' + node.name)) squeezed = GraphBuilder(ctx).make_squeeze( {'data': add.output[0], 'axes': [0]}, name=utils.make_name('ComplexAbs' + node.name), return_node=True) last_node = ctx.make_node( "Sqrt", inputs=squeezed.output[:1], name=utils.make_name('ComplexAbs' + node.name), shapes=[shape[1:]], dtypes=[onnx_dtype]) ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm, lstm_output_index, all_nodes, to_remove): fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index]) bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index]) if not fw_consumers and not bw_consumers: return if lstm_output_index == 0: axis = 1 # remove reverse op for lstm_bw reverse_nodes = get_reverse_nodes_after_y_output(g, lstm_bw) if not reverse_nodes: raise ValueError( "should not happen y_output is not followed with reverse node") for r_op in reverse_nodes: logger.debug("remove reverse op called %s", r_op.name) g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0]) to_remove.append(r_op.name) elif lstm_output_index in [1, 2]: axis = 0 else: raise ValueError("LSTM only should has 3 outputs.") if fw_consumers: attr = {"axes": [axis], "starts": [0], "ends": [1]} inputs_map = {"data": bi_lstm.output[lstm_output_index], **attr} slice_node_fw = GraphBuilder(g).make_slice(inputs_map) all_nodes.append(g.get_node_by_output(slice_node_fw)) g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index], slice_node_fw) if bw_consumers: attr = {"axes": [axis], "starts": [1], "ends": [2]} inputs_map = {"data": bi_lstm.output[lstm_output_index], **attr} slice_node_bw = GraphBuilder(g).make_slice(inputs_map) all_nodes.append(g.get_node_by_output(slice_node_bw)) g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index], slice_node_bw)
def version_7(cls, ctx, node, **kwargs): tfl_while_inputs = node.input output_shapes = node.output_shapes output_dtypes = node.output_dtypes output_names = node.output cond_name = node.get_attr_str("cond_subgraph_index") cond_graph = find_function(cond_name) cond_graph.parent_graph = ctx body_name = node.get_attr_str("body_subgraph_index") body = find_function(body_name) body.parent_graph = ctx ctx.remove_node(node.name) cond_binding = parameter_binding(cond_graph, tfl_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter. # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph. scan_outputs = sorted(body.scan_outputs, reverse=True) def input_is_unused(g, index): return len(g.find_output_consumers(g.inputs[index])) == 0 scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)] for idx, _ in scan_outputs: del tfl_while_inputs[idx] output_shapes.append(output_shapes.pop(idx)) output_dtypes.append(output_dtypes.pop(idx)) output_names.append(output_names.pop(idx)) max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max)) loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs, output_count=len(output_shapes), name=node.name + "_loop", shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True) output_map = dict(zip(output_names, loop_node.output)) # shift output consumers for k, v in output_map.items(): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs) for i in range(len(scan_outputs)): squeeze_node = GraphBuilder(body).make_squeeze( {'data': body.outputs[-1-i], "axes": [0]}, return_node=True) body.outputs[-1-i] = squeeze_node.output[0] loop_node.set_body_graph_as_attr("body", body)
def any_version(cls, opset, ctx, node, **kwargs): if node.type == "StringSplit": skip_empty = node.get_attr_value('skip_empty', True) else: skip_empty = False node.type = "StringSplit" node.domain = constants.CONTRIB_OPS_DOMAIN for a in list(node.attr.keys()): del node.attr[a] unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': node.input[1], 'axes': [0]}, return_node=True) skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'), np.array([skip_empty], np.bool)) ctx.replace_inputs(node, [node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0]])
def _connect_gru_state_to_graph(self, context): # in tf, state output shape is: [batch, hidden] # in onnx, output shape is: [number_directions, batch, hidden] exit_output_id = context.state_variables["state"].exit_output.id if not exit_output_id: logger.debug("no one consume state variable") return output_id = context.rnn_node.output[1] gru_state_shape = self.g.get_shape(output_id) output_shape = [gru_state_shape[1], gru_state_shape[2]] squeeze_node = GraphBuilder(self.g).make_squeeze( {'data': output_id, "axes": [0]}, shapes=[output_shape], dtypes=[self.g.get_dtype(output_id)], return_node=True) self.g.replace_all_inputs(exit_output_id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def version_1(cls, ctx, node, **kwargs): node.domain = constants.CONTRIB_OPS_DOMAIN input_node = node.inputs[0] utils.make_sure(input_node.type == "SentencepieceOp", "Input 0 to node %s is not SentencepieceOp", node.name) ctx.remove_input(node, node.input[0], 0) nbest_size_cast = ctx.make_node("Cast", [node.input[1]], attr={'to': TensorProto.INT64}).output[0] ctx.replace_input(node, node.input[1], nbest_size_cast, 1) for i in range(1, len(node.input)): unsqueeze = GraphBuilder(ctx).make_unsqueeze({'data': node.input[i], 'axes': [0]}) ctx.replace_input(node, node.input[i], unsqueeze, i) node.set_attr("model", input_node.attr['model'].s) node.type = "SentencepieceTokenizer" if ctx.is_safe_to_remove_nodes([input_node]): ctx.remove_node(input_node.name)
def _convert_since_9(cls, ctx, node, op_type, roi_required=False): # float32 out = ResizeBilinear/ResizeNearestNeighbor(T images, int size) # https://www.tensorflow.org/api_docs/python/tf/image/resize_nearest_neighbor # wants the input to be NHWC - adjust target_shape to this. mode = "linear" if node.type == "ResizeBilinear" else "nearest" # first create "scales" info for onnx upsample # if shape of input and output known then "scale" is calculated statically and set as a const node shape = ctx.get_shape(node.input[0]) if shape and shape[2] != -1 and shape[1] != -1 and node.inputs[1].is_const(): target_shape = node.inputs[1].get_tensor_value() n, h, w, c = shape nh, nw = target_shape # scales is nchw # the reason not storing data at raw field is because of the bug: https://github.com/onnx/onnx/issues/1852 scale_val = np.array([1.0, 1.0, float(nh) / h, float(nw) / w]).astype(np.float32) scales = ctx.make_const(utils.make_name("scales"), scale_val, raw=False) else: ori_shape = ctx.make_node("Shape", [node.input[0]]) attr = {"axes": [0], "starts": [1], "ends": [3]} inputs_map = {"data": ori_shape.output[0], **attr} ori_shape_hw = GraphBuilder(ctx).make_slice(inputs_map) ori_shape_hw_float = ctx.make_node("Cast", [ori_shape_hw], attr={"to": onnx_pb.TensorProto.FLOAT}) target_hw = node.inputs[1] target_hw_float = ctx.make_node("Cast", target_hw.output, attr={"to": onnx_pb.TensorProto.FLOAT}) scales_hw = ctx.make_node("Div", [target_hw_float.output[0], ori_shape_hw_float.output[0]]) const_one_array = ctx.make_const(utils.make_name("one"), np.array([1.0, 1.0]).astype(np.float32)) # scales is nchw scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0}) # because onnxruntime only supports to scale the last two dims so transpose is inserted input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW}) if roi_required: roi = ctx.make_const(utils.make_name("roi"), np.array([]).astype(np.float32)) upsample = ctx.make_node("Resize", [input_nchw.output[0], roi.output[0], scales.output[0]], attr={"mode": mode, "nearest_mode": "floor", "coordinate_transformation_mode": "asymmetric"}) else: upsample = ctx.make_node(op_type, [input_nchw.output[0], scales.output[0]], attr={"mode": mode}) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) ctx.make_node("Transpose", upsample.output, {"perm": constants.NCHW_TO_NHWC}, name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete): dtype = g.get_dtype(output.output[0]) op_name = utils.make_name("RandomUniform") shape_node = ru_op.inputs[0] shape = g.get_shape(output.output[0]) if shape_node.is_const(): # if the tensorflow input (aka the shape) is const we can use the RandomUniform op needs_squeeze = False if len(shape) == 0: shape = [1] needs_squeeze = True new_node = g.make_node("RandomUniform", [], name=op_name, attr={"low": tmin, "high": tmax, "dtype": dtype, "shape": shape}, shapes=[shape], dtypes=[dtype]) if needs_squeeze: new_node = GraphBuilder(g).make_squeeze({"data": new_node.output[0], "axes": [0]}, return_node=True) else: if shape_node.type == "Shape": # if shape is dynamic - in tensorflow shape comes as tensor VALUE, # in onnx RandomUniformLike finds takes the shape from the tensor itself. # In many cases there is a shape op in tensorflow before RandomUniform and # to make that work for onnx we just need to remove the shape op. new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name, attr={"low": tmin, "high": tmax, "dtype": dtype}, shapes=[shape], dtypes=[dtype]) else: # if the shape is calculated we need to create a tensor so RandomUniformLike # can take the shape from there. Pre opset9 this is somewhat hacky because there is # no real fill op in onnx. In general this is not going to help performance but the tensors # created are expected to be small. # tell the caller to not delete the shape node to_delete.remove(shape_node) # create a fill op with the shape of the value of the input tensor zero = g.make_const(utils.make_name("zero"), np.zeros((), dtype=np.float32)) fill_node = g.make_node("Fill", inputs=[shape_node.output[0], zero.name], shapes=[shape], dtypes=[dtype]) func, _ = handler.tf_op.find_effective_op("Fill") func(g, fill_node) # and use RandomUniformLike to create the random tensor new_node = g.make_node("RandomUniformLike", inputs=[fill_node.output[0]], name=op_name, attr={"low": tmin, "high": tmax, "dtype": dtype}, shapes=[shape], dtypes=[dtype]) return new_node
def process_var_init_nodes(self, context): assert "state" in context.state_variables.keys() initializer_input_id = context.state_variables["state"].enter_input_id node = self.g.get_node_by_output(initializer_input_id) if node.is_const(): val = node.get_tensor_value(as_list=False) initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) context.onnx_input_ids["initial_state"] = const_node.output[0] return squeeze_node = GraphBuilder(self.g).make_unsqueeze( { 'data': initializer_input_id, 'axes': [0] }, return_node=True) to_replace = [n for n in self.g.get_nodes() if n != squeeze_node] self.g.replace_all_inputs(initializer_input_id, squeeze_node.output[0], ops=to_replace) context.onnx_input_ids["initial_state"] = squeeze_node.output[0]
def _optimize_reduce(self, node, graph): if graph.get_dtype( node.output[0]) not in [TensorProto.FLOAT, TensorProto.DOUBLE]: return False if node.output[0] in graph.outputs: # Replacement is unsafe return False axes = node.get_attr_value('axes') inp_rank = graph.get_rank(node.input[0]) if inp_rank is None: return False if axes != list(range(2, inp_rank)): return False op_map = { "ReduceMean": "GlobalAveragePool", "ReduceMax": "GlobalMaxPool" } node.type = op_map[node.type] del node.attr['axes'] if not node.get_attr_value('keepdims', True): out_shapes = node.output_shapes out_dtypes = node.output_dtypes new_out_shape = graph.get_shape( node.input[0])[:2] + [1] * len(axes) graph.set_shape(node.output[0], new_out_shape) squeeze_node = GraphBuilder(graph).make_squeeze( { 'data': node.output[0], 'axes': axes }, shapes=out_shapes, dtypes=out_dtypes, return_node=True, op_name_scope=node.name) graph.insert_node_on_output(squeeze_node, node.output[0]) if 'keepdims' in node.attr: del node.attr['keepdims'] return True
def _optimize_reshape(self, node, graph): if node.inputs[1].is_const(): return False inp_shape = graph.get_shape(node.input[0]) if inp_shape is None: # The rank must be known return False feed_dict = {} for n in graph.find_output_consumers(node.input[0]): if n.type == "Shape": symbolic_shape = [] for i, d in enumerate(inp_shape): if d == -1: # Make a variable representing each unknown dim symbolic_shape.append( SymbolicTensorElement.from_variable(i)) else: symbolic_shape.append( SymbolicTensorElement.from_const(d)) feed_dict[n.output[0]] = np.array(symbolic_shape, np.object) try: symbolic_res = SymbolicExecutor(graph).compute_outputs( [node.input[1]], feed_dict) except SymbolicExecutionException: return False utils.make_sure( len(symbolic_res[0].shape) == 1, "Shape must have rank 1") symbolic_shape = symbolic_res[0].tolist() product_cnt = len( [val for val in symbolic_shape if val.has_multiple_terms()]) idx_cnt = len([val for val in symbolic_shape if val.is_single_var()]) if product_cnt > 1: # The -1 lets us handle at most one dim with multiple terms return False if idx_cnt + product_cnt <= 1: # Only 1 non-const dim. Use -1 and consts for the rest. new_shape = [ v.constant if v.is_const() else -1 for v in symbolic_shape ] shift = 0 else: # We will need to use some 0s. We can shift using squeeze/unsqueeze to line up equal dims def get_shift(val, i): if not val.is_single_var(): return None return val.terms[0] - i shifts = [ get_shift(val, i) for i, val in enumerate(symbolic_shape) ] # Find the most popular shift most_common = Counter(s for s in shifts if s is not None).most_common(1) shift = most_common[0][0] if most_common else 0 def get_reshape_dim(val, i, shift): if val.is_const(): return self.constant if get_shift(val, i) == shift: return 0 # Use -1 only as a last resort return -1 new_shape = [ get_reshape_dim(v, i, shift) for i, v in enumerate(symbolic_shape) ] if new_shape.count(-1) > 1: return False if shift > 0: new_shape = [1] * shift + new_shape squeeze_node = GraphBuilder(graph).make_squeeze( { 'data': node.output[0], 'axes': list(range(shift)) }, return_node=True, shapes=node.output_shapes, dtypes=node.output_dtypes) graph.insert_node_on_output(squeeze_node, node.output[0]) const_shape = graph.make_const(utils.make_name(node.name + "_shape"), np.array(new_shape, np.int64)).output[0] graph.replace_inputs(node, [node.input[0], const_shape]) if shift < 0: unsqueeze_node = GraphBuilder(graph).make_unsqueeze({ 'data': node.input[0], 'axes': list(range(-shift)) }) graph.replace_inputs(node, [unsqueeze_node, const_shape]) return True
def version_7(cls, ctx, node, **kwargs): # T output = MatrixBandPart(T input, int num_lower, int num_upper) # data-flow: first generate mask matrix and then use element-wise mul op input_rank = len(ctx.get_shape(node.input[0])) utils.make_sure( input_rank == 2, error_msg="MatrixBandPart op: only rank 2 is supported") bandpart = [node.inputs[ind].get_tensor_value() for ind in [1, 2]] utils.make_sure(bandpart in [[-1, 0], [0, -1]], "only support Lower/Upper triangular for now") # methods to generate mask matrix: if lower triangular is needed, then generate column one by one # otherwise row is generated one by one. axis, counter_axis, squeeze_axis = (1, 0, 2) if bandpart == [-1, 0] else (0, 1, 1) # 1: subgraph to implement tf.onelike(input[:, 0]), # no need to worry about the dtype, because bool type is needed as Xor only support bool node_name = utils.make_name("const_zero") const_zero = ctx.make_const(name=node_name, np_val=np.array([0]).astype(np.int32)) first_col_or_row = ctx.make_node( op_type="Gather", inputs=[node.input[0], const_zero.output[0]], attr={"axis": axis}) first_col_or_row_casted = ctx.make_node( op_type="Cast", inputs=first_col_or_row.output, attr={"to": onnx_pb.TensorProto.BOOL}) # line means one col or one row zero_line = ctx.make_node(op_type="Xor", inputs=first_col_or_row_casted.output * 2) one_line = ctx.make_node(op_type="Not", inputs=zero_line.output) # 2: "loop" to generate mask matrix: generate col or row of matrix one by one g = ctx.create_new_graph_with_same_config() node_name = utils.make_name("const_zero_bool") const_zero_bool = ctx.make_const(name=node_name, np_val=np.array([[0] ]).astype(np.bool)) ctx.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL) # shift right the line and add zero at the left. new_line = g.make_node(op_type="Concat", inputs=[const_zero_bool.output[0], "line"], attr={"axis": counter_axis}, dtypes=[onnx_pb.TensorProto.BOOL]) attr = {"axes": [counter_axis], "starts": [0], "ends": [-1]} inputs_map = {"data": new_line.output[0], **attr} slice_node = GraphBuilder(g).make_slice(inputs_map) g.make_node("Identity", ["cond"], outputs=["cond_out"]) g.make_node("Identity", ["line"], outputs=["res"]) g.make_node("Identity", [slice_node], outputs=["line_out"]) g.add_graph_input("trip", onnx_pb.TensorProto.INT64, []) g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, []) g.add_graph_input("line", onnx_pb.TensorProto.BOOL, [-1, -1]) g.add_graph_output("cond_out", onnx_pb.TensorProto.BOOL, []) g.add_graph_output("line_out", onnx_pb.TensorProto.BOOL, [-1, -1]) g.add_graph_output("res", onnx_pb.TensorProto.BOOL, [-1, -1]) # initial value of body vars shape = ctx.make_node(op_type="Shape", inputs=[node.input[0] ]) # dtype of result is int64 node_name = utils.make_name("line_num_index") col_or_row_num_index = ctx.make_const(name=node_name, np_val=np.array(axis).astype( np.int32)) line_num = ctx.make_node( op_type="Gather", inputs=[shape.output[0], col_or_row_num_index.output[0]]) trip_cnt = line_num.output[0] node_name = utils.make_name("true") cond = ctx.make_const(name=node_name, np_val=np.array(1).astype(np.bool)) col_init = one_line.output[0] loop_node = ctx.make_node(op_type="Loop", inputs=[trip_cnt, cond.output[0], col_init], output_count=2) loop_node.set_body_graph_as_attr("body", g) # convert generated mask matrix from bool to right shape and data type squeeze = ctx.make_node(op_type="Squeeze", inputs=[loop_node.output[1]], attr={"axes": [squeeze_axis]}) cast1 = ctx.make_node(op_type="Cast", inputs=squeeze.output, attr={"to": onnx_pb.TensorProto.FLOAT}) if axis == 1: mask_matrix = ctx.make_node(op_type="Transpose", inputs=cast1.output) else: mask_matrix = squeeze cast2 = ctx.make_node(op_type="Cast", inputs=mask_matrix.output, attr={"to": ctx.get_dtype(node.input[0])}) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) ctx.make_node(op_type="Mul", inputs=[cast2.output[0], node.input[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def version_9(cls, ctx, node, **kwargs): node_inputs = node.input num_segments_specified = False if node.type.endswith("WithNumSegments") or node.type.startswith( "Unsorted"): num_segments_specified = True num_segments = node_inputs.pop() node.type = node.type.replace("WithNumSegments", "") node.type = node.type.replace("Unsorted", "") if node.type.startswith("Sparse"): data_inp, indices_inp, segment_inp = node_inputs gather_node = ctx.make_node("Gather", [data_inp, indices_inp], attr={'axis': 0}) data_inp = gather_node.output[0] node.type = node.type.replace("Sparse", "") else: data_inp, segment_inp = node_inputs # Data has shape [n, a, b, ..., c] data_shape = ctx.get_shape(data_inp) data_rank = len(data_shape) if data_shape is not None else None data_dtype = ctx.get_dtype(data_inp) data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype) seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp)) if num_segments_specified and ctx.get_dtype( segment_inp) != ctx.get_dtype(num_segments): num_segments = ctx.make_node("Cast", [num_segments], attr={ "to": ctx.get_dtype(segment_inp) }).output[0] data_is_float = np.dtype(data_np_dtype).kind == 'f' data_is_int = np.dtype(data_np_dtype).kind == 'i' utils.make_sure(data_is_float or data_is_int, "dtype for Segment ops must be float or int") if node.type in ["SegmentSum", "SegmentMean", "SegmentSqrtN"]: onnx_op = "ReduceSum" identity_value = np.array(0, dtype=data_np_dtype) elif node.type == "SegmentProd": onnx_op = "ReduceProd" identity_value = np.array(1, dtype=data_np_dtype) elif node.type == "SegmentMax": onnx_op = "ReduceMax" if data_is_float: identity_value = np.array('-inf', dtype=data_np_dtype) else: identity_value = np.iinfo(data_np_dtype).min elif node.type == "SegmentMin": onnx_op = "ReduceMin" if data_is_float: identity_value = np.array('inf', dtype=data_np_dtype) else: identity_value = np.iinfo(data_np_dtype).max if not num_segments_specified: max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={ 'axes': [0], 'keepdims': 0 }) one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype)) num_segments = ctx.make_node( "Add", [max_segment.output[0], one_const.output[0]]).output[0] # ORT doesn't support bool for OneHot so we use float32 and cast to bool onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=np.float32)) # one_hot_node has shape [s, n] (s is # segments) one_hot_node = ctx.make_node( "OneHot", [segment_inp, num_segments, onehot_values.output[0]], attr={'axis': 0}) if node.type == "SegmentMean": scaling_node_output = GraphBuilder(ctx).make_reduce_sum({ "data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1 }) elif node.type == "SegmentSqrtN": seg_cnts_node_output = GraphBuilder(ctx).make_reduce_sum({ "data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1 }) scaling_node_output = ctx.make_node( "Sqrt", [seg_cnts_node_output]).output[0] else: scaling_node_output = None if scaling_node_output is not None and num_segments_specified: # If empty segments are possible, we must avoid division by zero const_one_float = ctx.make_const( utils.make_name("const_one_float"), np.array(1, dtype=np.float32)) scaling_node_output = ctx.make_node( "Max", [scaling_node_output, const_one_float.output[0]]).output[0] if onnx_op == "ReduceSum": # If the op is a summation, we can use MatMul instead of Where, which is faster # Data shape is [n, a, b, ..., c] data_shape_node = ctx.make_node("Shape", [data_inp]) new_shape = ctx.make_const(utils.make_name("reshape_const"), np.array([0, -1], dtype=np.int64)) # Reshape the data from [n, a, b, ..., c] to [n, P] data_reshape = ctx.make_node("Reshape", [data_inp, new_shape.output[0]]) one_hot_cast = one_hot_node if data_dtype != onnx_pb.TensorProto.FLOAT: one_hot_cast = ctx.make_node("Cast", [one_hot_node.output[0]], attr={'to': data_dtype}) # Shapes [s, n] * [n, P] => [s, P] product = ctx.make_node( "MatMul", [one_hot_cast.output[0], data_reshape.output[0]], op_name_scope=node.name) if scaling_node_output is not None: scaling_node_unsqueeze = ctx.make_node("Unsqueeze", [scaling_node_output], attr={'axes': [1]}) product = ctx.make_node( "Div", [product.output[0], scaling_node_unsqueeze.output[0]]) # Create new shape [0, a, b, ..., c] max_int64 = int(utils.get_max_value(np.int64)) new_shape_slice = GraphBuilder(ctx).make_slice({ "data": data_shape_node.output[0], "ends": [max_int64], "starts": [1], "axes": [0] }) zero_const = ctx.make_const(utils.make_name("zero_const"), np.array([0], dtype=np.int64)) new_shape = ctx.make_node("Concat", [zero_const.output[0], new_shape_slice], attr={'axis': 0}) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) # Reshape result from [s, P] to [s, a, b, ..., c] ctx.make_node("Reshape", [product.output[0], new_shape.output[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes) return identity_const = ctx.make_const(utils.make_name("const_identity"), identity_value) one_hot_bool = ctx.make_node("Cast", [one_hot_node.output[0]], attr={"to": onnx_pb.TensorProto.BOOL}) one_hot_unsqueeze = one_hot_bool # Make one_hot_unsqueeze have shape [s, n, 1, 1, ..., 1] if data_rank is None: # Unsqueeze requires known rank, but we can use Reshape if rank is unknown shape_node = ctx.make_node("Shape", [data_inp]) rank_node = ctx.make_node("Shape", [shape_node.output[0]]) one_const_int64 = ctx.make_const(utils.make_name("const_one"), np.array([1], dtype=np.int64)) num_unsqueeze_dims = ctx.make_node( "Sub", [rank_node.output[0], one_const_int64.output[0]]) one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[1]) unsqueeze_dims = ctx.make_node( "ConstantOfShape", inputs=[num_unsqueeze_dims.output[0]], attr={"value": one_tensor}) # Zero indicates a dimension should be unchanged double_zero_const = ctx.make_const( utils.make_name("double_zero"), np.array([0, 0], dtype=np.int64)) expanded_shape = ctx.make_node( "Concat", [double_zero_const.output[0], unsqueeze_dims.output[0]], attr={'axis': 0}) one_hot_unsqueeze = ctx.make_node( "Reshape", [one_hot_bool.output[0], expanded_shape.output[0]]) elif data_rank > 1: new_dims = list(range(2, 2 + data_rank - 1)) one_hot_unsqueeze = ctx.make_node("Unsqueeze", [one_hot_bool.output[0]], attr={'axes': new_dims}) # Shape of data: [n, a, b, ..., c] # Shape of one_hot: [s, n, 1, 1, ..., 1] # Broadcast left-pads shape with 1s, so result is shape: [s, n, a, b, ..., c] where_node = ctx.make_node( "Where", [one_hot_unsqueeze.output[0], data_inp, identity_const.output[0]]) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) # After reduction over axis 1, shape is: [s, a, b, ..., c] ctx.make_node(onnx_op, [where_node.output[0]], attr={ 'axes': [1], 'keepdims': 0 }, name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)