def _onnx_initializer_to_input_dict_items(cls, initializer, input_format='NCHW'): """ Convert ONNX graph initializer to input dict items. :param initializer: ONNX graph initializer, list of TensorProto. :return: List of input dict items. """ def tensor2list(onnx_tensor,input_format='NCHW',name=None): # Use the onnx.numpy_helper because the data may be raw row_list = numpy_helper.to_array(onnx_tensor) ##所有维度的常量需要转 if input_format=='NHWC' and (row_list is not None) and list(np.array(row_list).shape)==[4] and name!="data": c = row_list[1] h = row_list[2] w = row_list[3] return [row_list[0],h,w,c] else: return row_list for init in initializer: dtype=data_type.onnx2tf(init.data_type) print ("init name:",init.name) print ("init dtype:",init.data_type) print ("dims:",init.dims) tensor2list(init,input_format) return [(init.name, tf.constant( tensor2list(init,input_format,init.name), shape=init.dims, dtype=data_type.onnx2tf(init.data_type))) for init in initializer]
def _onnx_initializer_to_input_dict_items(cls, initializer): """ Convert ONNX graph initializer to input dict items. :param initializer: ONNX graph initializer, list of TensorProto. :return: List of input dict items. """ def tensor2list(onnx_tensor): # Use the onnx.numpy_helper because the data may be raw return numpy_helper.to_array(onnx_tensor).flatten().tolist() def validate_initializer_name(name): # Prepend a unique suffix if leading charater is "_" name = get_unique_suffix() + name if name[0] is "_" else name # Replace ":" with "_tf_" and append a unique suffix for # traceability return name.replace(":", "_tf_") + "_" + get_unique_suffix( ) if ":" in name else name return [(init.name, tf.constant(tensor2list(init), shape=init.dims, dtype=data_type.onnx2tf(init.data_type), name=validate_initializer_name(init.name))) for init in initializer]
def version_11(cls, node, **kwargs): default_dtype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')] dtype = data_type.onnx2tf(node.attrs.get("dtype", default_dtype)) ragged = tf.RaggedTensor.from_row_lengths(values=[], row_lengths=[]) sparse = tf.cast(ragged.to_sparse(), dtype) return [tf.RaggedTensor.from_sparse(sparse)]
def version_1(cls, node, **kwargs): attr_value = node.attrs["value"] dtype = data_type.onnx2tf(attr_value.data_type) value = numpy_helper.to_array(attr_value) return [ cls.make_tensor_from_onnx_node( node, inputs=[value], attrs={"dtype": dtype}) ]
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): """ Convert ONNX graph to TensorflowRep. :param graph_def: ONNX GraphProto object. :param opset: ONNX OperatorSetIdProto list. :param strict: whether to enforce semantic equivalence between the original model and the converted tensorflow model. :return: TensorflowRep object. """ handlers = cls._get_handlers(opset) # initializer: TensorProtos representing the values to initialize # a given tensor. # initialized: A list of names of the initialized tensors. if graph_def.initializer: initialized = {init.name for init in graph_def.initializer} else: initialized = set() module = BackendTFModule(handlers, opset, strict, graph_def, cls) signatures = dict() for value_info in graph_def.input: if value_info.name in initialized: continue shape = list(d.dim_value if ( d.dim_value > 0 and d.dim_param == "") else None for d in value_info.type.tensor_type.shape.dim) value_info_name = value_info.name.replace( ":", "_tf_") + "_" + get_unique_suffix( ) if ":" in value_info.name else value_info.name tf_spec = tf.TensorSpec( shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name) signatures[value_info.name] = tf_spec tf_rep = TensorflowRep() tf_rep.inputs = [ value_info.name for value_info in graph_def.input if value_info.name not in initialized ] tf_rep.outputs = [value_info.name for value_info in graph_def.output] module.outputs = tf_rep.outputs tf_rep.tf_module = module tf_rep.signatures = signatures return tf_rep
def _onnx_initializer_to_input_dict_items(cls, initializer): """ Convert ONNX graph initializer to input dict items. :param initializer: ONNX graph initializer, list of TensorProto. :return: List of input dict items. """ def tensor2list(onnx_tensor): # Use the onnx.numpy_helper because the data may be raw return numpy_helper.to_array(onnx_tensor).flatten().tolist() return [(init.name, tf.constant(tensor2list(init), shape=init.dims, dtype=data_type.onnx2tf(init.data_type))) for init in initializer]
def false_fn(): new_scan_outputs = [] for i in range(scan_outputs_start_index, len(body.output)): exp_elem_shape = scan_outputs_init[ i - scan_outputs_start_index].element_shape elem_shape = [] for j in range(exp_elem_shape.rank): shape_j = 0 if exp_elem_shape[ j] is None else exp_elem_shape[j] elem_shape.append(shape_j) new_scan_outputs.append( tf.TensorArray( dtype=data_type.onnx2tf( body.output[i].type.tensor_type.elem_type), size=0, element_shape=tf.TensorShape(elem_shape))) return new_scan_outputs
def _onnx_initializer_to_input_dict_items(cls, initializer): """ Convert ONNX graph initializer to input dict items. :param initializer: ONNX graph initializer, list of TensorProto. :return: List of input dict items. """ def tensor2list(onnx_tensor): # Use the onnx.numpy_helper because the data may be raw return numpy_helper.to_array(onnx_tensor).flatten().tolist() input_dict= [(init.name, tf.Variable( [tensor2list(init)], trainable=True, # True for prototype, will be based on training info later expected_shape=init.dims, name=init.name, dtype=data_type.onnx2tf(init.data_type))) for init in initializer] return input_dict
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs): """ Convert ONNX graph to TensorflowRep. :param graph_def: ONNX GraphProto object. :param opset: ONNX OperatorSetIdProto list. :param strict: whether to enforce semantic equivalence between the original model and the converted tensorflow model. :kwargs: additional arguements to generate tensor_dict for model debugging :return: TensorflowRep object. """ # To generate tensor_dict or not, default is False gen_tensor_dict = kwargs[ 'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False # User provided input tensors, in the case the model inputs have unknown shapes input_tensor_dict = kwargs[ 'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict() handlers = cls._get_handlers(opset) # initializer: TensorProtos representing the values to initialize # a given tensor. # initialized: A list of names of the initialized tensors. if graph_def.initializer: initialized = {init.name for init in graph_def.initializer} else: initialized = set() input_dict = dict() module = BackendTFModule(handlers, opset, strict, graph_def, cls) signatures = dict() for value_info in graph_def.input: if value_info.name in initialized: continue shape = list(d.dim_value if ( d.dim_value > 0 and d.dim_param == "") else None for d in value_info.type.tensor_type.shape.dim) value_info_name = value_info.name.replace( ":", "_tf_") + "_" + get_unique_suffix( ) if ":" in value_info.name else value_info.name tf_spec = tf.TensorSpec( shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name) signatures[value_info.name] = tf_spec if gen_tensor_dict: x = tf.constant( 0, dtype=data_type.onnx2tf( value_info.type.tensor_type.elem_type), name=value_info_name, shape=shape ) if value_info.name not in input_tensor_dict else input_tensor_dict[ value_info.name] input_dict[value_info.name] = x tf_rep = TensorflowRep() tf_rep.inputs = [ value_info.name for value_info in graph_def.input if value_info.name not in initialized ] tf_rep.outputs = [value_info.name for value_info in graph_def.output] module.outputs = tf_rep.outputs tf_rep.tf_module = module tf_rep.signatures = signatures tf_rep.tensor_dict = module.gen_tensor_dict( input_dict) if gen_tensor_dict else None tf_rep.onnx_op_list = cls._get_onnx_op_list(graph_def) return tf_rep
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict): """ Convert ONNX graph to TensorflowRep. :param graph_def: ONNX GraphProto object. :param opset: ONNX OperatorSetIdProto list. :param strict: whether to enforce semantic equivalence between the original model and the converted tensorflow model. :return: TensorflowRep object. """ handlers = cls._get_handlers(opset) tf_rep_graph = tf.Graph() with tf_rep_graph.as_default(): # initializer: TensorProtos representing the values to initialize # a given tensor. # initialized: A list of names of the initialized tensors. if graph_def.initializer: input_dict_items = cls._onnx_initializer_to_input_dict_items( graph_def.initializer) initialized = {init.name for init in graph_def.initializer} else: input_dict_items = [] initialized = set() # creating placeholders for currently unknown inputs for value_info in graph_def.input: if value_info.name in initialized: continue shape = list(d.dim_value if ( d.dim_value > 0 and d.dim_param == "") else None for d in value_info.type.tensor_type.shape.dim) value_info_name = value_info.name.replace( ":", "_tf_") + "_" + get_unique_suffix( ) if ":" in value_info.name else value_info.name x = tf.compat.v1.placeholder(data_type.onnx2tf( value_info.type.tensor_type.elem_type), name=value_info_name, shape=shape) input_dict_items.append((value_info.name, x)) # tensor dict: this dictionary is a map from variable names # to the latest produced TF tensors of the given name. # This dictionary will get updated as we build the graph to # record the names of newly produced tensors. tensor_dict = dict(input_dict_items) # Since tensor dict may be updated, we need to keep a copy # of the original input dict where we track the earliest # defined tensors so we can have access to the placeholders # to feed in input tensors when we run the graph. input_dict = dict(input_dict_items) for node in graph_def.node: onnx_node = OnnxNode(node) output_ops = cls._onnx_node_to_tensorflow_op(onnx_node, tensor_dict, handlers, opset=opset, strict=strict) curr_node_output_map = dict(zip(onnx_node.outputs, output_ops)) tensor_dict.update(curr_node_output_map) tf_rep = TensorflowRep() tf_rep.graph = tf_rep_graph tf_rep.inputs = [ value_info.name for value_info in graph_def.input if value_info.name not in initialized ] tf_rep.outputs = [value_info.name for value_info in graph_def.output] tf_rep.tensor_dict = tensor_dict return tf_rep
"value": lambda x: MakeNdarray(x.tensor), "seed2": lambda x: float(x.i), "seed": lambda x: float(x.i), "keep_dims": lambda x: int(x.b), "squeeze_dims": lambda x: list(x.list.i), } __onnx_attr_translator = { "axis": lambda x: int(x), "axes": lambda x: [int(a) for a in x], "dtype": lambda x: data_type.onnx2tf(x), "keepdims": lambda x: bool(x), "to": lambda x: data_type.onnx2tf(x), } def translate_tf(key, val): return __tf_attr_translator.get(key, lambda x: x)(val) def translate_onnx(key, val): return __onnx_attr_translator.get(key, lambda x: x)(val) def get_tf_shape_as_list(tf_shape_dim): return list(map(lambda x: x.size, list(tf_shape_dim)))
def scan(cls, node, input_dict, strict): current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)] body = node.attrs["body"] # in version 8, node.inputs[0] is the sequence_lens node_inputs = node.inputs if cls.SINCE_VERSION != 8 else \ node.inputs[1:] # M num_scan_inputs = int(node.attrs["num_scan_inputs"]) # N = num_inputs - M num_state_vars = len(node_inputs) - num_scan_inputs # K = num_outputs - N num_scan_outputs = len(node.outputs) - num_state_vars """ Function to run subgraph used with tf.scan """ def run_subgraph(a, b): input_values = dict(input_dict) # set the input values for the subgraph # set the values for the state variables for i in range(num_state_vars): input_values[body.input[i].name] = a[i] # set the values for the scan inputs for i in range(num_scan_inputs): input_values[body.input[i + num_state_vars].name] = b[i] # get the tensor operations for the onnx graph input_values = onnx_tf.backend.onnx_graph_to_tensorflow_ops( subgraph=body, tensor_dict=input_values, opset=current_opset, strict=strict) # return sequence of tensors for every subgraph output outputs = [input_values[output.name] for output in body.output] return outputs scan_input_axes = node.attrs.get("scan_input_axes", [0] * num_scan_inputs) scan_input_directions = node.attrs.get( "directions" if cls.SINCE_VERSION == 8 else "scan_input_directions", [0] * num_scan_inputs) scan_output_axes = node.attrs.get("scan_output_axes", [0] * num_scan_outputs) scan_output_directions = node.attrs.get("scan_output_directions", [0] * num_scan_outputs) # if version 8 read the sequnce_lens from the first input if cls.SINCE_VERSION == 8: sequence_lens = input_dict[node.inputs[0]] \ if node.inputs[0] != '' else None inputs = [input_dict[node_input] for node_input in node_inputs] scan_inputs = inputs[num_state_vars:] # loop over all the scan inputs and apply transpose depending # on input axes provided and also reverse the scan inputs if # reverse direction for scan is provided for i in range(num_scan_inputs): # if input axes are different than 0, use transpose to scan over # the provided axes if scan_input_axes[i] != 0: transpose_perm = cls._calc_transpose_perm_input(tf.rank(scan_inputs[i]), scan_input_axes[i]) scan_inputs[i] = tf.transpose(scan_inputs[i], transpose_perm) # check for reverse direction scans if scan_input_directions[i] == 1: # version 8 has a batch dimension axis = 0 if cls.SINCE_VERSION != 8 else 1 scan_inputs[i] = tf.reverse(scan_inputs[i], [axis]) state_vars_init = inputs[:num_state_vars] scan_outputs_init = [] # generate sequence of zero tensors for all scan outputs # with the correct shape and dtype for scan_output in body.output[num_state_vars:]: tensor_type = scan_output.type.tensor_type shape = [ d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None for d in tensor_type.shape.dim ] dtype = data_type.onnx2tf(tensor_type.elem_type) scan_outputs_init.append(tf.zeros(shape, dtype=dtype)) # tf.scan initilizer is state_variables_init + scan_outputs_init initializer = state_vars_init + scan_outputs_init if cls.SINCE_VERSION == 8: # version == 8 # function to process the batches. it is used with tf.map_fn def run_batches(x): # state vars initial values per batch initial = x[0] # scan inputs per batch scan_inputs = x[1] # sequence length for the batch seq_len = x[2] # slice the input to the current sequence len scan_inputs = [scan_input[:seq_len, ...] for scan_input in scan_inputs] # run scan on the current batch out = tf.scan(run_subgraph, scan_inputs, initializer=initial + scan_outputs_init) # pad to the original shape with zeros paddings = [[0, tf.shape(x[1][0], out_type=seq_len.dtype)[0] - seq_len]] for i in range(len(out)): pads = tf.concat( [paddings, tf.zeros([(tf.rank(out[i]) - 1), 2], dtype=tf.int32)], axis=0) out[i] = tf.pad(out[i], pads) return out if sequence_lens is None: # if sequence_lens is None, fill it with the shape of # the input axis 1 sequence_lens = tf.fill([tf.shape(scan_inputs[0])[0]], tf.shape(scan_inputs[0], out_type=tf.int32)[1]) output_types = [ data_type.onnx2tf(output.type.tensor_type.elem_type) for output in body.output ] # run scan for every batch out = tf.map_fn(run_batches, (state_vars_init, scan_inputs, sequence_lens), dtype=output_types) state_vars_outputs = [] # extract the final values of the state variables for state_var in out[:num_state_vars]: state_vars_outputs.append( tf.map_fn(lambda x: x[0][x[1] - 1], (state_var, sequence_lens), state_var.dtype)) else: # version > 8 # run the scan out = tf.scan(run_subgraph, scan_inputs, initializer=initializer) # extract the final values of the state variables state_vars_outputs = [ state_var[tf.shape(state_var)[0] - 1] for state_var in out[:num_state_vars] ] scan_outputs = out[num_state_vars:] # post process the scan outputs depending on the directions and # axes provided. for i in range(num_scan_outputs): # check for reverse direction scan outputs if scan_output_directions[i] == 1: scan_outputs[i] = tf.reverse(scan_outputs[i], [0]) if scan_output_axes[i] != 0: transpose_perm = cls._calc_transpose_perm_output( tf.rank(scan_outputs[i]), scan_output_axes[i]) scan_outputs[i] = tf.transpose(scan_outputs[i], transpose_perm) return state_vars_outputs + scan_outputs
def _common(cls, node, **kwargs): body = node.attrs["body"] tensor_dict = kwargs["tensor_dict"] M = tensor_dict[node.inputs[0]] if node.inputs[0] != "" else None M = tf.where(tf.greater(M, tf.int32.max), tf.constant(tf.int32.max, tf.int32), tf.cast( M, tf.int32)) if M is not None else M cond_init = tf.cast(tensor_dict[node.inputs[1]], tf.bool) if node.inputs[1] != "" else None v_init = [tensor_dict[graph_input] for graph_input in node.inputs[2:]] v_shapes = [ tf.TensorShape([None for i in range(v.shape.rank)]) for v in v_init ] iter_cnt_init = np.int64(0) current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)] # outputs of the body will be in this format: # (condition, loop carried dependencies..., scan_outputs...) scan_outputs_start_index = 1 + len(v_init) scan_outputs_init = [ tf.TensorArray(dtype=data_type.onnx2tf( body.output[i].type.tensor_type.elem_type), size=0, dynamic_size=True) for i in range(scan_outputs_start_index, len(body.output)) ] scan_outputs_shapes = [tf.TensorShape(None) for o in scan_outputs_init] def run_subgraph(iter_cnt, cond, v, scan_outputs): subgraph_tensor_dict = dict(tensor_dict) subgraph_tensor_dict[body.input[0].name] = iter_cnt subgraph_tensor_dict[body.input[1].name] = cond for i in range(2, len(body.input)): subgraph_tensor_dict[body.input[i].name] = v[i - 2] subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops( subgraph=body, tensor_dict=subgraph_tensor_dict, opset=current_opset) outputs = [ subgraph_tensor_dict[output.name] for output in body.output ] for i in range(scan_outputs_start_index, len(outputs)): s_index = i - scan_outputs_start_index insert_index = scan_outputs[s_index].size() scan_outputs[s_index] = scan_outputs[s_index].write( insert_index, outputs[i]) iter_cnt += 1 return iter_cnt, outputs[0], outputs[ 1:scan_outputs_start_index], scan_outputs # for loop if M is not None and cond_init is None: condition = lambda iter_cnt, cond, v, scan_outputs: True iter_cnt_final, _, v_final, scan_outputs_final = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=[iter_cnt_init, "", v_init, scan_outputs_init], shape_invariants=[ tf.TensorShape([]), tf.TensorShape(None), v_shapes, scan_outputs_shapes ], maximum_iterations=M) # while and do-while loop elif M is None and cond_init is not None: condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all( tf.equal(cond, True)) iter_cnt_final, cond_final, v_final, scan_outputs_final = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=[ iter_cnt_init, cond_init, v_init, scan_outputs_init ], shape_invariants=[ tf.TensorShape([]), tf.TensorShape(None), v_shapes, scan_outputs_shapes ]) # combine for loop and while loop together elif M is not None and cond_init is not None: condition = lambda iter_cnt, cond, v, scan_outputs: tf.reduce_all( tf.equal(cond, True)) iter_cnt_final, cond_final, v_final, scan_outputs_final = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=[ iter_cnt_init, cond_init, v_init, scan_outputs_init ], shape_invariants=[ tf.TensorShape([]), tf.TensorShape(None), v_shapes, scan_outputs_shapes ], maximum_iterations=M) else: # M is None and cond is None exception.OP_UNSUPPORTED_EXCEPT( "Both M and cond in Loop are not set at the same time", "Tensorflow.(PS. if you want to create a do-while loop " + "then please set cond to True or 1)") if scan_outputs_start_index == len(body.output): # there is no scan_output in the body graph return v_final else: # if the loop has run >= 1 time then do nothing def true_fn(): return scan_outputs_final # if the loop didn't run at all then recreate the scan_outputs' # TensorArray and set the element_shape to [0]. # Then tensorflow will allow to append the empty tensor # to v_final def false_fn(): new_scan_outputs = [] for i in range(scan_outputs_start_index, len(body.output)): exp_elem_shape = scan_outputs_init[ i - scan_outputs_start_index].element_shape elem_shape = [] for j in range(exp_elem_shape.rank): shape_j = 0 if exp_elem_shape[ j] is None else exp_elem_shape[j] elem_shape.append(shape_j) new_scan_outputs.append( tf.TensorArray( dtype=data_type.onnx2tf( body.output[i].type.tensor_type.elem_type), size=0, element_shape=tf.TensorShape(elem_shape))) return new_scan_outputs scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn) scan_outputs_tensors = [o.stack() for o in scan_out_final] return v_final + scan_outputs_tensors
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs): """ Convert ONNX graph to TensorflowRep. :param graph_def: ONNX GraphProto object. :param opset: ONNX OperatorSetIdProto list. :param strict: whether to enforce semantic equivalence between the original model and the converted tensorflow model. :kwargs: additional arguements to generate tensor_dict for model debugging :return: TensorflowRep object. """ # To generate tensor_dict or not, default is False gen_tensor_dict = kwargs[ 'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False # User provided input tensors, in the case the model inputs have unknown shapes input_tensor_dict = kwargs[ 'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict() training_mode = kwargs[ 'training_mode'] if 'training_mode' in kwargs else False handlers = cls._get_handlers(opset) # initializer: TensorProtos representing the values to initialize # a given tensor. # initialized: A list of names of the initialized tensors. if graph_def.initializer: initialized = {init.name for init in graph_def.initializer} else: initialized = set() input_dict = dict() module = BackendTFModule(handlers, opset, strict, graph_def, cls) signatures = dict() tf_rep_graph = tf.Graph() with tf_rep_graph.as_default(): for value_info in graph_def.input: if value_info.name in initialized: continue shape = list(d.dim_value if ( d.dim_value > 0 and d.dim_param == "") else None for d in value_info.type.tensor_type.shape.dim) value_info_name = value_info.name.replace( ":", "_tf_") + "_" + get_unique_suffix( ) if ":" in value_info.name else value_info.name tf_spec = tf.TensorSpec( shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name) signatures[value_info.name] = tf_spec if gen_tensor_dict or training_mode: x = tf.compat.v1.placeholder( data_type.onnx2tf( value_info.type.tensor_type.elem_type), name=value_info_name, shape=shape ) if value_info.name not in input_tensor_dict else input_tensor_dict[ value_info.name] input_dict[value_info.name] = x if gen_tensor_dict or training_mode: input_dict_items = cls._onnx_initializer_to_input_dict_items( graph_def.initializer, training_mode=True) tensor_dict = dict(input_dict) tensor_dict.update(input_dict_items) tensor_dict[ training_flag_name] = tf.compat.v1.placeholder_with_default( False, shape=[]) for node in graph_def.node: onnx_node = OnnxNode(node) output_ops = cls._onnx_node_to_tensorflow_op(onnx_node, tensor_dict, handlers, opset=opset, strict=strict) curr_node_output_map = dict( zip(onnx_node.outputs, output_ops)) tensor_dict.update(curr_node_output_map) tf_rep = TensorflowRep() tf_rep.inputs = [ value_info.name for value_info in graph_def.input if value_info.name not in initialized ] tf_rep.outputs = [value_info.name for value_info in graph_def.output] module.outputs = tf_rep.outputs tf_rep.tf_module = module tf_rep.signatures = signatures if gen_tensor_dict or training_mode: tf_rep.tensor_dict = tensor_dict if training_mode: tf_rep.graph = tf_rep_graph tf_rep.onnx_op_list = cls._get_onnx_op_list(graph_def) return tf_rep
def _common(cls, node, **kwargs): body = node.attrs["body"] tensor_dict = kwargs["tensor_dict"] M = tensor_dict[node.inputs[0]] if tensor_dict[ node.inputs[0]].dtype == tf.int64 else None cond = None if tensor_dict[ node.inputs[1]].dtype == tf.string else tf.cast( tensor_dict[node.inputs[1]], tf.bool) v_initial = [ tensor_dict[graph_input] for graph_input in node.inputs[2:] ] v_shapes = [v.get_shape() for v in v_initial] current_opset = [make_opsetid(cls.DOMAIN, cls.VERSION)] # outputs of the body will be in this format: # (condition, loop carried dependencies..., scan_outputs...) scan_outputs_start_index = 1 + len(v_initial) scan_outputs = [ tf.TensorArray(dtype=data_type.onnx2tf( body.output[i].type.tensor_type.elem_type), size=0, dynamic_size=True) for i in range(scan_outputs_start_index, len(body.output)) ] scan_outputs_shapes = [tf.TensorShape(None) for o in scan_outputs] def run_subgraph(cond, v, scan_outputs): input_values = {} input_values[body.input[0].name] = M input_values[body.input[1].name] = cond for i in range(2, len(body.input)): input_values[body.input[i].name] = v[i - 2] tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops( graph_def=body, input_values=input_values, opset=current_opset) outputs = [tensor_dict[output.name] for output in body.output] for i in range(scan_outputs_start_index, len(outputs)): s_index = i - scan_outputs_start_index insert_index = scan_outputs[s_index].size() scan_outputs[s_index] = scan_outputs[s_index].write( insert_index, outputs[i]) return outputs[0], outputs[ 1:scan_outputs_start_index], scan_outputs # for loop if M is not None and cond is None: M = tf.cast(M, tf.int32) condition = lambda cond, v, scan_outputs: True _, v_final, scan_outputs = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=["", v_initial, scan_outputs], shape_invariants=[ tf.TensorShape(None), v_shapes, scan_outputs_shapes ], maximum_iterations=M) # while and do-while loop elif M is None and cond is not None: condition = lambda cond, v, scan_outputs: tf.reduce_all( tf.equal(cond, True)) cond, v_final, scan_outputs = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=[cond, v_initial, scan_outputs], shape_invariants=[ tf.TensorShape(None), v_shapes, scan_outputs_shapes ]) # combine for loop and while loop together elif M is not None and cond is not None: M = tf.cast(M, tf.int32) condition = lambda cond, v, scan_outputs: tf.reduce_all( tf.equal(cond, True)) cond, v_final, scan_outputs = tf.while_loop( cond=condition, body=run_subgraph, loop_vars=[cond, v_initial, scan_outputs], shape_invariants=[ tf.TensorShape(None), v_shapes, scan_outputs_shapes ], maximum_iterations=M) else: # M is None and cond is None exception.OP_UNSUPPORTED_EXCEPT( "Both M and cond in Loop are not set at the same time", "Tensorflow.(PS. if you want to create a do-while loop " + "then please set cond to True or 1)") scan_outputs_tensors = [o.stack() for o in scan_outputs] if scan_outputs_start_index == len(body.output): # there is no scan_output in the body graph return [v_final] else: return [v_final, scan_outputs_tensors]