def version_6(cls, ctx, node, **kwargs): # T output = FloorDiv(T x, T y) node.type = "Div" dtype = ctx.get_dtype(node.input[0]) if dtype in [ onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.DOUBLE ]: new_node_name = utils.make_name("floor_div_res") floor_res = ctx.insert_new_node_on_output( op_type="Floor", output_name=node.output[0], name=new_node_name) ctx.copy_dtype(node.output[0], floor_res.output[0]) ctx.copy_shape(node.output[0], floor_res.output[0])
def process_var_init_nodes(self, context): assert "state" in context.state_variables.keys() initializer_input_id = context.state_variables["state"].enter_input_id node = self.g.get_node_by_output(initializer_input_id) if node.is_const(): val = node.get_tensor_value(as_list=False) initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) context.onnx_input_ids["initial_state"] = const_node.output[0] return squeeze_node = self.g.make_node("Unsqueeze", [initializer_input_id], attr={"axes": [0]}) to_replace = [n for n in self.g.get_nodes() if n != squeeze_node] self.g.replace_all_inputs(initializer_input_id, squeeze_node.output[0], ops=to_replace) context.onnx_input_ids["initial_state"] = squeeze_node.output[0]
def make_range_const(ctx, start, limit, delta, output, scope_name, shape, dtype): """make Range subgraph if all inputs are const.""" # T range = Range(T start, T limit, T delta) # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial) base_name = utils.make_name(scope_name) start = ctx.get_node_by_output(start).get_tensor_value(as_list=False) limit = ctx.get_node_by_output(limit).get_tensor_value(as_list=False) delta = ctx.get_node_by_output(delta).get_tensor_value(as_list=False) val = np.arange(start, limit, delta, dtype=start.dtype) const_range = ctx.make_const(base_name, val) ctx.make_node("Identity", [const_range.output[0]], shapes=[shape], dtypes=[dtype], outputs=[output])
def version_1(cls, ctx, node, **kwargs): if node.type == "StringSplit": skip_empty = node.get_attr_value('skip_empty', True) else: skip_empty = False node.type = "StringSplit" node.domain = constants.CONTRIB_OPS_DOMAIN for a in list(node.attr.keys()): del node.attr[a] unsqueeze_node = ctx.make_node("Unsqueeze", [node.input[1]], attr={'axes': [0]}) skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'), np.array([skip_empty], np.bool)) ctx.replace_inputs(node, [ node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0] ])
def test_override_shape(self): inputs = [INPUT1] shapes = [[1, 3, 4, 1]] dtypes = [TensorProto.FLOAT] graph = self._create_empty_graph(inputs, shapes, dtypes) output_name = utils.make_name("output") graph._output_shapes[output_name] = [-1, -1, 2, 3] # pylint: disable=protected-access node = graph.make_node("Transpose", [INPUT1], attr={"perm": [1, 0, 2, 3]}, outputs=[output_name]) graph.update_node_shape_dtype(node, override=True) graph.add_graph_output(node.output[0]) self._run_test_case( graph, self._generate_random_inputs(inputs, shapes, dtypes))
def connect_initializer_node(self, initializer_input_id, hidden_size): node = self.g.get_node_by_name(initializer_input_id) if node.is_const(): val = node.get_tensor_value() initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) return const_node.output[0] else: squeeze_node = make_onnx_node(self.g, "Unsqueeze", [initializer_input_id], attr={"axes": [0]}) self.g.replace_all_inputs(self.g.get_nodes(), initializer_input_id, squeeze_node.output[0]) self.all_nodes.append(squeeze_node) return squeeze_node.output[0]
def create_if_op(input_ids, output_data_type, output_shape): op_name = utils.make_name("If") true_graph = create_body_graph_for_if_branch(output_data_type, output_shape, input_ids[1], op_name) false_graph = create_body_graph_for_if_branch(output_data_type, output_shape, input_ids[2], op_name) out_name = port_name(op_name) # output a scalar if_node = helper.make_node("If", [input_ids[0]], [out_name], name=op_name, then_branch=true_graph, else_branch=false_graph) return if_node, out_name
def create_if_op(ctx, node, cur_cond_val_out_name): data_shape = get_hidden_size_best_effort(ctx, node) true_graph = create_body_graph_for_if_branch(ctx, node.input[1], data_shape) false_graph = create_body_graph_for_if_branch(ctx, node.input[2], data_shape) op_name = utils.make_name("If") out_name = port_name(op_name) # output a scalar if_node = helper.make_node("If", [cur_cond_val_out_name], [out_name], name=op_name, then_branch=true_graph, else_branch=false_graph) return if_node, out_name
def version_1(cls, ctx, node, **kwargs): dtype = ctx.get_dtype(node.input[0]) if dtype != TensorProto.STRING: # Fallback to normal domain conversion func, _ = handler.tf_op.find_effective_op(node.type, constants.ONNX_DOMAIN) func(ctx, node, **kwargs) return need_not = node.type == "NotEqual" node.type = "StringEqual" node.domain = constants.CONTRIB_OPS_DOMAIN if need_not: output_name = node.output[0] not_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name)) ctx.copy_shape(output_name, not_node.output[0]) ctx.copy_dtype(output_name, not_node.output[0])
def version_7(cls, ctx, node, **kwargs): # T output = Fill(int32 dims, T value, @int32 index_type) # T outputs = Tile(T value, int64 repeats (e.g. dims)) fill_shape = ctx.get_shape(node.input[0]) utils.make_sure(fill_shape is not None, "shape of {} is None".format(node.input[0])) fill_shape_dims = fill_shape[0] utils.make_sure(fill_shape_dims > 0, "opset 7 requires fill shape length > 0, or please try opset > 7") val_dtype = ctx.get_dtype(node.input[1]) val_shape = ctx.get_shape(node.input[1]) need_cast = val_dtype != onnx_pb.TensorProto.FLOAT and ctx.opset < 9 new_dtype = val_dtype if need_cast: new_dtype = onnx_pb.TensorProto.FLOAT attr = {"to": new_dtype} cast_to_float = ctx.insert_new_node_on_input(node, "Cast", node.input[1], name=None, **attr) ctx.set_dtype(cast_to_float.output[0], new_dtype) ctx.set_shape(cast_to_float.output[0], val_shape) for _ in range(fill_shape_dims): attr = {"axes": [0]} shape = ctx.get_shape(node.input[1]) unsqueeze_node = ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], name=None, **attr) ctx.set_dtype(unsqueeze_node.output[0], new_dtype) if shape: shape = [1] + shape else: shape = [1] ctx.set_shape(unsqueeze_node.output[0], shape) # Tile's repeats must be INT64 attr = {"to": onnx_pb.TensorProto.INT64} tile_shape_int64 = ctx.insert_new_node_on_input(node, "Cast", node.input[0], name=None, **attr) ctx.set_dtype(tile_shape_int64.output[0], onnx_pb.TensorProto.INT64) ctx.set_shape(tile_shape_int64.output[0], fill_shape) tmp = node.input[0] ctx.replace_input(node, node.input[0], node.input[1], 0) ctx.replace_input(node, node.input[1], tmp, 1) node.type = "Tile" ctx.set_dtype(node.output[0], new_dtype) if need_cast: attr = {"to": val_dtype} op_name = utils.make_name(node.name + "/cast_back") cast_back = ctx.insert_new_node_on_output("Cast", node.output[0], name=op_name, **attr) ctx.set_dtype(cast_back.output[0], val_dtype)
def create_graph_from_onnx_graph(graph_proto): """Create Graph loading onnx graph proto.""" output_shapes = {} output_dtypes = {} shapes, dtypes = GraphUtil._parse_shape_and_type_from_value_infos( graph_proto.value_info) output_shapes.update(shapes) output_dtypes.update(dtypes) shapes, dtypes = GraphUtil._parse_shape_and_type_from_value_infos( graph_proto.output) output_shapes.update(shapes) output_dtypes.update(dtypes) nodes_to_append = [] for n in graph_proto.node: if n.op_type == "Constant": n.op_type = "Const" # some pytorch model had empty names - make one up if not n.name: n.name = utils.make_name("was_empty") nodes_to_append.append(n) output_names = [] for n in graph_proto.output: output_names.append(n.name) g = Graph(nodes_to_append, output_shapes, output_dtypes, None, None, None, output_names) const_nodes_from_initializer = GraphUtil._parse_graph_initializer( g, graph_proto) all_nodes = g.get_nodes() all_nodes.extend(const_nodes_from_initializer) g.set_nodes(all_nodes) GraphUtil._parse_graph_input(g, graph_proto) for n in g.get_nodes(): for attr_name, attr_val in n.attr.items(): if attr_val.HasField('g'): # it was assumed that the a.g has inferred shapes/dtypes. sub_g = GraphUtil.create_graph_from_onnx_graph(attr_val.g) n.set_body_graph_as_attr(attr_name, sub_g) return g
def convert_to_input(self, tensor, is_optional=False): """in ONNX, input shold come from node, so it must be a string""" if is_optional and tensor is None: return None utils.make_sure(tensor is not None, "input is required so it couldn't be None") res = tensor if isinstance(tensor, list): res = self.graph.make_const(utils.make_name("const_slice"), np.array(tensor)).output[0] utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed") return res
def rewrite_tfl_rfft(g, ops): pattern0 = \ OpTypePattern('TFL_COMPLEX_ABS', name='complex_abs', inputs=[ OpTypePattern('TFL_RESHAPE', name='reshape', inputs=[ OpTypePattern('TFL_RFFT2D', name='rfft2d', inputs=[ OpTypePattern('*'), OpTypePattern('Const|ConstV2', name='length'), ]), OpTypePattern('Const|ConstV2', name='shape'), ], allow_reorder=True), ]) matcher = GraphMatcher(pattern0, allow_reorder=False) match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: length = match.get_op("length").get_tensor_value(as_list=True) rfft2d = match.get_op("rfft2d") complex_abs = match.get_op("complex_abs") reshape = match.get_op("reshape") shape = match.get_op("shape").get_tensor_value(as_list=True) output_shape = g.get_shape(rfft2d.output[0]) if output_shape is None or output_shape != shape[:-1] + [ 1, shape[-1] ]: continue if length[0] != 1: continue rfft2d.type = "RFFT" g.copy_shape(complex_abs.input[0], rfft2d.output[0]) # Skip the Reshape g.replace_input(complex_abs, complex_abs.input[0], rfft2d.output[0], 0) new_length = g.make_const(utils.make_name("rfft_length"), np.array([length[1]], np.int64)) g.replace_input(rfft2d, rfft2d.input[1], new_length.output[0], 1) g.replace_all_inputs(complex_abs.output[0], reshape.output[0]) # Move reshape below complex abs g.replace_input(reshape, reshape.input[0], complex_abs.output[0], 0) return ops
def rewrite_random_normal(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', name='input2', inputs=[ OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*" ]), "*" ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') mean = output.inputs[1].get_tensor_value() dtype = g.get_dtype(output.output[0]) op_name = utils.make_name("RandomNormal") out_name = utils.port_name(op_name) rn_op = match.get_op('input1') seed = rn_op.get_attr('seed2').i if rn_op.inputs[0].type == "Shape": shape_node = rn_op.inputs[0] new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name, attr={ "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed }) else: shape = g.get_shape(output.output[0]) new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, attr={ "shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed }) g.replace_all_inputs(ops, output.output[0], new_node.output[0]) g.safe_remove_nodes(match.get_nodes()) return ops
def version_6(cls, ctx, node, **kwargs): if node.type == "Log": # ORT doesn't implement Log on doubles double_to_float = { onnx_pb.TensorProto.DOUBLE: onnx_pb.TensorProto.FLOAT } dtypes = node.output_dtypes if node.maybe_cast_input([[onnx_pb.TensorProto.FLOAT]], double_to_float): cast_back_node = ctx.insert_new_node_on_output( "Cast", node.output[0], name=utils.make_name(node.name + "_castback"), to=dtypes[0]) ctx.set_dtype(cast_back_node.output[0], dtypes[0]) ctx.copy_shape(node.name, cast_back_node.output[0]) ctx.copy_dtype(node.input[0], node.output[0])
def shape_op(ctx, node, name, args): # FIXME - this is not correct shape = ctx.get_shape(node.input[0]) if not shape: shape.append(1) if shape[0] is None or shape[0] == -1: shape[0] = 1 old_output = node.output[0] node_name = utils.make_name(node.name) new_node = ctx.make_const(node_name, "Const", np.zeros(shape, dtype=np.float32)) new_node.output.append(node_name + ":0") for n in ctx.get_nodes(): for i, input_name in enumerate(n.input): if input_name == old_output: n.input[i] = node_name + ":0" break return new_node
def version_4(cls, ctx, node, **kwargs): # T output = ZerosLike(T x) # when params "dtype" used, tf will call another op "Fill" instead, so Cast is not needed here. input_dtype = ctx.get_dtype(node.input[0]) node_name = utils.make_name("zero") const_zero = ctx.make_const( node_name, np.array(0).astype(utils.map_onnx_to_numpy_type(input_dtype))) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) ctx.make_node(op_type="Mul", inputs=[node.input[0], const_zero.output[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node): label_dtype = ctx.get_dtype(label.output[0]) logit_dtype = ctx.get_dtype(logit.output[0]) utils.make_sure(label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit") log_softmax = ctx.make_node(op_type="LogSoftmax", inputs=logit.output) # implement tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1)) mul1 = ctx.make_node(op_type="Mul", inputs=[label.output[0], log_softmax.output[0]]) reduce_sum = ctx.make_node(op_type="ReduceSum", inputs=[mul1.output[0]], attr={"axes": [-1]}) const_negative_one = ctx.make_const(name=utils.make_name("const_negative_one"), np_val=np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype])) mul2 = ctx.make_node(op_type="Mul", inputs=[const_negative_one.output[0], reduce_sum.output[0]]) shapes = tf_ori_node.output_shapes dtypes = tf_ori_node.output_dtypes ctx.remove_node(tf_ori_node.name) ctx.make_node(op_type="Squeeze", inputs=[mul2.output[0]], attr={"axes": [1]}, outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
def version_7(cls, ctx, node, **kwargs): # T output = Select(bool condition, T x, T y) # Select_res = Add(Multiply(Cast(bool condition, T), T x,), # Multiply(Cast(Not(bool condition), T), T y)). # TODO: Fix case where condition is 1-dimensional utils.make_sure( len(node.input) > 1, "Select with only condition is not supported.") dtype = ctx.get_dtype(node.output[0]) utils.make_sure(dtype != TensorProto.STRING, "Select with dtype string requires opset 9") cond_shape = ctx.get_shape(node.input[0]) input_shape = ctx.get_shape(node.input[1]) if input_shape is None: input_shape = ctx.get_shape(node.input[2]) input_rank = len(input_shape) if input_shape is not None else None cond_rank = len(cond_shape) if cond_shape is not None else None # if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast if node.type == "Select" and cond_rank == 1 and input_rank != 1: utils.make_sure(input_rank is not None, "input_rank unknown and cond_rank == 1") broadcast_shape = [cond_shape[0]] + [1] * (input_rank - 1) shape_const = ctx.make_const( utils.make_name(node.name), np.array(broadcast_shape, dtype=np.int64)) reshape = ctx.make_node("Reshape", [node.input[0], shape_const.output[0]]) ctx.replace_input(node, node.input[0], reshape.output[0], 0) positive_cast = ctx.make_node("Cast", [node.input[0]], name=utils.make_name(node.name), attr={"to": dtype}) negative = ctx.make_node("Not", [node.input[0]], name=utils.make_name(node.name)) negative_cast = ctx.make_node("Cast", [negative.output[0]], name=utils.make_name(node.name), attr={"to": dtype}) multiply_1 = ctx.make_node("Mul", [positive_cast.output[0], node.input[1]], name=utils.make_name(node.name)) multiply_2 = ctx.make_node("Mul", [node.input[2], negative_cast.output[0]], name=utils.make_name(node.name)) add_name = node.name add_out = node.output shape = ctx.get_shape(node.output[0]) ctx.remove_node(node.name) ctx.make_node("Add", [multiply_1.output[0], multiply_2.output[0]], outputs=add_out, name=add_name, dtypes=[dtype], shapes=[shape])
def create_if_op(g, input_ids, output_data_type, output_shape): op_name = utils.make_name("If") true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name) false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name) out_name = utils.port_name(op_name) # output a scalar if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name, skip_conversion=True) if_node.set_body_graph_as_attr("then_branch", true_graph) if_node.set_body_graph_as_attr("else_branch", false_graph) return if_node, out_name
def create_if_op(g, input_ids, output_data_type, output_shape): op_name = utils.make_name("If") true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name) false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name) out_name = utils.port_name(op_name) # output a scalar branches = {"then_branch": true_graph, "else_branch": false_graph} if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name, skip_conversion=True, branches=branches) return if_node, out_name
def reshape_op5(ctx, node, name, args): shape_node = node.inputs[1] # onnx wants reshape.input[1] to have the value be int64 which is not the case for tensorflow. name = node.input[1] if shape_node.is_const(): # if it is a const, change the const to be int64 shape = shape_node.get_tensor_value() shape = np.array(list(shape), dtype=np.int64) onnx_tensor = numpy_helper.from_array(shape, name) ctx._initializers[name] = onnx_tensor shape_node.set_attr("value", onnx_tensor) return node else: op_name = utils.make_name(node.name) cast_op = ctx.insert_new_node_on_input(node, "Cast", name, name=op_name) cast_op.set_attr("to", onnx_pb.TensorProto.INT64) ctx.copy_shape(name, op_name + ":0") return [cast_op, node]
def _make_onnx_node(self, operation_type, input_names_with_output_id, attribute=None, output_num=1): op_name = utils.make_name(operation_type) out_names = [] for i in range(output_num): out_names.append(op_name + ":" + str(i)) n = helper.make_node(operation_type, input_names_with_output_id, out_names, name=op_name) if attribute: n.attribute.extend(attribute) return Node(n, self._g)
def version_1(cls, ctx, node, **kwargs): shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64}) const_zero = ctx.make_const(utils.make_name("zero"), np.array(0).astype(np.int64)) mul_node = ctx.make_node( 'Mul', inputs=[casted_input.output[0], const_zero.output[0]]) ctx.make_node("Cast", inputs=[mul_node.output[0]], attr={'to': dtypes[0]}, name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def _process_init_nodes(self, initializer_input_id, rnn_props): # copy from lstm_rewriter # todo: remove this once Fill ops is supported fill_ch_init_node = self._workaround_fill_ch_init_node(initializer_input_id, rnn_props) if fill_ch_init_node: return fill_ch_init_node.output[0] node = self.g.get_node_by_name(initializer_input_id) if node.is_const(): val = node.get_tensor_value() initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) return const_node.output[0] squeeze_node = make_onnx_node(self.g, "Unsqueeze", [initializer_input_id], attr={"axes": [0]}) self.g.replace_all_inputs(self.g.get_nodes(), initializer_input_id, squeeze_node.output[0]) self.all_nodes.append(squeeze_node) return squeeze_node.output[0]
def version_10(cls, ctx, node, **kwargs): if node.type == "StringLower": case_action = "LOWER" else: case_action = "UPPER" node.type = "StringNormalizer" str_input = node.input[0] rank = ctx.get_rank(node.input[0]) shape = ctx.get_shape(node.input[0]) if rank != 1: ctx.insert_new_node_on_input(node, "Flatten", node.input[0], axis=0) node.set_attr("case_change_action", case_action) if rank != 1: if shape is None or -1 in shape: new_shape = ctx.make_node("Shape", [str_input]).output[0] else: new_shape = ctx.make_const(utils.make_name("shape"), np.array(shape, np.int64)).output[0] ctx.insert_new_node_on_output("Reshape", node.output[0], inputs=[node.output[0], new_shape])
def version_7(cls, ctx, node, **kwargs): # T2 output = Equal(T1, x, T1 y), T1 \in {bool, int32, int64} need_not = node.type == "NotEqual" supported_dtypes = [ TensorProto.BOOL, TensorProto.INT32, TensorProto.INT64 ] # FIXME: casting is not the same as equal target_dtype = TensorProto.INT32 _add_cast_to_inputs(ctx, node, supported_dtypes, target_dtype) if need_not: node.type = "Equal" output_name = node.output[0] not_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name( node.name)) ctx.copy_shape(output_name, not_node.output[0]) ctx.copy_dtype(output_name, not_node.output[0])
def _make_sparse_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node): logit = logit.output[0] label = label.output[0] label_dtype = ctx.get_dtype(label) logit_dtype = ctx.get_dtype(logit) utils.make_sure( label_dtype == logit_dtype, "the following logic only works on same dtype of label and logit") # when label is onehot, logic "tf.multiply(-1, tf.reduce_sum(tf.multiply(label, log_softmax), axis=1))" is equal to # "-log(q_i)" where i is the selected index specified by label, q_i = logic_i/sum, the detail process is as follows: # logit_exp=exp(logit) >> sum = tf.reduce_sum(logit_exp, axis = -1), masked_sum = reduce_sum(mul(logit_exp, mul)) # >> -log(masked_sum/sum) logit_exp = ctx.make_node(op_type="Exp", inputs=[logit]).output[0] logit_exp_sum = ctx.make_node(op_type="ReduceSum", inputs=[logit_exp], attr={ "axes": [-1], "keepdims": 0 }).output[0] masked = ctx.make_node(op_type="Mul", inputs=[label, logit_exp]).output[0] masked_sum = ctx.make_node(op_type="ReduceSum", inputs=[masked], attr={ "axes": [-1], "keepdims": 0 }).output[0] probability = ctx.make_node(op_type="Div", inputs=[masked_sum, logit_exp_sum]).output[0] log_prob = ctx.make_node(op_type="Log", inputs=[probability]).output[0] const_negative_one = ctx.make_const( name=utils.make_name("const_negative_one"), np_val=np.array(-1).astype( utils.ONNX_TO_NUMPY_DTYPE[logit_dtype])).output[0] shapes = tf_ori_node.output_shapes dtypes = tf_ori_node.output_dtypes ctx.remove_node(tf_ori_node.name) res = ctx.make_node(op_type="Mul", inputs=[log_prob, const_negative_one], outputs=[tf_ori_node.output[0]], shapes=[shapes[0]], dtypes=[dtypes[0]])
def version_1(cls, ctx, node, **kwargs): # ONNX: Each input value is divided by (bias+(alpha/size)*sum(xi^2 for every xi in the local region))^beta # TF: sqr_sum[a, b, c, d] = sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) # output = input / (bias + alpha * sqr_sum) ** beta # by default, depth_radius is 5 in tensorflow size = node.get_attr_value("depth_radius", 5) * 2 + 1 node.set_attr("size", size) node.set_attr("alpha", size * node.get_attr("alpha").f) shapes = node.output_shapes[0] dtypes = node.output_dtypes[0] ctx.insert_new_node_on_input(node, "Transpose", node.input[0], perm=constants.NHWC_TO_NCHW) ctx.update_node_shape_dtype(node, override=True) op_name = utils.make_name(node.name) ctx.insert_new_node_on_output("Transpose", node.output[0], perm=constants.NCHW_TO_NHWC, name=op_name, shapes=shapes, dtypes=dtypes)
def version_9(cls, ctx, node, **kwargs): # T output = Select(bool condition, T x, T y) # T1 output = Where(bool condition, T1 x, T1 y) # NOTE: condition can be 1-dimension in tensorflow, while in onnx, # it should be broadcastable with other two inputs node.type = "Where" cond_shape = ctx.get_shape(node.input[0]) make_sure(cond_shape is not None, "shape of {} is None".format(node.input[0])) input_shape = ctx.get_shape(node.input[1]) if input_shape is None: input_shape = ctx.get_shape(node.input[2]) make_sure(input_shape is not None, "input shape of {} is None".format(node.name)) input_rank = len(input_shape) # if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast if len(cond_shape) == 1 and input_rank > 1: broadcast_shape = [cond_shape[0]] + [1] * (input_rank - 1) shape_const = ctx.make_const(utils.make_name(node.name), np.array(broadcast_shape, dtype=np.int64)) reshape = ctx.make_node("Reshape", [node.input[0], shape_const.output[0]]) ctx.replace_input(node, node.input[0], reshape.output[0])