def version_11(cls, ctx, node, **kwargs): supported_dtypes = [ onnx_pb.TensorProto.INT32, onnx_pb.TensorProto.INT64 ] onnx_dtype = ctx.get_dtype(node.input[0]) utils.make_sure(onnx_dtype in supported_dtypes, "InvertPermutation only applies on INT32, INT64.") shape = ctx.get_shape(node.input[0]) shape_node = ctx.make_node("Shape", inputs=node.input, name=utils.make_name(node.name + '_shape')) neg_node = ctx.make_node("Neg", inputs=node.input, name=utils.make_name(node.name + '_neg')) topk_node = ctx.make_node( "TopK", inputs=[neg_node.output[0], shape_node.output[0]], name=utils.make_name(node.name + '_topk'), output_count=2) ctx.remove_node(node.name) last_node = ctx.make_node("Identity", inputs=topk_node.output[1:], name=utils.make_name(node.name + '_indices'), shapes=[shape], dtypes=[onnx_dtype]) ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
def version_1(cls, ctx, node, **kwargs): # in tf-2.0 grappler optimizes the graph pretty well and our matching logic # in the rewriter does not trigger. grappler will send the random uniform # with shape as input so we need to pickup the input here and if the shape is # const we make it an attribute. seed = node.get_attr("seed") node.set_attr("seed", float(seed.f)) utils.make_sure(node.inputs[0].is_const(), "%s node with non-const shape requires opset >= 9") shape = node.inputs[0].get_tensor_value() ctx.remove_input(node, node.input[0], 0) if len(shape) == 0: # ORT can't take an empty shape (scalar) node.set_attr("shape", [1]) ctx.set_shape(node.output[0], [1]) squeeze_node = GraphBuilder(ctx).make_squeeze( { 'data': node.output[0], 'axes': [0] }, return_node=True) ctx.insert_node_on_output(squeeze_node, node.output[0]) rand_out = squeeze_node.output[0] else: node.set_attr("shape", shape) ctx.set_shape(node.output[0], shape) rand_out = node.output[0] if node.type == "RandomUniformInt": cls.randuniform_int(ctx, node, rand_out, node.input[0], node.input[1]) node.type = "RandomUniform" ctx.replace_inputs(node, [])
def expand_tensor(t): if t.shape == (1,): return t[0] utils.make_sure(in_rank is not None, "Cannot dequantize node %s with unknown input rank", node.name) new_shape = [1] * in_rank new_shape[axis] = t.shape[0] return t.reshape(new_shape)
def version_9(cls, ctx, node, **kwargs): # T output = Select(bool condition, T x, T y) # T1 output = Where(bool condition, T1 x, T1 y) # NOTE: condition can be 1-dimension in tensorflow, while in onnx, # it should be broadcastable with other two inputs if ctx.get_dtype(node.output[0]) != TensorProto.STRING: # Due to bad ORT implementation, Mul/Add ops are faster than Where op cls.version_7(ctx, node, **kwargs) return cond_shape = ctx.get_shape(node.input[0]) input_shape = ctx.get_shape(node.input[1]) if input_shape is None: input_shape = ctx.get_shape(node.input[2]) input_rank = len(input_shape) if input_shape is not None else None cond_rank = len(cond_shape) if cond_shape is not None else None # if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast if node.type == "Select" and cond_rank == 1 and input_rank != 1: utils.make_sure(input_rank is not None, "input_rank unknown and cond_rank == 1") broadcast_shape = [cond_shape[0]] + [1] * (input_rank - 1) shape_const = ctx.make_const( utils.make_name(node.name), np.array(broadcast_shape, dtype=np.int64)) reshape = ctx.make_node("Reshape", [node.input[0], shape_const.output[0]]) ctx.replace_input(node, node.input[0], reshape.output[0], 0) node.type = "Where"
def _get_output_shape_dtype(self, cond_context): output_shapes = [] output_dtypes = [] for i, _ in enumerate(cond_context.true_branch_context.output): true_output = cond_context.true_branch_context.output[i] false_output = cond_context.false_branch_context.output[i] true_shape = self.g.get_shape(true_output) utils.make_sure(true_shape is not None, "Shape of {} is None".format(true_output)) true_rank = len(true_shape) true_dtype = self.g.get_dtype(true_output) false_shape = self.g.get_shape(false_output) utils.make_sure(false_shape is not None, "Shape of {} is None".format(false_output)) false_rank = len(false_shape) false_dtype = self.g.get_dtype(false_output) # just require rank is equal if true_rank != false_rank: raise RuntimeError( "the rank of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_rank, false_rank)) if true_dtype != false_dtype: raise RuntimeError( "the dtype of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_dtype, false_dtype)) output_shapes.append(utils.create_vague_shape_like(true_shape)) output_dtypes.append(true_dtype) return output_shapes, output_dtypes
def get_tf_tensor_data(tensor): """Get data from tensor.""" make_sure(isinstance(tensor, tensor_pb2.TensorProto), "Require TensorProto") np_data = tensor_util.MakeNdarray(tensor) make_sure(isinstance(np_data, np.ndarray), "%r isn't ndarray", np_data) return np_data
def __init__(self, enter_name, enter_input_id, next_iteration_input_id, switch_true_identity_output_id, exit_output_id, is_tensor_array, ta_index_id, g): self.enter_name = enter_name self.enter_input_id = enter_input_id # the output of iteration body graph for this variable # should not be None utils.make_sure(next_iteration_input_id, "next_iteration_input_id should not be None") self.next_iteration_input = TensorValueInfo(next_iteration_input_id, g) # the starting point of iteration body graph, # might be None when this variable value (either initial value or last iteration output value) # is not consumed iteration body graph nodes. self.switch_true_identity_output = TensorValueInfo( switch_true_identity_output_id, g) # the switch_false branch is ended with Exit, which is a boundary for the loop, # might be None when no consumers for the variable output. self.exit_output = TensorValueInfo(exit_output_id, g) # only applicable for tensor array variable self.is_tensor_array = is_tensor_array # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration. # then we can be sure this is equivalent to scan output behavior. self.ta_index_id = ta_index_id
def version_6(cls, ctx, node, **kwargs): # T output = All(T x, list(int) reduce_indices, @bool keepdims) # T output = Any(T x, list(int) reduce_indices, @bool keepdims) reduce_dim = node.inputs[1].get_tensor_value() # for Any, the reduce_indices can be scalar as observed. if np.isscalar(reduce_dim): reduce_dim = [reduce_dim] if ctx.opset < 11: utils.make_sure(all(i >= 0 for i in reduce_dim), "negative reduce axis is not supported in onnx for now") cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT}) keepdims = helper.get_attribute_value(node.get_attr("keep_dims")) op_type = "ReduceMin" if node.type == "All" else "ReduceSum" if op_type == "ReduceSum": reduce_node_output = GraphBuilder(ctx).make_reduce_sum( {"data": cast.output[0], "axes": reduce_dim, "keepdims": keepdims, "noop_with_empty_axes": 1}) else: reduce_node_output = ctx.make_node(op_type=op_type, inputs=cast.output, attr={"axes": reduce_dim, "keepdims": keepdims}).output[0] zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32)) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) ctx.make_node(op_type="Greater", inputs=[reduce_node_output, zero_node.output[0]], name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def version_1(cls, ctx, node, **kwargs): """Range.""" # T range = Range(T start, T limit, T delta) dtype = node.get_attr_int("Tidx") shape = node.output_shapes[0] utils.make_sure(dtype is not None, "Tidx of %s is None", node.name) ctx.remove_node(node.name) make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)
def version_9(cls, ctx, node, **kwargs): node_dtype = ctx.get_dtype(node.output[0]) utils.make_sure(node_dtype, "dtype of {} is None".format(node.name)) if node_dtype in [ onnx_pb.TensorProto.BOOL, onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128 ]: raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
def version_7(cls, ctx, node, **kwargs): """Range.""" # T range = Range(T start, T limit, T delta) # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial) dtype = node.get_attr_int("Tidx") shape = node.output_shapes[0] utils.make_sure(dtype is not None, "Tidx of %s is None", node.name) ctx.remove_node(node.name) make_range(ctx, node.input[0], node.input[1], node.input[2], node.output[0], node.name, shape, dtype)
def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None, return_node=False, op_name_scope=None): """ Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input kwargs: key could be ["data", "axes"]. """ outputs = kwargs.pop("outputs", None) if self.graph.opset < 13: data = kwargs.pop("data") axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True) attr = {"axes": axes} inputs = [data] else: data = kwargs.pop("data") axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64) attr = {} inputs = [data, axes] utils.make_sure(not kwargs, "kwargs contains un-used key") new_attr = {} for key, val in attr.items(): if val is not None: new_attr[key] = val attr = new_attr for ind, val in enumerate(inputs): if val is None: inputs[ ind] = utils.ONNX_EMPTY_INPUT # empty string means no connection in ONNX # remove tailing "" while inputs[-1] == utils.ONNX_EMPTY_INPUT: inputs = inputs[:-1] node = self.graph.make_node(op_type="Unsqueeze", inputs=inputs, attr=attr, name=name, outputs=outputs, shapes=shapes, dtypes=dtypes, op_name_scope=op_name_scope) if return_node: return node return node.output[0]
def add_variable(self, var): utils.make_sure(var.enter_name not in self.scan_variables, "variable %s already exists as scan variable.", var.enter_name) utils.make_sure(var.enter_name not in self.state_variables, "variable %s already exists as state variable.", var.enter_name) if not var.is_tensor_array: self.state_variables[var.enter_name] = var else: self.scan_variables[var.enter_name] = var
def _find_tensorarray_write(op): utils.make_sure(op.type == "TensorArrayV3", "op should be tensorarray") tensor_array_consumers = op.outputs[0].consumers() for i in tensor_array_consumers: if i.type == "Enter": consumer_ops = i.outputs[0].consumers() for j in consumer_ops: if j.type == "TensorArrayWriteV3": return j return None
def parameter_binding(g, inputs, state_vars=None): binding = {} i = 0 for k in g.input_names: if state_vars and k in state_vars: binding[k] = state_vars[k] else: binding[k] = inputs[i] i += 1 utils.make_sure(i == len(inputs), "Parameter count mismatch while binding controlflow") return binding
def any_version(cls, opset, ctx, node, **kwargs): """ Computes the modules of a complex. If the matrix dtype is not complex64 or complex128, it assumes the first dimension means real part (0) and imaginary part (1, :, :...). """ supported_dtypes = [ onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.DOUBLE, onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128, ] onnx_dtype = ctx.get_dtype(node.input[0]) utils.make_sure(onnx_dtype in supported_dtypes, "Unsupported input type.") shape = ctx.get_shape(node.input[0]) np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype) utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape) ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64)) ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64)) p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype)) real_part = ctx.make_node( 'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0), name=utils.make_name('Real_' + node.name)) imag_part = ctx.make_node( 'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0), name=utils.make_name('Imag_' + node.name)) real_part2 = ctx.make_node( 'Pow', inputs=[real_part.output[0], p2.name], name=utils.make_name(real_part.name + 'p2p')) imag_part2 = ctx.make_node( 'Pow', inputs=[imag_part.output[0], p2.name], name=utils.make_name(imag_part.name + 'p2p')) ctx.remove_node(node.name) add = ctx.make_node( "Add", inputs=[real_part2.output[0], imag_part2.output[0]], name=utils.make_name('ComplexAbs_' + node.name)) squeezed = GraphBuilder(ctx).make_squeeze( {'data': add.output[0], 'axes': [0]}, name=utils.make_name('ComplexAbs' + node.name), return_node=True) last_node = ctx.make_node( "Sqrt", inputs=squeezed.output[:1], name=utils.make_name('ComplexAbs' + node.name), shapes=[shape[1:]], dtypes=[onnx_dtype]) ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
def get_tf_const_value(op, as_list=True): """ If as_list=True, return the array as a (possibly nested) list. Otherwise, return data of type np.ndarray. If a tensor is a scalar having value 1, when as_list=False, return np.array(1), type is <class 'numpy.ndarray'> when as_list=True, return 1, type is <class 'int'>. """ make_sure(is_tf_const_op(op), "%r isn't a const op", op.name) value = get_tf_tensor_data(op.get_attr("value")) if as_list: value = value.tolist() return value
def _make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype): utils.make_sure( dtype in [ TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.INT16, TensorProto.INT32, TensorProto.INT64, TensorProto.COMPLEX64, TensorProto.COMPLEX128 ], "dtype %s is not supported", dtype) ctx.make_node("Range", [start, limit, delta], outputs=[output], name=scope_name, shapes=[shape], dtypes=[dtype], domain=constants.MICROSOFT_DOMAIN)
def convert_to_attribute(self, tensor, is_optional=False): if is_optional and tensor is None: return None utils.make_sure(tensor is not None, "input is required so it couldn't be None") res = tensor if isinstance(tensor, str): const_node = self.graph.get_node_by_output(tensor) res = const_node.get_tensor_value(as_list=True) utils.make_sure(isinstance(res, list), "input is an attr, so a list is needed") return res
def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None): """ ReduceSum changes its schema at opset 13: it treats some axes as dynamic input kwargs: key could be ["data", "axes", "keepdims", "noop_with_empty_axes", "outputs"]. """ outputs = kwargs.pop("outputs", None) if self.graph.opset < 13: data = kwargs.pop("data") axes = self.convert_to_attribute(kwargs.pop("axes", None), is_optional=True) keepdims = kwargs.pop("keepdims", 1) noop_with_empty_axes = kwargs.pop("noop_with_empty_axes", 0) if noop_with_empty_axes == 0 and axes == []: axes = None attr = {"axes": axes, "keepdims": keepdims} inputs = [data] else: keepdims = kwargs.pop("keepdims", 1) noop_with_empty_axes = kwargs.pop("noop_with_empty_axes", 0) data = self.convert_to_input(kwargs.pop("data"), "const_data") axes = self.convert_to_input(kwargs.pop("axes", None), "const_axes", is_optional=True, dtype=np.int64) attr = { "keepdims": keepdims, "noop_with_empty_axes": noop_with_empty_axes } inputs = [data, axes] utils.make_sure(not kwargs, "kwargs contains un-used key") new_attr = {} for key, val in attr.items(): if val is not None: new_attr[key] = val attr = new_attr return self.graph.make_node(op_type="ReduceSum", inputs=inputs, attr=attr, name=name, outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
def separate_fused_activation_function(ctx, node): activation_fn = node.attr['fused_activation_function'].s del node.attr['fused_activation_function'] if activation_fn == b'RELU': ctx.insert_new_node_on_output("Relu", node.output[0]) elif activation_fn == b'RELU6': # This is a TF op. We will convert it on the 2nd pass. shape = ctx.get_shape(node.output[0]) dtype = ctx.get_dtype(node.output[0]) new_node = ctx.make_node("Relu6", [node.output[0]], skip_conversion=False, shapes=[shape], dtypes=[dtype]) ctx.insert_node_on_output(new_node, node.output[0]) elif activation_fn == b'TANH': ctx.insert_new_node_on_output("Tanh", node.output[0]) else: # TODO: SIGN_BIT and RELU_N1_TO_1 not supported yet utils.make_sure(activation_fn == b'NONE', "Unsupported fused activation function %s on node %s", activation_fn, node.name)
def compress_graph_def(graph_def): """ Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing. """ node_defs = list(graph_def.node) const_node_values = {} for node_def in node_defs: if node_def.op == 'Const': tensor = node_def.attr["value"].tensor # Small constants are sometimes used to store shape information and must be maintained if len(tensor.tensor_content) > 1000: make_sure(node_def.name not in const_node_values, "Two nodes in graph have same name %s", node_def.name) const_node_values[node_def.name] = tensor.tensor_content tensor.tensor_content = b'' return const_node_values
def version_10(cls, ctx, node, **kwargs): scale = node.get_attr_value('scale') zero_point = node.get_attr_value('zero_point') axis = node.get_attr_value('quantized_dimension') np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.output[0])) if len(scale) > 1 or len(zero_point) > 1: utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization for node %s", node.name) node.set_attr("axis", axis) scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32)) zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type)) ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]]) del node.attr["scale"] del node.attr["zero_point"] del node.attr["quantized_dimension"] if "min" in node.attr: del node.attr["min"] if "max" in node.attr: del node.attr["max"]
def version_1(cls, ctx, node, **kwargs): node.domain = constants.CONTRIB_OPS_DOMAIN node.type = "StringRegexReplace" pattern = node.get_attr_str("pattern") rewrite = node.get_attr_str("rewrite") utils.make_sure( node.get_attr_value("replace_global") != 0, "Can not convert StaticRegexReplace if replace_global is False") pattern_node = ctx.make_const(utils.make_name("pattern"), np.array([pattern], np.object)) rewrite_node = ctx.make_const(utils.make_name("rewrite"), np.array([rewrite], np.object)) del node.attr["pattern"] del node.attr["rewrite"] del node.attr["replace_global"] ctx.replace_inputs( node, [node.input[0], pattern_node.output[0], rewrite_node.output[0]])
def convert_to_input(self, tensor, const_name, is_optional=False, dtype=None): """in ONNX, input shold come from node, so it must be a string""" if is_optional and tensor is None: return None utils.make_sure(tensor is not None, "input is required so it couldn't be None") res = tensor if isinstance(tensor, list): res = self.graph.make_const(utils.make_name(const_name), np.array(tensor, dtype)).output[0] utils.make_sure(isinstance(res, str), "input is a dynamic input, so a str is needed") return res
def _merge_shapes_for_tf(shape1, shape2): """ Merge 2 shapes, return merged shape, set unknown for dims with different values. Raise exception for mismatch. """ if shape1 is None: return shape2 if shape2 is None: return shape1 utils.make_sure(utils.is_list_or_tuple(shape1), "invalid type for shape1") utils.make_sure(utils.is_list_or_tuple(shape2), "invalid type for shape2") utils.make_sure( len(shape1) == len(shape2), "shapes rank mismatch: shape1=%s, shape2=%s", shape1, shape2) merged = [] for d1, d2 in zip(shape1, shape2): d = d1 if d1 is None: d = d2 elif d2 is not None: # None means unknown in tensorflow d = None merged.append(d) return merged
def version_1(cls, ctx, node, **kwargs): node.domain = constants.CONTRIB_OPS_DOMAIN input_node = node.inputs[0] utils.make_sure(input_node.type == "SentencepieceOp", "Input 0 to node %s is not SentencepieceOp", node.name) ctx.remove_input(node, node.input[0], 0) nbest_size_cast = ctx.make_node("Cast", [node.input[1]], attr={ 'to': TensorProto.INT64 }).output[0] ctx.replace_input(node, node.input[1], nbest_size_cast, 1) for i in range(1, len(node.input)): unsqueeze = GraphBuilder(ctx).make_unsqueeze({ 'data': node.input[i], 'axes': [0] }) ctx.replace_input(node, node.input[i], unsqueeze, i) node.set_attr("model", input_node.attr['model'].s) node.type = "SentencepieceTokenizer" if ctx.is_safe_to_remove_nodes([input_node]): ctx.remove_node(input_node.name)
def version_1(cls, ctx, node, initialized_tables, **kwargs): table_node = node.inputs[0] while table_node.type == 'Identity': table_node = table_node.inputs[0] shared_name = table_node.get_attr_value("shared_name") utils.make_sure(shared_name is not None, "Could not determine table shared name for node %s", node.name) utils.make_sure(shared_name in initialized_tables, "Initialized table %s for node %s not found.", shared_name, node.name) keys, _ = initialized_tables[shared_name] node_name = node.name node_outputs = node.output ctx.remove_node(node.name) size_const = ctx.make_const(node_name, np.array(len(keys), dtype=np.int64)) ctx.replace_all_inputs(node_outputs[0], size_const.output[0]) customer_nodes = ctx.find_output_consumers(table_node.output[0]) if len(customer_nodes) == 0: ctx.remove_node(table_node.name)
def version_1(cls, ctx, node, **kwargs): """Sign op.""" # T sign = Sign(T Input) node_dtype = ctx.get_dtype(node.output[0]) utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name)) if node_dtype in [ onnx_pb.TensorProto.COMPLEX64, onnx_pb.TensorProto.COMPLEX128 ]: raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now") zero_name = utils.make_name("{}_zero".format(node.name)) ctx.make_const(zero_name, np.array(0, dtype=np.float32)) if node_dtype not in [ onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE ]: cast_node_0 = ctx.make_node("Cast", [node.input[0]], {"to": onnx_pb.TensorProto.FLOAT}) greater_node = ctx.make_node("Greater", [cast_node_0.output[0], zero_name]) less_node = ctx.make_node("Less", [cast_node_0.output[0], zero_name]) else: greater_node = ctx.make_node("Greater", [node.input[0], zero_name]) less_node = ctx.make_node("Less", [node.input[0], zero_name]) cast_node_1 = ctx.make_node("Cast", [greater_node.output[0]], {"to": node_dtype}) cast_node_2 = ctx.make_node("Cast", [less_node.output[0]], {"to": node_dtype}) shapes = node.output_shapes dtypes = node.output_dtypes ctx.remove_node(node.name) ctx.make_node("Sub", [cast_node_1.output[0], cast_node_2.output[0]], outputs=[node.output[0]], shapes=shapes, dtypes=dtypes)
def version_7(cls, ctx, node, **kwargs): # T output = Select(bool condition, T x, T y) # Select_res = Add(Multiply(Cast(bool condition, T), T x,), # Multiply(Cast(Not(bool condition), T), T y)). # TODO: Fix case where condition is 1-dimensional utils.make_sure( len(node.input) > 1, "Select with only condition is not supported.") dtype = ctx.get_dtype(node.output[0]) utils.make_sure(dtype != TensorProto.STRING, "Select with dtype string requires opset 9") cond_shape = ctx.get_shape(node.input[0]) input_shape = ctx.get_shape(node.input[1]) if input_shape is None: input_shape = ctx.get_shape(node.input[2]) input_rank = len(input_shape) if input_shape is not None else None cond_rank = len(cond_shape) if cond_shape is not None else None # if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast if node.type == "Select" and cond_rank == 1 and input_rank != 1: utils.make_sure(input_rank is not None, "input_rank unknown and cond_rank == 1") broadcast_shape = [cond_shape[0]] + [1] * (input_rank - 1) shape_const = ctx.make_const( utils.make_name(node.name), np.array(broadcast_shape, dtype=np.int64)) reshape = ctx.make_node("Reshape", [node.input[0], shape_const.output[0]]) ctx.replace_input(node, node.input[0], reshape.output[0], 0) positive_cast = ctx.make_node("Cast", [node.input[0]], name=utils.make_name(node.name), attr={"to": dtype}) negative = ctx.make_node("Not", [node.input[0]], name=utils.make_name(node.name)) negative_cast = ctx.make_node("Cast", [negative.output[0]], name=utils.make_name(node.name), attr={"to": dtype}) multiply_1 = ctx.make_node("Mul", [positive_cast.output[0], node.input[1]], name=utils.make_name(node.name)) multiply_2 = ctx.make_node("Mul", [node.input[2], negative_cast.output[0]], name=utils.make_name(node.name)) add_name = node.name add_out = node.output shape = ctx.get_shape(node.output[0]) ctx.remove_node(node.name) ctx.make_node("Add", [multiply_1.output[0], multiply_2.output[0]], outputs=add_out, name=add_name, dtypes=[dtype], shapes=[shape])