def version_7(cls, ctx, node, **kwargs): tfl_while_inputs = node.input output_shapes = node.output_shapes output_dtypes = node.output_dtypes output_names = node.output cond_name = node.get_attr_str("cond_subgraph_index") cond_graph = find_function(cond_name) cond_graph.parent_graph = ctx body_name = node.get_attr_str("body_subgraph_index") body = find_function(body_name) body.parent_graph = ctx ctx.remove_node(node.name) cond_binding = parameter_binding(cond_graph, tfl_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter. # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph. scan_outputs = sorted(body.scan_outputs, reverse=True) def input_is_unused(g, index): return len(g.find_output_consumers(g.inputs[index])) == 0 scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)] for idx, _ in scan_outputs: del tfl_while_inputs[idx] output_shapes.append(output_shapes.pop(idx)) output_dtypes.append(output_dtypes.pop(idx)) output_names.append(output_names.pop(idx)) max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max)) loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs, output_count=len(output_shapes), name=node.name + "_loop", shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True) output_map = dict(zip(output_names, loop_node.output)) # shift output consumers for k, v in output_map.items(): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs) for i in range(len(scan_outputs)): squeeze_node = GraphBuilder(body).make_squeeze( {'data': body.outputs[-1-i], "axes": [0]}, return_node=True) body.outputs[-1-i] = squeeze_node.output[0] loop_node.set_body_graph_as_attr("body", body)
def version_1(cls, ctx, node, **kwargs): """V2 control flow - If""" inputs = node.input[1:] if node.type == "If" and len(inputs) == 0: # this comes from the re-writers return output_shapes = node.output_shapes output_dtypes = node.output_dtypes ctx.remove_node(node.name) # replace the original node if_node = ctx.make_node("If", node.input[:1], name=node.name, output_count=len(output_shapes), shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True) for branch in ["then_branch", "else_branch"]: func_name = node.get_attr_str(branch) g = find_function(func_name) g.parent_graph = ctx wire_if_branch(ctx, g, inputs, output_shapes, output_dtypes, func_name, node.name) if_node.set_body_graph_as_attr(branch, g)
def version_1(cls, ctx, node, **kwargs): """V2 control flow - If""" inputs = node.input[1:] output_shapes = node.output_shapes output_dtypes = node.output_dtypes ctx.remove_node(node.name) # replace the original node branches = {} for branch in ["then_branch", "else_branch"]: func_name = node.get_attr_str(branch) g = find_function(func_name) g.parent_graph = ctx wire_if_branch(ctx, g, inputs, output_shapes, output_dtypes, func_name, node.name) branches[branch] = g _ = ctx.make_node("If", node.input[:1], name=node.name, output_count=len(output_shapes), shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True, branches=branches)
def version_7(cls, ctx, node, **kwargs): tfl_while_inputs = node.input output_shapes = node.output_shapes output_dtypes = node.output_dtypes output_names = node.output cond_name = node.get_attr_str("cond_subgraph_index") cond_graph = find_function(cond_name) cond_graph.parent_graph = ctx body_name = node.get_attr_str("body_subgraph_index") body = find_function(body_name) body.parent_graph = ctx ctx.remove_node(node.name) cond_binding = parameter_binding(cond_graph, tfl_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max)) loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs, output_count=len(output_shapes), name=node.name + "_loop", shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True) output_map = dict(zip(output_names, loop_node.output)) # shift output consumers for k, v in output_map.items(): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph) loop_node.set_body_graph_as_attr("body", body)
def version_7(cls, ctx, node, **kwargs): # the tensorflow while input is: # loop_counter, max_iterations, [loop_vars] # cond and body use the same inputs # outputs are identical to inputs tf_while_inputs = node.input # the onnx loop input is: # max_iterations, cond, [loop_vars] # body uses the inputs: # iteration, cond, [loop_vars] # the onnx loop output is: # cond [v_final_and_scan_outputs] output_shapes = node.output_shapes output_dtypes = node.output_dtypes # node.output must be copied as some element # may be removed from output_names below output_names = node.output.copy() # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other consumers, # modify it in place. Otherwise, make a new const node and leave the original unchanged. maximum_iterations_name = node.input[1] maximum_iterations = node.inputs[1].get_tensor_value() if maximum_iterations == -1: maximum_iterations = np.iinfo(np.int64).max consumers = ctx.find_output_consumers(maximum_iterations_name) external_consumers = [ c for c in consumers if c != node and c.type != 'TensorListReserve' ] if len(external_consumers) == 0: ctx.remove_node(node.inputs[1].name) else: maximum_iterations_name = utils.make_name(node.inputs[1].name) ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64)) node.input[1] = maximum_iterations_name cond_name = node.get_attr_str("cond") cond_graph = find_function(cond_name) cond_graph.parent_graph = ctx body_name = node.get_attr_str("body") body = find_function(body_name) body.parent_graph = ctx loop_vars = [] # passed into the loop body_input_to_state_var = { } # Map from body input name to state var name cond_input_to_state_var = {} to_remove = [] input_idx_to_remove = [] # remove TensorListReserve for idx, name in enumerate(tf_while_inputs): if idx == 1: # onnx does not know maximum_iterations in the body so move this to a state var body_input_to_state_var[ body.func_inputs[idx]] = maximum_iterations_name cond_input_to_state_var[ cond_graph.func_inputs[idx]] = maximum_iterations_name continue if idx < 2: # skip [0,1] loop_counter, max_iterations continue n = node.inputs[idx] if n.type in ["TensorListReserve", "TensorListResize"]: # there is no equivalent step in onnx and we should remove it. to_remove.append((idx, n)) continue # tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead if body.func_inputs[idx] in body.ta_reads: body_input_to_state_var[body.func_inputs[idx]] = name cond_input_to_state_var[cond_graph.func_inputs[idx]] = name input_idx_to_remove.append(idx) else: loop_vars.append(name) # loop_vars that become state_vars need to be removed from output as well for idx in reversed(input_idx_to_remove): del output_shapes[idx] del output_dtypes[idx] del output_names[idx] del body.outputs[idx] removed_scan_outputs = {} # remove tensor array that are passed in to the loop for idx, n in reversed(to_remove): ctx.remove_node(n.name) # make the node output bad ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC") del body.func_inputs[idx] del cond_graph.func_inputs[idx] del tf_while_inputs[idx] # save the index of the scan output removed_scan_outputs[body.outputs[idx]] = idx del body.outputs[idx] # FIXME: Output shapes may be in wrong order if there are multiple scan outputs output_shapes.append(output_shapes[idx]) output_dtypes.append(output_dtypes[idx]) output_names.append(output_names[idx]) del output_shapes[idx] del output_dtypes[idx] del output_names[idx] utils.make_sure( len(removed_scan_outputs) <= 1, "converter only supports while loops with a single scan output") ctx.remove_node(node.name) # In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph # before the loop and into the body. cond_binding = parameter_binding(cond_graph, tf_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) # onnx Loop op outputs only loop_vars so we need shift output dtypes/shapes and consumers output_shapes = output_shapes[2:] output_dtypes = output_dtypes[2:] output_names = output_names[2:] loop_node = ctx.make_node( "Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars, output_count=len(output_shapes), name=node.name + "_loop", shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True) output_map = dict(zip(output_names, loop_node.output)) # shift output consumers for k, v in output_map.items(): ctx.replace_all_inputs(ctx.get_nodes(), k, v) wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes, output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs) # if there was a tensorflow variant type, bind in a real type here # FIXME: I don't think this is needed anymore for i, n in enumerate(body.inputs): if body.get_dtype(n.output[0]) == onnx_pb.TensorProto.UNDEFINED: body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i])) loop_node.set_body_graph_as_attr("body", body)
def version_7(cls, ctx, node, **kwargs): # the tensorflow while input is: # loop_counter, max_iterations, [loop_vars] # cond and body use the same inputs # outputs are identical to inputs tf_while_inputs = node.input # the onnx loop input is: # max_iterations, cond, [loop_vars] # body uses the inputs: # iteration, cond, [loop_vars] # the onnx loop output is: # cond [v_final_and_scan_outputs] output_shapes = node.output_shapes output_dtypes = node.output_dtypes # make maximum_iterations int64 and replace -1(tf) with maxsize(onnx) maximum_iterations_name = node.input[1] maximum_iterations = node.inputs[1].get_tensor_value() ctx.remove_node(node.inputs[1].name) if maximum_iterations == -1: maximum_iterations = sys.maxsize ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64)) cond_name = node.get_attr_str("cond") cond_graph = find_function(cond_name) cond_graph.parent_graph = ctx body_name = node.get_attr_str("body") body = find_function(body_name) body.parent_graph = ctx loop_vars = [] # passed into the loop state_vars = {} # comes from outer context to_remove = [] input_idx_to_remove = [] # remove TensorListReserve for idx, name in enumerate(tf_while_inputs): if idx == 1: # onnx does not know maximum_iterations in the body so move this to a state var state_vars[body.func_inputs[idx]] = maximum_iterations_name continue if idx < 2: # skip [0,1] loop_counter, max_iterations continue n = node.inputs[idx] if n.type in ["TensorListReserve", "TensorListResize"]: # there is no equivalent step in onnx and we should remove it. # But we make this an identity to keep the loop_vars the same on input and output # of the body but there should be no access to this argument in the body. to_remove.append((idx, n)) continue # tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead if body.func_inputs[idx] in body.ta_reads: state_vars[body.func_inputs[idx]] = name input_idx_to_remove.append(idx) else: loop_vars.append(name) # loop_vars that become state_vars need to be removed from output as well for idx in reversed(input_idx_to_remove): del output_shapes[idx] del output_dtypes[idx] del body.outputs[idx] # remove tensor array that are passed in to the loop for idx, n in reversed(to_remove): ctx.remove_node(n.name) # make the node output bad ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC") del body.func_inputs[idx] del cond_graph.func_inputs[idx] del tf_while_inputs[idx] ctx.remove_node(node.name) # In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph # before the loop and into the body. cond_binding = parameter_binding(cond_graph, tf_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) # onnx Loop op outputs only loop_vars so we need shift output dtypes/shapes and consumers output_map = { node.output[i + 2]: node.output[i] for i in range(len(node.output) - 2) } output_shapes = output_shapes[2:] output_dtypes = output_dtypes[2:] loop_node = ctx.make_node( "Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars, output_count=len(output_shapes), name=node.name, shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True) # shift output consumers for k, v in output_map.items(): ctx.replace_all_inputs(ctx.get_nodes(), k, v) wire_while_body(ctx, body, loop_node.inputs, state_vars, output_shapes, output_dtypes, body_name, node.name, cond_graph, tf_while_inputs) # if there was a tensorflow variant type, bind in a real type here for i, n in enumerate(body.inputs): if body.get_dtype(n.output[0]) == onnx_pb.TensorProto.UNDEFINED: body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i])) loop_node.set_body_graph_as_attr("body", body)
def version_7(cls, ctx, node, **kwargs): # the tensorflow while input is: # loop_counter, max_iterations, [loop_vars] # cond and body use the same inputs # outputs are identical to inputs tf_while_inputs = node.input # the onnx loop input is: # max_iterations, cond, [loop_vars] # body uses the inputs: # iteration, cond, [loop_vars] # the onnx loop output is: # cond [v_final_and_scan_outputs] output_shapes = node.output_shapes output_dtypes = node.output_dtypes # node.output must be copied as some element # may be removed from output_names below output_names = node.output.copy() # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other # consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged. # if maximum_iterations is not const,should add an cast node(cast to int64) maximum_iterations_name = node.input[1] if node.inputs[1].is_const(): maximum_iterations = node.inputs[1].get_tensor_value() if maximum_iterations == -1: maximum_iterations = np.iinfo(np.int64).max consumers = ctx.find_output_consumers(maximum_iterations_name) external_consumers = [ c for c in consumers if c != node and c.type != 'TensorListReserve' ] if len(external_consumers) == 0: ctx.remove_node(node.inputs[1].name) else: maximum_iterations_name = utils.make_name(node.inputs[1].name) ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64)) ctx.replace_input(node, node.input[1], maximum_iterations_name, 1) maximum_iterations_int64 = maximum_iterations_name else: cast_inputs = [maximum_iterations_name] attr = {"to": onnx_pb.TensorProto.INT64} cast_name = node.name + "_cast" cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name) maximum_iterations_int64 = cast_node.output[0] cond_name = node.get_attr_str("cond") cond_graph = find_function(cond_name) cond_graph.parent_graph = ctx body_name = node.get_attr_str("body") body = find_function(body_name) body.parent_graph = ctx loop_vars = [] # passed into the loop body_input_to_state_var = { } # Map from body input name to state var name cond_input_to_state_var = {} scan_outputs = [] input_idx_to_remove = [] idx_to_ragged_writes = dict(body.ragged_variant_list_writes) # remove TensorListReserve for idx, name in enumerate(tf_while_inputs): if idx == 1: # onnx does not know maximum_iterations in the body so move this to a state var body_input_to_state_var[ body.input_names[idx]] = maximum_iterations_name cond_input_to_state_var[ cond_graph.input_names[idx]] = maximum_iterations_name continue if idx < 2: # skip [0,1] loop_counter, max_iterations continue n = node.inputs[idx] if n.type in ["TensorListReserve", "TensorListResize"]: # there is no equivalent step in onnx and we should remove it. output_shape = None output_dtype = n.get_attr_value("element_dtype") is_ragged = False if n.type == "TensorListReserve" and n.inputs[0].is_const( ) and not n.inputs[0].is_scalar(): output_shape = [-1] + n.inputs[0].get_tensor_value( as_list=True) if idx in idx_to_ragged_writes: output_shape = None output_dtype = body.get_dtype( idx_to_ragged_writes[idx].input[0]) is_ragged = True loop_vars.append(name) scan_outputs.append( (idx, n, output_shape, output_dtype, is_ragged)) continue # tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead if body.input_names[idx] in body.ta_reads: body_input_to_state_var[body.input_names[idx]] = name cond_input_to_state_var[cond_graph.input_names[idx]] = name input_idx_to_remove.append(idx) else: loop_vars.append(name) # loop_vars that become state_vars need to be removed from output as well for idx in reversed(input_idx_to_remove): del output_shapes[idx] del output_dtypes[idx] del output_names[idx] del body.outputs[idx] scan_output_names = [] ragged_scan_output_names = [] ragged_scan_output_to_len = {} # remove tensor arrays that are passed in to the loop for idx, n, output_shape, output_dtype, is_ragged in reversed( scan_outputs): if is_ragged: out = n.output[0] ctx.remove_node(n.name) seq_empty = ctx.make_node("SequenceEmpty", [], attr={'dtype': output_dtype}, name=n.name, outputs=[out], shapes=[None], dtypes=[utils.SeqType(output_dtype)]) ctx.replace_all_inputs(n.output[0], seq_empty.output[0]) # Ragged tensors also must track the length of each row output_shapes.append([-1]) output_dtypes.append(TensorProto.INT64) output_shapes[idx] = None output_dtypes[idx] = utils.SeqType(output_dtype) body_ragged_name = utils.make_name("ragged_scan_output") external_ragged_name = utils.make_name("ragged_output") scan_output_names.append(body_ragged_name) output_names.append(external_ragged_name) ragged_scan_output_names.append(body_ragged_name) ragged_scan_output_to_len[ output_names[idx]] = external_ragged_name continue ctx.remove_node(n.name) # make the node output bad ctx.replace_all_inputs(n.output[0], "@@ALLOC") # ops=ctx.get_nodes() del body.inputs[idx] del cond_graph.inputs[idx] del tf_while_inputs[idx] scan_output_names.append(body.outputs[idx]) del body.outputs[idx] output_shapes.append(output_shape) output_dtypes.append(output_dtype) output_names.append(output_names[idx]) del output_shapes[idx] del output_dtypes[idx] del output_names[idx] ctx.remove_node(node.name) # In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph # before the loop and into the body. cond_binding = parameter_binding(cond_graph, tf_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) # onnx Loop op outputs only loop_vars so we need shift output dtypes/shapes and consumers output_shapes = output_shapes[2:] output_dtypes = output_dtypes[2:] output_names = output_names[2:] branches = {"body": body} loop_node = ctx.make_node( "Loop", [maximum_iterations_int64, cond_outputs[0]] + loop_vars, output_count=len(output_shapes), name=node.name + "_loop", shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True, branches=branches) output_map = dict(zip(output_names, loop_node.output)) # shift output consumers for k, v in output_map.items(): if k not in ragged_scan_output_to_len.values(): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() ragged_scan_output_to_len = { output_map[k]: output_map[v] for k, v in ragged_scan_output_to_len.items() } wire_while_body(ctx, body, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes, output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names, ragged_scan_output_names) loop_node.ragged_scan_output_to_len = ragged_scan_output_to_len # if there was a tensorflow variant type, bind in a real type here # FIXME: I don't think this is needed anymore for i, n in enumerate(body.inputs): if body.get_dtype(n.output[0]) == onnx_pb.TensorProto.UNDEFINED: body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i]))
def rewrite_gru_tf2(g, ops): pattern1 = make_grucell_pattern("Identity") pattern2 = keras_gru_pattern for pattern in [pattern1, pattern2]: matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match_result in match_results: activation_op = match_result.get_op("optional_activation") activations = ["Sigmoid", activation_op.type] if activation_op.type not in ["Relu", "Tanh", "Sigmoid"]: continue if pattern is pattern1: concat = match_result.get_op("cell_inputs") if len(concat.inputs) != 3: continue get_item = concat.inputs[0] init_state = concat.inputs[1] else: get_item = match_result.get_op("gru_input") init_state = match_result.get_op("state") if not get_item.type == "TensorListGetItem": continue x_e = get_item.inputs[0] if not x_e.is_graph_input(): continue x_idx = g.input_names.index(x_e.output[0]) if not init_state.is_graph_input(): continue init_state_idx = g.input_names.index(init_state.output[0]) cell_output = match_result.get_op("cell_output") final_consumers = g.find_output_consumers(cell_output.output[0]) select_ops = [n for n in final_consumers if n.type == "Select"] def has_tensor_list_consumer(n): return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0])) select_ops = [n for n in select_ops if has_tensor_list_consumer(n)] if len(select_ops) == 1: greater_eq = select_ops[0].inputs[0] if greater_eq.type != "GreaterEqual": continue seq_len = greater_eq.inputs[1] if not seq_len.is_graph_input(): continue seq_len_idx = g.input_names.index(seq_len.output[0]) final_consumers = g.find_output_consumers( select_ops[0].output[0]) else: seq_len_idx = None tensor_set_items = [ n for n in final_consumers if n.type == "TensorListSetItem" ] if len(tensor_set_items) != 1: continue if not tensor_set_items[0].inputs[0].is_graph_input(): continue out_idx = g.input_names.index(tensor_set_items[0].input[0]) hk = match_result.get_op("hidden_kernel") while hk.type == "Identity": hk = hk.inputs[0] if not hk.is_graph_input(): continue hk_idx = g.input_names.index(hk.output[0]) hb = match_result.get_op("hidden_bias") if not hb.is_graph_input(): continue hb_idx = g.input_names.index(hb.output[0]) gk = match_result.get_op("gate_kernel") while gk.type == "Identity": gk = gk.inputs[0] if not gk.is_graph_input(): continue gk_idx = g.input_names.index(gk.output[0]) gb = match_result.get_op("gate_bias") if not gb.is_graph_input(): continue gb_idx = g.input_names.index(gb.output[0]) bias_add = match_result.get_op("bias_add") if bias_add is not None and bias_add.data_format != "NHWC": continue g.gru_rewriter_context = { "x_idx": x_idx, "out_idx": out_idx, "initial_state_idx": init_state_idx, "hidden_kernel_idx": hk_idx, "hidden_bias_idx": hb_idx, "gate_kernel_idx": gk_idx, "gate_bias_idx": gb_idx, "seq_len_idx": seq_len_idx, "activations": activations, "from_keras": pattern is pattern2, "linear_before_reset": 1 if pattern is pattern2 else 0, } for op in ops: if op.is_while(): body_graph = find_function(op.get_attr_str("body")) if body_graph.gru_rewriter_context is None: continue body_context = body_graph.gru_rewriter_context hk = op.input[body_context["hidden_kernel_idx"]] hb = op.input[body_context["hidden_bias_idx"]] gk = op.input[body_context["gate_kernel_idx"]] gb = op.input[body_context["gate_bias_idx"]] if not all(g.is_const(w) for w in [hk, hb, gk, gb]): continue hk_const = g.get_tensor_value(hk, as_list=False) hb_const = g.get_tensor_value(hb, as_list=False) gk_const = g.get_tensor_value(gk, as_list=False) gb_const = g.get_tensor_value(gb, as_list=False) initial_state_sq = op.input[body_context["initial_state_idx"]] initial_state = GraphBuilder(g).make_unsqueeze({ "data": initial_state_sq, "axes": [0] }) context = UnitRnnContext() context.from_keras = body_context["from_keras"] context.weights.update({ "hidden_kernel": hk_const, "hidden_bias": hb_const, "gate_kernel": gk_const, "gate_bias": gb_const }) context.attributes["activations"] = body_context["activations"] context.attributes["linear_before_reset"] = body_context[ "linear_before_reset"] tensor_array_inp = op.inputs[body_context["x_idx"]] if not tensor_array_inp.type == "TensorListFromTensor": continue final_consumers = g.find_output_consumers( op.output[body_context["out_idx"]]) output_ys = [ n.output[0] for n in final_consumers if n.type == "TensorListStack" ] context.onnx_input_ids["X"] = tensor_array_inp.input[0] if body_context["seq_len_idx"] is None: context.onnx_input_ids["sequence_lens"] = "" else: context.onnx_input_ids["sequence_lens"] = op.input[ body_context["seq_len_idx"]] context.onnx_input_ids["initial_state"] = initial_state gru_rewriter = GRUUnitRewriter(g) gru_rewriter.process_weights_and_bias(context) gru_node = gru_rewriter.create_rnn_node(context) squeeze_output = GraphBuilder(g).make_squeeze({ "data": gru_node.output[0], "axes": [1] }) for output in output_ys: g.replace_all_inputs(output, squeeze_output) f_state_squeeze = GraphBuilder(g).make_squeeze({ "data": gru_node.output[1], "axes": [0] }) g.replace_all_inputs(op.output[body_context["initial_state_idx"]], f_state_squeeze) return g.get_nodes()
def rewriter_lstm_tf2(g, ops): pattern1 = make_lstm_pattern(enter_or_id="Identity") # TF LSTM pattern2 = make_lstm_pattern(from_keras=True, use_bias=False) # keras LSTM pattern3 = make_lstm_pattern(from_keras=True, use_bias=True) # keras LSTM with bias for pattern in [pattern1, pattern2, pattern3]: matcher = GraphMatcher(pattern, allow_reorder=False) match_results = list(matcher.match_ops(ops)) for match_result in match_results: from_keras = pattern != pattern1 activations_fgh = [ match_result.get_op("ft").type, match_result.get_op("gt").type, match_result.get_op("ct'").type ] supported_activations = ['Relu', 'Sigmoid', 'Tanh'] if any(f not in supported_activations for f in activations_fgh): continue # extract input x_t if from_keras: get_item = match_result.get_op("xt") else: concat = match_result.get_op("xh") if len(concat.inputs) != 3: continue get_item = concat.inputs[0] if not get_item.type == "TensorListGetItem": continue x_e = get_item.inputs[0] if not x_e.is_graph_input(): continue x_idx = g.input_names.index(x_e.output[0]) # extract output h_t ht_mul = match_result.get_op("ht") final_consumers = g.find_output_consumers(ht_mul.output[0]) select_ops = [n for n in final_consumers if n.type == "Select"] def has_tensor_list_consumer(n): return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0])) select_ops = [n for n in select_ops if has_tensor_list_consumer(n)] if len(select_ops) == 1: greater_eq = select_ops[0].inputs[0] if greater_eq.type != "GreaterEqual": continue seq_len = greater_eq.inputs[1] if not seq_len.is_graph_input(): continue seq_len_idx = g.input_names.index(seq_len.output[0]) final_consumers = g.find_output_consumers( select_ops[0].output[0]) else: seq_len_idx = None tensor_set_items = [ n for n in final_consumers if n.type == "TensorListSetItem" ] if len(tensor_set_items) != 1: continue if not tensor_set_items[0].inputs[0].is_graph_input(): continue out_idx = g.input_names.index(tensor_set_items[0].input[0]) # extract input h_(t-1) and c_(t-1) init_state = match_result.get_op( "ht-1") if from_keras else concat.inputs[1] if init_state.is_graph_input(): # c and h are separate h_idx = g.input_names.index(init_state.output[0]) c_e = match_result.get_op("c") if not c_e.is_graph_input(): continue c_idx = g.input_names.index(c_e.output[0]) ch_info = { "state_is_tuple": True, "c_idx": c_idx, "h_idx": h_idx, } else: # c and h are concatenated if not init_state.type == "Slice": continue ch_e = init_state.inputs[0] if not ch_e.is_graph_input(): continue ch_idx = g.input_names.index(ch_e.output[0]) c_e = match_result.get_op("c") if not c_e.type == "Slice" or c_e.input[0] != ch_e.output[0]: continue ch_info = { "state_is_tuple": False, "ch_idx": ch_idx, } # extract weights and bias w_idx = hk_idx = gk_idx = 0 ft_bias = None if from_keras: # hidden kernel hk = match_result.get_op("R") while hk.type == "Identity": hk = hk.inputs[0] if not hk.is_graph_input(): continue hk_idx = g.input_names.index(hk.output[0]) # gate kernel gk = match_result.get_op("W") while gk.type == "Identity": gk = gk.inputs[0] if not gk.is_graph_input(): continue gk_idx = g.input_names.index(gk.output[0]) # Wb and Rb are concatenated b_idx = None if pattern is pattern3: bias_add = match_result.get_op("bias_add") if bias_add is not None and bias_add.data_format != "NHWC": continue b_e = match_result.get_op("cell_bias") while b_e.type == "Identity": b_e = b_e.inputs[0] if not b_e.is_graph_input(): continue b_idx = g.input_names.index(b_e.output[0]) else: # W and R are concatenated w_e = match_result.get_op("cell_kernel") if not w_e.is_graph_input(): continue w_idx = g.input_names.index(w_e.output[0]) bias_add = match_result.get_op("bias_add") if bias_add is not None and bias_add.data_format != "NHWC": continue b_e = match_result.get_op("cell_bias") if not b_e.is_graph_input(): continue b_idx = g.input_names.index(b_e.output[0]) ft_bias_node = match_result.get_op("ft_bias") if not ft_bias_node.is_const(): continue if g.get_dtype(ft_bias_node.output[0]) != g.get_dtype( b_e.output[0]): continue ft_bias = ft_bias_node.get_tensor_value(as_list=False) g.lstm_rewriter_context = { # common "x_idx": x_idx, "out_idx": out_idx, "seq_len_idx": seq_len_idx, "bias_idx": b_idx, "from_keras": from_keras, "activations_fgh": activations_fgh, **ch_info, # {state_is_tuple, h_idx, c_idx} or {state_is_tuple, ch_idx} # TF "weight_idx": w_idx, "ft_bias": ft_bias, # Keras "w_idx": gk_idx, "r_idx": hk_idx, } for op in ops: if op.is_while(): body_graph = find_function(op.get_attr_str("body")) if body_graph.lstm_rewriter_context is None: continue body_context = body_graph.lstm_rewriter_context # parse weights consts = [] if body_context["from_keras"]: wx = op.input[body_context["w_idx"]] wh = op.input[body_context["r_idx"]] wx_const = g.get_tensor_value(wx, as_list=False) wh_const = g.get_tensor_value(wh, as_list=False) consts.extend([wx, wh]) else: w = op.input[body_context["weight_idx"]] w_const = g.get_tensor_value(w, as_list=False) consts.append(w) # parse bias if body_context["bias_idx"] is not None: b = op.input[body_context["bias_idx"]] b_const = g.get_tensor_value(b, as_list=False) consts.append(b) else: b_const = None if not all(g.is_const(c) for c in consts): continue # parse states if body_context["state_is_tuple"]: initial_c_sq = op.input[body_context["c_idx"]] initial_h_sq = op.input[body_context["h_idx"]] initial_c = GraphBuilder(g).make_unsqueeze({ "data": initial_c_sq, "axes": [0] }) initial_h = GraphBuilder(g).make_unsqueeze({ "data": initial_h_sq, "axes": [0] }) else: initial_ch = op.input[body_context["ch_idx"]] if not g.is_const(initial_ch): continue initial_ch_const = g.get_tensor_value(initial_ch, as_list=False) if not len(initial_ch_const.shape) == 2: continue initial_ch_const = np.expand_dims(initial_ch_const, axis=0) initial_c_const, initial_h_const = np.split(initial_ch_const, 2, axis=2) initial_c = g.make_const(utils.make_name("initial_c"), initial_c_const).output[0] initial_h = g.make_const(utils.make_name("initial_h"), initial_h_const).output[0] # build LSTMContext context = LSTMContext() context.from_keras = body_context["from_keras"] if context.from_keras: context.weights.append({ "w": wx_const, "r": wh_const, "bias": b_const }) else: context.weights.append({ "weight": w_const, "bias": b_const, "ft_bias": body_context["ft_bias"] }) context.onnx_input_ids.append({}) context.input_size.append(None) context.hidden_size.append(None) context.attributes.append( {"activations": body_context['activations_fgh']}) tensor_array_inp = op.inputs[body_context["x_idx"]] if not tensor_array_inp.type == "TensorListFromTensor": continue final_consumers = g.find_output_consumers( op.output[body_context["out_idx"]]) output_ys = [ n.output[0] for n in final_consumers if n.type == "TensorListStack" ] context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0] if body_context["seq_len_idx"] is None: context.onnx_input_ids[0]["sequence_lens"] = "" else: context.onnx_input_ids[0]["sequence_lens"] = op.input[ body_context["seq_len_idx"]] context.onnx_input_ids[0]["initial_c"] = initial_c context.onnx_input_ids[0]["initial_h"] = initial_h lstm_rewriter = LSTMRewriter(g) lstm_rewriter.num_lstm_layers = 1 lstm_rewriter.process_weights_and_bias(context) lstm_node = lstm_rewriter.create_rnn_node(context)[0] squeeze_output = GraphBuilder(g).make_squeeze({ "data": lstm_node.output[0], "axes": [1] }) for output in output_ys: g.replace_all_inputs(output, squeeze_output) if body_context["state_is_tuple"]: c_squeeze = GraphBuilder(g).make_squeeze({ "data": lstm_node.output[2], "axes": [0] }) h_squeeze = GraphBuilder(g).make_squeeze({ "data": lstm_node.output[1], "axes": [0] }) g.replace_all_inputs(op.output[body_context["c_idx"]], c_squeeze) g.replace_all_inputs(op.output[body_context["h_idx"]], h_squeeze) else: concat_ch = g.make_node( "Concat", [lstm_node.output[2], lstm_node.output[1]], attr={ "axis": 2 }).output[0] ch_squeeze = GraphBuilder(g).make_squeeze({ "data": concat_ch, "axes": [0] }) ch_output = op.output[body_context["ch_idx"]] g.replace_all_inputs(ch_output, ch_squeeze) return g.get_nodes()
def rewriter_lstm_tf2(g, ops): pattern1 = make_lstmcell_pattern("Identity") for pattern in [pattern1]: matcher = GraphMatcher(pattern, allow_reorder=False) match_results = list(matcher.match_ops(ops)) for match_result in match_results: concat = match_result.get_op("xh") if len(concat.inputs) != 3: continue get_item = concat.inputs[0] if not get_item.type == "TensorListGetItem": continue x_e = get_item.inputs[0] if not x_e.is_graph_input(): continue x_idx = g.input_names.index(x_e.output[0]) ht_mul = match_result.get_op("ht") final_consumers = g.find_output_consumers(ht_mul.output[0]) select_ops = [n for n in final_consumers if n.type == "Select"] def has_tensor_list_consumer(n): return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0])) select_ops = [n for n in select_ops if has_tensor_list_consumer(n)] if len(select_ops) == 1: greater_eq = select_ops[0].inputs[0] if greater_eq.type != "GreaterEqual": continue seq_len = greater_eq.inputs[1] if not seq_len.is_graph_input(): continue seq_len_idx = g.input_names.index(seq_len.output[0]) final_consumers = g.find_output_consumers(select_ops[0].output[0]) else: seq_len_idx = None tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"] if len(tensor_set_items) != 1: continue if not tensor_set_items[0].inputs[0].is_graph_input(): continue out_idx = g.input_names.index(tensor_set_items[0].input[0]) if concat.inputs[1].is_graph_input(): # c and h are separate h_idx = g.input_names.index(concat.input[1]) c_e = match_result.get_op("c") if not c_e.is_graph_input(): continue c_idx = g.input_names.index(c_e.output[0]) ch_info = { "state_is_tuple": True, "c_idx": c_idx, "h_idx": h_idx, } else: # c and h are concatenated if not concat.inputs[1].type == "Slice": continue ch_e = concat.inputs[1].inputs[0] if not ch_e.is_graph_input(): continue ch_idx = g.input_names.index(ch_e.output[0]) c_e = match_result.get_op("c") if not c_e.type == "Slice" or c_e.input[0] != ch_e.output[0]: continue ch_info = { "state_is_tuple": False, "ch_idx": ch_idx, } w_e = match_result.get_op("cell_kernel") if not w_e.is_graph_input(): continue w_idx = g.input_names.index(w_e.output[0]) bias_add = match_result.get_op("bias_add") if bias_add is not None and bias_add.data_format != "NHWC": continue b_e = match_result.get_op("cell_bias") if not b_e.is_graph_input(): continue b_idx = g.input_names.index(b_e.output[0]) ft_bias_node = match_result.get_op("ft_bias") if not ft_bias_node.is_const(): continue if g.get_dtype(ft_bias_node.output[0]) != g.get_dtype(b_e.output[0]): continue ft_bias = ft_bias_node.get_tensor_value(as_list=False) g.lstm_rewriter_context = { "x_idx": x_idx, "out_idx": out_idx, "weight_idx": w_idx, "bias_idx": b_idx, "ft_bias": ft_bias, "seq_len_idx": seq_len_idx, **ch_info } for op in ops: if op.is_while(): body_graph = find_function(op.get_attr_str("body")) if body_graph.lstm_rewriter_context is None: continue body_context = body_graph.lstm_rewriter_context w = op.input[body_context["weight_idx"]] b = op.input[body_context["bias_idx"]] if not g.is_const(w) or not g.is_const(b): continue w_const = g.get_tensor_value(w, as_list=False) b_const = g.get_tensor_value(b, as_list=False) if body_context["state_is_tuple"]: initial_c_sq = op.input[body_context["c_idx"]] initial_h_sq = op.input[body_context["h_idx"]] initial_c = GraphBuilder(g).make_unsqueeze({"data": initial_c_sq, "axes": [0]}) initial_h = GraphBuilder(g).make_unsqueeze({"data": initial_h_sq, "axes": [0]}) else: initial_ch = op.input[body_context["ch_idx"]] if not g.is_const(initial_ch): continue initial_ch_const = g.get_tensor_value(initial_ch, as_list=False) if not len(initial_ch_const.shape) == 2: continue initial_ch_const = np.expand_dims(initial_ch_const, axis=0) initial_c_const, initial_h_const = np.split(initial_ch_const, 2, axis=2) initial_c = g.make_const(utils.make_name("initial_c"), initial_c_const).output[0] initial_h = g.make_const(utils.make_name("initial_h"), initial_h_const).output[0] context = LSTMContext() context.weights.append({"weight": w_const, "bias": b_const, "ft_bias": body_context["ft_bias"]}) context.onnx_input_ids.append({}) context.input_size.append(None) context.hidden_size.append(None) context.attributes.append({}) tensor_array_inp = op.inputs[body_context["x_idx"]] if not tensor_array_inp.type == "TensorListFromTensor": continue final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]]) output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"] context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0] if body_context["seq_len_idx"] is None: context.onnx_input_ids[0]["sequence_lens"] = "" else: context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]] context.onnx_input_ids[0]["initial_c"] = initial_c context.onnx_input_ids[0]["initial_h"] = initial_h lstm_rewriter = LSTMRewriter(g) lstm_rewriter.num_lstm_layers = 1 lstm_rewriter.process_weights_and_bias(context) lstm_node = lstm_rewriter.create_rnn_node(context)[0] squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]}) for output in output_ys: g.replace_all_inputs(output, squeeze_output) if body_context["state_is_tuple"]: c_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[2], "axes": [0]}) h_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[1], "axes": [0]}) g.replace_all_inputs(op.output[body_context["c_idx"]], c_squeeze) g.replace_all_inputs(op.output[body_context["h_idx"]], h_squeeze) else: concat_ch = g.make_node("Concat", [lstm_node.output[2], lstm_node.output[1]], attr={"axis": 2}).output[0] ch_squeeze = GraphBuilder(g).make_squeeze({"data": concat_ch, "axes": [0]}) ch_output = op.output[body_context["ch_idx"]] g.replace_all_inputs(ch_output, ch_squeeze) return g.get_nodes()