def connect_unit_rnn_output_to_graph(self, context): outputs = context.loop_properties.scan_outputs_exits if not outputs: logger.debug("no one consume output") return gb = GraphBuilder(self.g) gather_output_id = outputs[0].id logger.debug("found output for rnn: %s", gather_output_id) # in tf batch major mode, output shape is : [batch, time, hidden] # in time major mode, output shape is: [time, batch, hidden] # in onnx, output shape is : [time, num_directions, batch, hidden] rnn_node = context.rnn_node[len(context.rnn_node) - 1] output_id = rnn_node.output[0] rnn_output_shape = self.g.get_shape(output_id) squeeze_output_shape = [ rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3] ] squeeze_node = gb.make_squeeze({ 'data': output_id, "axes": [1] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(output_id)], return_node=True) self.g.replace_all_inputs( gather_output_id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def _connect_lstm_ych_to_graph(self, context, i): # in tf, concat of y_c and y_h output shape is: [batch, hidden *2] # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden] gb = GraphBuilder(self.g) exit_output = context.state_variables["ct_ht" + str(i)].exit_output lstm_node = context.rnn_node[i] yc_shape = self.g.get_shape(lstm_node.output[2]) concat_output_shape = [yc_shape[0], yc_shape[1], yc_shape[2] * 2] concat = self.g.make_node( "Concat", [lstm_node.output[2], lstm_node.output[1]], attr={"axis": 2}, shapes=[concat_output_shape], dtypes=[self.g.get_dtype(lstm_node.output[2])]) squeeze_output_shape = [concat_output_shape[1], concat_output_shape[2]] squeeze_node = gb.make_squeeze( { 'data': concat.output[0], "axes": [0] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(concat.output[0])], return_node=True) self.g.replace_all_inputs( exit_output.id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def create_rnn_node(self, context): gb = GraphBuilder(self.g) rnn_nodes = list() outputs = context.loop_properties.scan_outputs_exits logger.debug("number of rnn node outputs: %s", len(outputs)) for i in range(self.num_lstm_layers): logger.debug("creating rnn node for layer: %s", i) rnn_nodes.append(self.create_single_rnn_node(context, i)) output_id = rnn_nodes[i].output[0] rnn_output_shape = self.g.get_shape(output_id) squeeze_output_shape = [ rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3] ] squeeze_node = gb.make_squeeze( { "data": output_id, "axes": [1] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(output_id)], return_node=True) if i + 1 < self.num_lstm_layers: logger.debug("setting input for layer: %s", i + 1) context.onnx_input_ids[i + 1]["X"] = squeeze_node.output[0] return rnn_nodes
def any_version(cls, opset, ctx, node, **kwargs): node.domain = constants.CONTRIB_OPS_DOMAIN separator = node.get_attr_value("separator") if separator is None: separator = b'' separator = separator.decode('UTF-8') separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], np.object)) axis_node = ctx.make_const(utils.make_name("axis"), np.array([0], np.int64)) inps_with_shapes = [i for i in node.input if ctx.get_shape(i) != []] shape_node = None if 0 < len(inps_with_shapes) < len(node.input): shape_node = ctx.make_node("Shape", [inps_with_shapes[0]]) unsqueezes = [] for inp in node.input: if ctx.get_shape(inp) == [] and shape_node is not None: expand_node = ctx.make_node("Expand", [inp, shape_node.output[0]]) inp = expand_node.output[0] unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({ 'data': inp, 'axes': [0] }) unsqueezes.append(unsqueeze_node) stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0}) ctx.replace_inputs(node, [ stack_node.output[0], separator_node.output[0], axis_node.output[0] ])
def version_6(cls, ctx, node, **kwargs): # T output = All(T x, list(int) reduce_indices, @bool keepdims) # T output = Any(T x, list(int) reduce_indices, @bool keepdims) reduce_dim = node.inputs[1].get_tensor_value() # for Any, the reduce_indices can be scalar as observed. if np.isscalar(reduce_dim): reduce_dim = [reduce_dim] if ctx.opset < 11: utils.make_sure(all(i >= 0 for i in reduce_dim), "negative reduce axis is not supported in onnx for now") cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT}) keepdims = helper.get_attribute_value(node.get_attr("keep_dims")) op_type = "ReduceMin" if node.type == "All" else "ReduceSum" if op_type == "ReduceSum": reduce_node_output = GraphBuilder(ctx).make_reduce_sum( {"data": cast.output[0], "axes": reduce_dim, "keepdims": keepdims, "noop_with_empty_axes": 1}) else: reduce_node_output = ctx.make_node(op_type=op_type, inputs=cast.output, attr={"axes": reduce_dim, "keepdims": keepdims}).output[0] zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32)) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def version_1(cls, ctx, node, **kwargs): # in tf-2.0 grappler optimizes the graph pretty well and our matching logic # in the rewriter does not trigger. grappler will send the random uniform # with shape as input so we need to pickup the input here and if the shape is # const we make it an attribute. seed = node.get_attr("seed") node.set_attr("seed", float(seed.f)) utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9") shape = node.inputs[0].get_tensor_value() ctx.remove_input(node, node.input[0], 0) if len(shape) == 0: # ORT can't take an empty shape (scalar) node.set_attr("shape", [1]) ctx.set_shape(node.output[0], [1]) squeeze_node = GraphBuilder(ctx).make_squeeze( { 'data': node.output[0], 'axes': [0] }, return_node=True) ctx.insert_node_on_output(squeeze_node, node.output[0]) rand_out = squeeze_node.output[0] else: node.set_attr("shape", shape) ctx.set_shape(node.output[0], shape) rand_out = node.output[0] if node.type == "RandomUniformInt": cls.randuniform_int(ctx, node, rand_out, node.input[0], node.input[1]) node.type = "RandomUniform" ctx.replace_inputs(node, [])
def _connect_lstm_yc_to_graph(self, context, i): # in tf, y_c output shape is: [batch, hidden] # in onnx, output shape is: [number_directions, batch, hidden] gb = GraphBuilder(self.g) exit_output = context.state_variables["ct" + str(i)].exit_output output_id = context.rnn_node[i].output[2] lstm_yc_shape = self.g.get_shape(output_id) squeeze_node = gb.make_squeeze( { "data": output_id, "axes": [0] }, shapes=[[lstm_yc_shape[1], lstm_yc_shape[2]]], dtypes=[self.g.get_dtype(output_id)], return_node=True) self.g.replace_all_inputs( exit_output.id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def _process_c_or_h_init_nodes(self, initializer_input_id, context): node = self.g.get_node_by_output(initializer_input_id) if node.is_const(): val = node.get_tensor_value(as_list=False) initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) return const_node.output[0] gb = GraphBuilder(self.g) squeeze_node = gb.make_unsqueeze( { 'data': initializer_input_id, "axes": [0] }, return_node=True) to_replace = [n for n in self.g.get_nodes() if n != squeeze_node] self.g.replace_all_inputs(initializer_input_id, squeeze_node.output[0], ops=to_replace) return squeeze_node.output[0]
def any_version(cls, opset, ctx, node, **kwargs): """ Computes the modules of a complex. If the matrix dtype is not complex64 or complex128, it assumes the first dimension means real part (0) and imaginary part (1, :, :...). """ supported_dtypes = [ onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.DOUBLE, onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128, ] onnx_dtype = ctx.get_dtype(node.input[0]) utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.") shape = ctx.get_shape(node.input[0]) np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype) utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape) ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64)) ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64)) p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype)) real_part = ctx.make_node( 'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0), name=utils.make_name('Real_' + node.name)) imag_part = ctx.make_node( 'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0), name=utils.make_name('Imag_' + node.name)) real_part2 = ctx.make_node( 'Pow', inputs=[real_part.output[0], p2.name], name=utils.make_name(real_part.name + 'p2p')) imag_part2 = ctx.make_node( 'Pow', inputs=[imag_part.output[0], p2.name], name=utils.make_name(imag_part.name + 'p2p')) ctx.remove_node(node.name) add = ctx.make_node( "Add", inputs=[real_part2.output[0], imag_part2.output[0]], name=utils.make_name('ComplexAbs_' + node.name)) squeezed = GraphBuilder(ctx).make_squeeze( {'data': add.output[0], 'axes': [0]}, name=utils.make_name('ComplexAbs' + node.name), return_node=True) last_node = ctx.make_node( "Sqrt", inputs=squeezed.output[:1], name=utils.make_name('ComplexAbs' + node.name), shapes=[shape[1:]], dtypes=[onnx_dtype]) ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
def _process_non_tuple_ch_init_nodes(self, context, i): gb = GraphBuilder(self.g) input_id = context.state_variables["ct_ht" + str(i)].enter_input_id hidden_size = context.hidden_size[i] attr = {"axes": [1], "starts": [0], "ends": [hidden_size]} inputs_map = {"data": input_id, **attr} slice_node1 = GraphBuilder(self.g).make_slice(inputs_map) unsqueeze_node_1 = gb.make_unsqueeze({ 'data': slice_node1, "axes": [0] }, return_node=True) attr = { "axes": [1], "starts": [hidden_size], "ends": [hidden_size * 2] } inputs_map = {"data": input_id, **attr} slice_node2 = GraphBuilder(self.g).make_slice(inputs_map) unsqueeze_node_2 = gb.make_unsqueeze({ 'data': slice_node2, "axes": [0] }, return_node=True) return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
def slice_birnn_for_original_rnn_consumers(g, rnn_fw, rnn_bw, bi_rnn, rnn_output_index, all_nodes, to_remove): fw_consumers = g.find_output_consumers(rnn_fw.output[rnn_output_index]) bw_consumers = g.find_output_consumers(rnn_bw.output[rnn_output_index]) if not fw_consumers and not bw_consumers: return if rnn_output_index == 0: axis = 1 # remove reverse op for rnn_bw reverse_nodes = get_reverse_nodes_after_y_output(g, rnn_bw) for r_op in reverse_nodes: logger.debug("remove reverse op %s", r_op.name) g.replace_all_inputs(r_op.output[0], r_op.input[0], ops=all_nodes) to_remove.append(r_op.name) elif rnn_output_index in [1, 2]: axis = 0 else: raise ValueError("rnn only should has 3 outputs.") if fw_consumers: attr = {"axes": [axis], "starts": [0], "ends": [1]} inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr} slice_node_fw = GraphBuilder(g).make_slice(inputs_map) all_nodes.append(g.get_node_by_output(slice_node_fw)) g.replace_all_inputs(rnn_fw.output[rnn_output_index], slice_node_fw, ops=fw_consumers) if bw_consumers: attr = {"axes": [axis], "starts": [1], "ends": [2]} inputs_map = {"data": bi_rnn.output[rnn_output_index], **attr} slice_node_bw = GraphBuilder(g).make_slice(inputs_map) all_nodes.append(g.get_node_by_output(slice_node_bw)) g.replace_all_inputs(rnn_bw.output[rnn_output_index], slice_node_bw, ops=bw_consumers)
def any_version(cls, opset, ctx, node, **kwargs): if node.type == "StringSplit": skip_empty = node.get_attr_value('skip_empty', True) else: skip_empty = False node.type = "StringSplit" node.domain = constants.CONTRIB_OPS_DOMAIN for a in list(node.attr.keys()): del node.attr[a] unsqueeze_node = GraphBuilder(ctx).make_unsqueeze( { 'data': node.input[1], 'axes': [0] }, return_node=True) skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'), np.array([skip_empty], np.bool)) ctx.replace_inputs(node, [ node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0] ])
def version_1(cls, ctx, node, **kwargs): node.domain = constants.CONTRIB_OPS_DOMAIN input_node = node.inputs[0] utils.make_sure(input_node.type == "SentencepieceOp", "Input 0 to node %s is not SentencepieceOp", node.name) ctx.remove_input(node, node.input[0], 0) nbest_size_cast = ctx.make_node("Cast", [node.input[1]], attr={ 'to': TensorProto.INT64 }).output[0] ctx.replace_input(node, node.input[1], nbest_size_cast, 1) for i in range(1, len(node.input)): unsqueeze = GraphBuilder(ctx).make_unsqueeze({ 'data': node.input[i], 'axes': [0] }) ctx.replace_input(node, node.input[i], unsqueeze, i) node.set_attr("model", input_node.attr['model'].s) node.type = "SentencepieceTokenizer" if ctx.is_safe_to_remove_nodes([input_node]): ctx.remove_node(input_node.name)
def any_version(cls, opset, ctx, node, **kwargs): node_inputs = node.input num_segments_specified = False if node.type.endswith("WithNumSegments") or node.type.startswith("Unsorted"): num_segments_specified = True num_segments = node_inputs.pop() node.type = node.type.replace("WithNumSegments", "") node.type = node.type.replace("Unsorted", "") if node.type.startswith("Sparse"): data_inp, indices_inp, segment_inp = node_inputs gather_node = ctx.make_node("Gather", [data_inp, indices_inp], attr={'axis': 0}) data_inp = gather_node.output[0] node.type = node.type.replace("Sparse", "") else: data_inp, segment_inp = node_inputs # Data has shape [n, a, b, ..., c] data_shape = ctx.get_shape(data_inp) data_rank = len(data_shape) if data_shape is not None else None data_dtype = ctx.get_dtype(data_inp) data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype) seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp)) if num_segments_specified and ctx.get_dtype(segment_inp) != ctx.get_dtype(num_segments): num_segments = ctx.make_node("Cast", [num_segments], attr={"to": ctx.get_dtype(segment_inp)}).output[0] data_is_float = np.dtype(data_np_dtype).kind == 'f' data_is_int = np.dtype(data_np_dtype).kind == 'i' utils.make_sure(data_is_float or data_is_int, "dtype for Segment ops must be float or int") if node.type in ["SegmentSum", "SegmentMean", "SegmentSqrtN"]: onnx_op = "ReduceSum" identity_value = np.array(0, dtype=data_np_dtype) elif node.type == "SegmentProd": onnx_op = "ReduceProd" identity_value = np.array(1, dtype=data_np_dtype) elif node.type == "SegmentMax": onnx_op = "ReduceMax" if data_is_float: identity_value = np.array('-inf', dtype=data_np_dtype) else: identity_value = np.iinfo(data_np_dtype).min elif node.type == "SegmentMin": onnx_op = "ReduceMin" if data_is_float: identity_value = np.array('inf', dtype=data_np_dtype) else: identity_value = np.iinfo(data_np_dtype).max if not num_segments_specified: max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={'axes': [0], 'keepdims': 0}) one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype)) num_segments = ctx.make_node("Add", [max_segment.output[0], one_const.output[0]]).output[0] # ORT doesn't support bool for OneHot so we use float32 and cast to bool onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=np.float32)) # one_hot_node has shape [s, n] (s is # segments) one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments, onehot_values.output[0]], attr={'axis': 0}) if node.type == "SegmentMean": scaling_node_output = GraphBuilder(ctx).make_reduce_sum( {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1}) elif node.type == "SegmentSqrtN": seg_cnts_node_output = GraphBuilder(ctx).make_reduce_sum( {"data": one_hot_node.output[0], "axes": [1], "keepdims": 0, "noop_with_empty_axes": 1}) scaling_node_output = ctx.make_node("Sqrt", [seg_cnts_node_output]).output[0] else: scaling_node_output = None if scaling_node_output is not None and num_segments_specified: # If empty segments are possible, we must avoid division by zero const_one_float = ctx.make_const(utils.make_name("const_one_float"), np.array(1, dtype=np.float32)) scaling_node_output = ctx.make_node("Max", [scaling_node_output, const_one_float.output[0]]).output[0] if onnx_op == "ReduceSum": # If the op is a summation, we can use MatMul instead of Where, which is faster # Data shape is [n, a, b, ..., c] data_shape_node = ctx.make_node("Shape", [data_inp]) new_shape = ctx.make_const(utils.make_name("reshape_const"), np.array([0, -1], dtype=np.int64)) # Reshape the data from [n, a, b, ..., c] to [n, P] data_reshape = ctx.make_node("Reshape", [data_inp, new_shape.output[0]]) one_hot_cast = one_hot_node if data_dtype != onnx_pb.TensorProto.FLOAT: one_hot_cast = ctx.make_node("Cast", [one_hot_node.output[0]], attr={'to': data_dtype}) # Shapes [s, n] * [n, P] => [s, P] product = ctx.make_node("MatMul", [one_hot_cast.output[0], data_reshape.output[0]], op_name_scope=node.name) if scaling_node_output is not None: scaling_node_unsqueeze = GraphBuilder(ctx).make_unsqueeze( {'data': scaling_node_output, 'axes': [1]}, return_node=True) product = ctx.make_node("Div", [product.output[0], scaling_node_unsqueeze.output[0]]) # Create new shape [0, a, b, ..., c] max_int64 = int(utils.get_max_value(np.int64)) new_shape_slice = GraphBuilder(ctx).make_slice( {"data": data_shape_node.output[0], "ends": [max_int64], "starts": [1], "axes": [0]}) zero_const = ctx.make_const(utils.make_name("zero_const"), np.array([0], dtype=np.int64)) new_shape = ctx.make_node("Concat", [zero_const.output[0], new_shape_slice], attr={'axis': 0}) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) # Reshape result from [s, P] to [s, a, b, ..., c] ctx.make_node("Reshape", [product.output[0], new_shape.output[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes) return identity_const = ctx.make_const(utils.make_name("const_identity"), identity_value) one_hot_bool = ctx.make_node("Cast", [one_hot_node.output[0]], attr={"to": onnx_pb.TensorProto.BOOL}) one_hot_unsqueeze = one_hot_bool # Make one_hot_unsqueeze have shape [s, n, 1, 1, ..., 1] if data_rank is None: # Unsqueeze requires known rank, but we can use Reshape if rank is unknown shape_node = ctx.make_node("Shape", [data_inp]) rank_node = ctx.make_node("Shape", [shape_node.output[0]]) one_const_int64 = ctx.make_const(utils.make_name("const_one"), np.array([1], dtype=np.int64)) num_unsqueeze_dims = ctx.make_node("Sub", [rank_node.output[0], one_const_int64.output[0]]) one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[1]) unsqueeze_dims = ctx.make_node("ConstantOfShape", inputs=[num_unsqueeze_dims.output[0]], attr={"value": one_tensor}) # Zero indicates a dimension should be unchanged double_zero_const = ctx.make_const(utils.make_name("double_zero"), np.array([0, 0], dtype=np.int64)) expanded_shape = ctx.make_node("Concat", [double_zero_const.output[0], unsqueeze_dims.output[0]], attr={'axis': 0}) one_hot_unsqueeze = ctx.make_node("Reshape", [one_hot_bool.output[0], expanded_shape.output[0]]) elif data_rank > 1: new_dims = list(range(2, 2 + data_rank - 1)) one_hot_unsqueeze = GraphBuilder(ctx).make_unsqueeze( {'data': one_hot_bool.output[0], 'axes': new_dims}, return_node=True) # Shape of data: [n, a, b, ..., c] # Shape of one_hot: [s, n, 1, 1, ..., 1] # Broadcast left-pads shape with 1s, so result is shape: [s, n, a, b, ..., c] where_node = ctx.make_node("Where", [one_hot_unsqueeze.output[0], data_inp, identity_const.output[0]]) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) # After reduction over axis 1, shape is: [s, a, b, ..., c] ctx.make_node(onnx_op, [where_node.output[0]], attr={'axes': [1], 'keepdims': 0}, name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False): nodes_to_add = [] shape_node = self.g.make_node("Shape", [input_id]) nodes_to_add.append(shape_node) inferred_shape = self.g.get_shape(input_id) if handle_output is True: # handle output: # if required dim values don't contain more than one -1, # just use a const for Reshape's shape input. if inferred_shape is not None and inferred_shape[1:].count( -1) <= 1: new_shape_node = self.g.make_const( utils.make_name(target_name + "_target_shape"), np.array(inferred_shape[1:], dtype=np.int64)) nodes_to_add.append(new_shape_node) else: # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1) # from [1, time, real-batch, ...] origin_shape_node = self.g.make_node( "Cast", [shape_node.output[0]], {"to": onnx_pb.TensorProto.FLOAT}) nodes_to_add.append(origin_shape_node) attr = {"axes": [0], "starts": [1], "ends": [sys.maxsize]} inputs_map = {"data": origin_shape_node.output[0], **attr} sliced_shape_node = GraphBuilder(self.g).make_slice(inputs_map) nodes_to_add.append( self.g.get_node_by_output(sliced_shape_node)) new_shape_node = self.g.make_node( "Cast", [sliced_shape_node], {"to": onnx_pb.TensorProto.INT64}) nodes_to_add.append(new_shape_node) new_shape = inferred_shape[1:] else: # handle input: if inferred_shape is not None and inferred_shape.count(-1) <= 1: new_shape_node = self.g.make_const( utils.make_name(target_name + "_target_shape"), np.array([1] + inferred_shape, dtype=np.int64)) nodes_to_add.append(new_shape_node) else: # add a fake batch size : 1 fake_batch_size_node = self.g.make_const( utils.make_name(target_name + "_target_shape"), np.array([1], dtype=np.int64)) nodes_to_add.append(fake_batch_size_node) new_shape_node = self.g.make_node( "Concat", [fake_batch_size_node.output[0], shape_node.output[0]], attr={"axis": 0}) nodes_to_add.append(new_shape_node) new_shape = [1] + inferred_shape reshape_node = self.g.make_node("Reshape", [input_id, new_shape_node.output[0]], shapes=[new_shape], dtypes=[self.g.get_dtype(input_id)], op_name_scope=target_name) nodes_to_add.append(reshape_node) logger.debug("create Reshape for scan output %s, with output shape %s", reshape_node.output[0], new_shape) return nodes_to_add
def version_11(cls, ctx, node, **kwargs): # This ops is basically NMS with a little post-processing. # TFLite implementation: # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/kernels/detection_postprocess.cc # box_encodings.shape = [batch_dim, box_num, 4] # class_predictions.shape = [batch_dim, box_num, num_classes(+1)] # anchors.shape = [box_num, 4] box_encodings, class_predictions, anchors = node.input classes_dtype = ctx.get_dtype(node.output[1]) box_cnt_dtype = ctx.get_dtype(node.output[3]) num_classes = node.get_attr_value('num_classes') max_detections = node.get_attr_value('max_detections') # Remove 'other' class if present. max_int64 = int(utils.get_max_value(np.int64)) class_predictions = GraphBuilder(ctx).make_slice( {'data': class_predictions, 'starts': [-num_classes], 'ends': [max_int64], 'axes': [2]}) scaling_vector = [node.get_attr_value(a) for a in ['y_scale', 'x_scale', 'h_scale', 'w_scale']] scale_const = ctx.make_const(utils.make_name('scale_const'), np.array(scaling_vector, np.float32)).output[0] scaled_boxes = ctx.make_node('Div', [box_encodings, scale_const]).output[0] anchors_yx = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [0], 'ends': [2], 'axes': [1]}) anchors_hw = GraphBuilder(ctx).make_slice({'data': anchors, 'starts': [2], 'ends': [4], 'axes': [1]}) boxes_yx = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [0], 'ends': [2], 'axes': [2]}) boxes_hw = GraphBuilder(ctx).make_slice({'data': scaled_boxes, 'starts': [2], 'ends': [4], 'axes': [2]}) scaled_boxes_yx = ctx.make_node('Mul', [boxes_yx, anchors_hw]).output[0] boxes_hw_exp = ctx.make_node('Exp', [boxes_hw]).output[0] scaled_boxes_hw = ctx.make_node('Mul', [boxes_hw_exp, anchors_hw]).output[0] const_half = ctx.make_const(utils.make_name('const_half'), np.array(0.5, np.float32)).output[0] boxes_half_hw = ctx.make_node('Mul', [scaled_boxes_hw, const_half]).output[0] boxes_center_yx = ctx.make_node('Add', [scaled_boxes_yx, anchors_yx]).output[0] boxes_lower_left = ctx.make_node('Sub', [boxes_center_yx, boxes_half_hw]).output[0] boxes_upper_right = ctx.make_node('Add', [boxes_center_yx, boxes_half_hw]).output[0] adjusted_boxes = ctx.make_node('Concat', [boxes_lower_left, boxes_upper_right], attr={'axis': 2}).output[0] iou_threshold = np.array(node.get_attr_value('nms_iou_threshold'), np.float32) iou_threshold_const = ctx.make_const(utils.make_name('iou_threshold'), iou_threshold).output[0] score_threshold = np.array(node.get_attr_value('nms_score_threshold'), np.float32) score_threshold_const = ctx.make_const(utils.make_name('score_threshold'), score_threshold).output[0] if node.get_attr_value('use_regular_nms', False): boxes_per_class = np.array(node.get_attr_value('detections_per_class', 100), np.int64) else: # When tflite uses FastNMS, detections_per_class is ignored. logging.warning("NMS node %s uses fast NMS. ONNX will approximate with standard NMS.", node.name) boxes_per_class = np.array(max_detections, np.int64) max_boxes_per_class_const = ctx.make_const(utils.make_name('max_boxes_per_class'), boxes_per_class).output[0] # scores.shape = [batch_dim, classes_num, box_num] scores = ctx.make_node('Transpose', [class_predictions], attr={'perm': [0, 2, 1]}).output[0] nms_inputs = [adjusted_boxes, scores, max_boxes_per_class_const, iou_threshold_const, score_threshold_const] # shape: [-1, 3], elts of format [batch_index, class_index, box_index] selected_indices = ctx.make_node('NonMaxSuppression', nms_inputs, attr={'center_point_box': 0}, op_name_scope=node.name).output[0] selected_boxes_idx = GraphBuilder(ctx).make_slice( {'data': selected_indices, 'starts': [2], 'ends': [3], 'axes': [1]}) selected_boxes_idx_sq = GraphBuilder(ctx).make_squeeze({'data': selected_boxes_idx, 'axes': [1]}) selected_classes = GraphBuilder(ctx).make_slice( {'data': selected_indices, 'starts': [1], 'ends': [2], 'axes': [1]}) selected_classes_sq = GraphBuilder(ctx).make_squeeze({'data': selected_classes, 'axes': [1]}) box_and_class_idx = ctx.make_node('Concat', [selected_boxes_idx, selected_classes], attr={'axis': 1}).output[0] box_cnt = ctx.make_node('Shape', [selected_classes_sq]).output[0] adjusted_boxes_sq = GraphBuilder(ctx).make_squeeze({'data': adjusted_boxes, 'axes': [0]}) detection_boxes = ctx.make_node('Gather', [adjusted_boxes_sq, selected_boxes_idx_sq]).output[0] class_predictions_sq = GraphBuilder(ctx).make_squeeze({'data': class_predictions, 'axes': [0]}) detection_scores = ctx.make_node('GatherND', [class_predictions_sq, box_and_class_idx]).output[0] k_const = ctx.make_const(utils.make_name('const_k'), np.array([max_detections], np.int64)).output[0] if ctx.opset >= 12: min_k = ctx.make_node('Min', [k_const, box_cnt]).output[0] else: # Lower opsets only support Min between floats box_cnt_float = ctx.make_node('Cast', [box_cnt], attr={'to': TensorProto.FLOAT}).output[0] k_const_float = ctx.make_node('Cast', [k_const], attr={'to': TensorProto.FLOAT}).output[0] min_k_float = ctx.make_node('Min', [k_const_float, box_cnt_float]).output[0] min_k = ctx.make_node('Cast', [min_k_float], attr={'to': TensorProto.INT64}).output[0] min_k_cast = ctx.make_node('Cast', [min_k], attr={'to': box_cnt_dtype}).output[0] scores_top_k, scores_top_k_idx = ctx.make_node('TopK', [detection_scores, min_k], output_count=2).output scores_top_k_idx_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k_idx, 'axes': [0]}) scores_top_k_unsq = GraphBuilder(ctx).make_unsqueeze({'data': scores_top_k, 'axes': [0]}) selected_classes_sort = ctx.make_node('Gather', [selected_classes_sq, scores_top_k_idx_unsq]).output[0] classes_sort_cast = ctx.make_node('Cast', [selected_classes_sort], attr={'to': classes_dtype}).output[0] detection_boxes_sorted = ctx.make_node('Gather', [detection_boxes, scores_top_k_idx_unsq]).output[0] pad_amount = ctx.make_node('Sub', [k_const, min_k]).output[0] quad_zero_const = ctx.make_const(utils.make_name('quad_zero_const'), np.array([0, 0, 0, 0], np.int64)).output[0] duo_zero_const = ctx.make_const(utils.make_name('duo_zero_const'), np.array([0, 0], np.int64)).output[0] zero_const = ctx.make_const(utils.make_name('zero_const'), np.array([0], np.int64)).output[0] pads_3d = ctx.make_node('Concat', [quad_zero_const, pad_amount, zero_const], attr={'axis': 0}).output[0] pads_2d = ctx.make_node('Concat', [duo_zero_const, zero_const, pad_amount], attr={'axis': 0}).output[0] detection_boxes_padded = ctx.make_node('Pad', [detection_boxes_sorted, pads_3d]).output[0] detection_classes_padded = ctx.make_node('Pad', [classes_sort_cast, pads_2d]).output[0] detection_scores_padded = ctx.make_node('Pad', [scores_top_k_unsq, pads_2d]).output[0] ctx.replace_all_inputs(node.output[0], detection_boxes_padded) ctx.replace_all_inputs(node.output[1], detection_classes_padded) ctx.replace_all_inputs(node.output[2], detection_scores_padded) ctx.replace_all_inputs(node.output[3], min_k_cast) ctx.remove_node(node.name)
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # create a dummy loop to calculate the init condition init_cond_output = self._create_subgraph_initial_cond(cond_g_info) ## create Loop body graph with existing nodes body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: shape = out_tensor_value_info.shape utils.make_sure( shape is not None, "Conversion of Loop requries output shape [{}] exists".format(out_tensor_value_info.id) ) out_tensor_value_info.shape = utils.create_vague_shape_like(shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. gb = GraphBuilder(loop_body_g) index_node = gb.make_unsqueeze({'data': input_ta.index_input_id, "axes": [0]}, return_node=True) gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = gb.make_squeeze({'data': gather_node.output[0], "axes": [0]}, return_node=True) loop_body_g.replace_all_inputs(input_ta.consumer.id, data_node.output[0]) # ops=loop_body_g.get_nodes() ## create Loop node branches = {"body": loop_body_g} loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def version_10(cls, ctx, node, **kwargs): x = node.input[0] x_shape = ctx.get_shape(x) h = node.input[1] h_shape = ctx.get_shape(h) p = node.input[3] utils.make_sure(node.attr["rnn_mode"].s == b"gru", "rnn mode other than gru are not supported yet") utils.make_sure(node.attr["dropout"].f == 0, "dropout not supported yet") utils.make_sure(node.attr["input_mode"].s == b"linear_input", "input mode must be linear input") num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2 num_layers = int(h_shape[0] / num_dirs) num_units = hidden_size = h_shape[2] input_size = x_shape[2] w_shape = [num_layers * num_dirs, 3 * hidden_size, input_size] w_shape_const = ctx.make_const(utils.make_name("w_shape"), np.array(w_shape, dtype=np.int64)) r_shape = [num_layers * num_dirs, 3 * hidden_size, hidden_size] r_shape_const = ctx.make_const(utils.make_name("r_shape"), np.array(r_shape, dtype=np.int64)) b_shape = [num_layers * num_dirs, 6 * hidden_size] b_shape_const = ctx.make_const(utils.make_name("b_shape"), np.array(b_shape, dtype=np.int64)) zero_const = ctx.make_const(utils.make_name("zero"), np.array([0], dtype=np.int64)) w_end = np.prod(w_shape) w_end_const = ctx.make_const(utils.make_name("w_end"), np.array([w_end], dtype=np.int64)) r_end = w_end + np.prod(r_shape) r_end_const = ctx.make_const(utils.make_name("r_end"), np.array([r_end], dtype=np.int64)) b_end = r_end + np.prod(b_shape) b_end_const = ctx.make_const(utils.make_name("b_end"), np.array([b_end], dtype=np.int64)) def name(nm): return node.name + "_" + nm ws = [name('W_' + str(i)) for i in range(num_layers * num_dirs)] rs = [name('R_' + str(i)) for i in range(num_layers * num_dirs)] bs = [name('B_' + str(i)) for i in range(num_layers * num_dirs)] hs = [name('H_' + str(i)) for i in range(num_layers * num_dirs)] yhs = [name('YH_' + str(i)) for i in range(num_layers * num_dirs)] w_flattened = ctx.make_node( 'Slice', [p, zero_const.output[0], w_end_const.output[0]]) r_flattened = ctx.make_node( 'Slice', [p, w_end_const.output[0], r_end_const.output[0]]) b_flattened = ctx.make_node( 'Slice', [p, r_end_const.output[0], b_end_const.output[0]]) w = utils.make_name('W') r = utils.make_name('R') b = utils.make_name('B') ctx.make_node('Reshape', [w_flattened.output[0], w_shape_const.output[0]], outputs=[w]) ctx.make_node('Reshape', [r_flattened.output[0], r_shape_const.output[0]], outputs=[r]) ctx.make_node('Reshape', [b_flattened.output[0], b_shape_const.output[0]], outputs=[b]) ctx.make_node('Split', [w], outputs=ws) ctx.make_node('Split', [r], outputs=rs) ctx.make_node('Split', [b], outputs=bs) ctx.make_node('Split', [h], outputs=hs) builder = GraphBuilder(ctx) xnf = xnb = x for i in range(num_layers): suffix = '_' + str(i * num_dirs) ctx.make_node('GRU', [ xnf, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix) ], outputs=[name('Y' + suffix), name('YH' + suffix)], attr={ 'direction': 'forward', 'hidden_size': num_units }) xnf = name(x + suffix) builder.make_squeeze({ 'data': name('Y' + suffix), 'outputs': [xnf], 'axes': [1] }) if num_dirs == 2: suffix = '_' + str(i * 2 + 1) ctx.make_node( 'GRU', [ xnb, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix) ], outputs=[name('Y' + suffix), name('YH' + suffix)], attr={ 'direction': 'reverse', 'hidden_size': num_units }) xnb = name(x + suffix) builder.make_squeeze({ 'data': name('Y' + suffix), 'outputs': [xnb], 'axes': [1] }) ctx.remove_node(node.name) if num_dirs == 2: ctx.make_node('Concat', [xnf, xnb], outputs=[node.output[0]], attr={'axis': -1}) else: ctx.make_node('Identity', [xnf], outputs=[node.output[0]]) ctx.make_node('Concat', yhs, outputs=[node.output[1]], attr={'axis': 0})
def version_7(cls, ctx, node, **kwargs): tfl_while_inputs = node.input output_shapes = node.output_shapes output_dtypes = node.output_dtypes output_names = node.output cond_name = node.get_attr_str("cond_subgraph_index") cond_graph = find_function(cond_name) cond_graph.parent_graph = ctx body_name = node.get_attr_str("body_subgraph_index") body = find_function(body_name) body.parent_graph = ctx ctx.remove_node(node.name) cond_binding = parameter_binding(cond_graph, tfl_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter. # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph. scan_outputs = sorted(body.scan_outputs, reverse=True) def input_is_unused(g, index): return len(g.find_output_consumers(g.inputs[index])) == 0 scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)] for idx, _ in scan_outputs: del tfl_while_inputs[idx] output_shapes.append(output_shapes.pop(idx)) output_dtypes.append(output_dtypes.pop(idx)) output_names.append(output_names.pop(idx)) max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max)) loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs, output_count=len(output_shapes), name=node.name + "_loop", shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True) output_map = dict(zip(output_names, loop_node.output)) # shift output consumers for k, v in output_map.items(): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs) for i in range(len(scan_outputs)): squeeze_node = GraphBuilder(body).make_squeeze( { 'data': body.outputs[-1 - i], "axes": [0] }, return_node=True) body.outputs[-1 - i] = squeeze_node.output[0] loop_node.set_body_graph_as_attr("body", body)
def rewrite_eye(g, ops): # schema of eye is eye(num_rows, num_columns=None), if num_columns not specified then it's equal to num_rows # tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while # these two ops are un-supported directly in onnx # but onnx op EyeLike can be used to map the sub_graph # "rewrite_eye" supports tf.eye(non_const) and tf.eye(non_const1, non_const2). # tf.eye(const) and tf.eye(const1, const2) are not supported in this rewriter # ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing if g.opset < 9: return g.get_nodes() pattern1 = \ OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("Const", name="fill_value"), OpTypePattern("ConcatV2", inputs=[ "*", "*", OpTypePattern("Pack", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast") ]) ]) ]) ]) pattern2 = \ OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("Const", name="fill_value"), OpTypePattern("ConcatV2", inputs=[ "*", "*", OpTypePattern("Pack", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast") ]) ]) ]) ]) pattern3 = \ OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]) ]) pattern4 = \ OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]), ]) pattern5 = \ OpTypePattern("MatrixDiagV3", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]), "*", "*", "*", "*", ]) pattern6 = \ OpTypePattern("MatrixSetDiagV3", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]), "*" ]) pattern7 = \ OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("Reshape", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*", ]), OpTypePattern("Const", name="fill_value"), ]) ]) pattern8 = \ OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("Reshape", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*", ]), OpTypePattern("Const", name="fill_value"), ]) ]) for pattern in [ pattern1, pattern2, pattern3, pattern4, pattern5, pattern6, pattern7, pattern8 ]: matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match_result in match_results: if match_result.get_op("fill_value").get_tensor_value() != 1: continue min_or_cast = match_result.get_op("min_or_cast") if min_or_cast.type == "Minimum": min_node = min_or_cast elif min_or_cast.type == "Cast" and min_or_cast.inputs[ 0].type == "Minimum": min_node = min_or_cast.inputs[0] else: continue num_rows = min_node.inputs[0] num_columns = min_node.inputs[1] old_output = match_result.get_op("output_eye_matrix") output_dtypes = [g.get_dtype(old_output.output[0])] output_shapes = [g.get_shape(old_output.output[0])] g.remove_node(old_output.name) # onnx op "EyeLike" need a 2D tensor, so generate it num_rows = GraphBuilder(g).make_unsqueeze( { "axes": [0], "data": num_rows.output[0] }, return_node=True) num_columns = GraphBuilder(g).make_unsqueeze( { "axes": [0], "data": num_columns.output[0] }, return_node=True) matrix_shape = g.make_node( "Concat", [num_rows.output[0], num_columns.output[0]], attr={"axis": 0}) # cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data. matrix_shape_int64 = g.make_node( "Cast", matrix_shape.output, attr={"to": onnx_pb.TensorProto.INT64}) zero_matrix = g.make_node("ConstantOfShape", matrix_shape_int64.output) g.make_node("EyeLike", zero_matrix.output, attr={"dtype": output_dtypes[0]}, name=old_output.name, shapes=output_shapes, dtypes=output_dtypes, outputs=old_output.output) return g.get_nodes()