def _process_non_tuple_ch_init_nodes(self, context, i): gb = GraphBuilder(self.g) input_id = context.state_variables["ct_ht" + str(i)].enter_input_id hidden_size = context.hidden_size[i] attr = {"axes": [1], "starts": [0], "ends": [hidden_size]} inputs_map = {"data": input_id, **attr} slice_node1 = GraphBuilder(self.g).make_slice(inputs_map) unsqueeze_node_1 = gb.make_unsqueeze({ 'data': slice_node1, "axes": [0] }, return_node=True) attr = { "axes": [1], "starts": [hidden_size], "ends": [hidden_size * 2] } inputs_map = {"data": input_id, **attr} slice_node2 = GraphBuilder(self.g).make_slice(inputs_map) unsqueeze_node_2 = gb.make_unsqueeze({ 'data': slice_node2, "axes": [0] }, return_node=True) return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
def version_13(cls, ctx, node, **kwargs): ctx.ta_reads.append(node.input[0]) node.type = "Gather" ctx.replace_inputs(node, [node.input[0], node.input[1]]) g = GraphBuilder(ctx) usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True) ctx.insert_node_on_output(usq_node) sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True) ctx.insert_node_on_output(sq_node)
def _process_c_or_h_init_nodes(self, initializer_input_id, context): node = self.g.get_node_by_output(initializer_input_id) if node.is_const(): val = node.get_tensor_value(as_list=False) initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) return const_node.output[0] gb = GraphBuilder(self.g) squeeze_node = gb.make_unsqueeze( { 'data': initializer_input_id, "axes": [0] }, return_node=True) to_replace = [n for n in self.g.get_nodes() if n != squeeze_node] self.g.replace_all_inputs(initializer_input_id, squeeze_node.output[0], ops=to_replace) return squeeze_node.output[0]
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # create a dummy loop to calculate the init condition init_cond_output = self._create_subgraph_initial_cond(cond_g_info) ## create Loop body graph with existing nodes body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: shape = out_tensor_value_info.shape utils.make_sure( shape is not None, "Conversion of Loop requries output shape [{}] exists".format(out_tensor_value_info.id) ) out_tensor_value_info.shape = utils.create_vague_shape_like(shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. gb = GraphBuilder(loop_body_g) index_node = gb.make_unsqueeze({'data': input_ta.index_input_id, "axes": [0]}, return_node=True) gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = gb.make_squeeze({'data': gather_node.output[0], "axes": [0]}, return_node=True) loop_body_g.replace_all_inputs(input_ta.consumer.id, data_node.output[0]) # ops=loop_body_g.get_nodes() ## create Loop node branches = {"body": loop_body_g} loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL