def _get_output_shape_dtype(self, cond_context): output_shapes = [] output_dtypes = [] for i, _ in enumerate(cond_context.true_branch_context.output): true_output = cond_context.true_branch_context.output[i] false_output = cond_context.false_branch_context.output[i] true_shape = self.g.get_shape(true_output) utils.make_sure(true_shape is not None, "Shape of {} is None".format(true_output)) true_rank = len(true_shape) true_dtype = self.g.get_dtype(true_output) false_shape = self.g.get_shape(false_output) utils.make_sure(false_shape is not None, "Shape of {} is None".format(false_output)) false_rank = len(false_shape) false_dtype = self.g.get_dtype(false_output) # just require rank is equal if true_rank != false_rank: raise RuntimeError( "the rank of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_rank, false_rank)) if true_dtype != false_dtype: raise RuntimeError( "the dtype of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_dtype, false_dtype)) output_shapes.append(utils.create_vague_shape_like(true_shape)) output_dtypes.append(true_dtype) return output_shapes, output_dtypes
def _get_output_shape_dtype(self, cond_context): output_shapes = [] output_dtypes = [] for i, _ in enumerate(cond_context.true_branch_context.output): true_output = cond_context.true_branch_context.output[i] false_output = cond_context.false_branch_context.output[i] true_shape = self.g.get_shape(true_output) true_dtype = self.g.get_dtype(true_output) false_shape = self.g.get_shape(false_output) false_dtype = self.g.get_dtype(false_output) if not utils.are_shapes_compatible(true_shape, false_shape): raise RuntimeError( "the shape of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_shape, false_shape ) ) if true_dtype != false_dtype: raise RuntimeError( "the dtype of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_dtype, false_dtype ) ) # in tf, the shape of different branched can be different, # for example output shape of branch A can be [-1] while branch B can be [1]. # Under this case, we should set output shape to be [-1] output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape))) output_dtypes.append(true_dtype) return output_shapes, output_dtypes
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name): g = parent_g.create_new_graph_with_same_config() name = utils.make_name("Identity") g.make_node( 'Identity', inputs=[chosen_cur_cond_val_out_name], outputs=['y'], name=name ) g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape)) return g
def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids, rank, loop_name): g = parent_g.create_new_graph_with_same_config() g.parent_graph = parent_g iter_name = utils.make_name("i") cond_name = utils.make_name("cond") fake_var_name = utils.make_name("fake_var") g.add_graph_input(iter_name, TensorProto.INT64, (1,)) # iteration_num g.add_graph_input(cond_name, TensorProto.BOOL, ()) # condition g.add_graph_input(fake_var_name, TensorProto.FLOAT, ()) # loop-carried dependency # get the i'th value of condition cond_input_id = gather_input_ids[0] cond_input_id_for_current_iter = get_inputs_for_current_iteration(g, cond_input_id, iter_name) # get the i'th value of true values true_input_id = gather_input_ids[1] true_input_id_for_current_iter = get_inputs_for_current_iteration(g, true_input_id, iter_name) # get the i'th value of false values false_input_id = gather_input_ids[2] false_input_id_for_current_iter = get_inputs_for_current_iteration(g, false_input_id, iter_name) input_ids_for_current_iter = [cond_input_id_for_current_iter, true_input_id_for_current_iter, false_input_id_for_current_iter] output_id = None rank -= 1 if rank >= 1: loop_1 = create_loop_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:], trip_count_input_ids, rank) output_id = loop_1.output[1] elif rank == 0: _, if_node_output_id = create_if_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:]) output_id = if_node_output_id output_identity_name = utils.make_name("loop_output") loop_output_id = utils.port_name(output_identity_name) g.make_node( 'Identity', [output_id], outputs=[loop_output_id], name=output_identity_name ) cond_identity_name = utils.make_name("cond_output") cond_output_id = utils.port_name(cond_identity_name) g.make_node( 'Identity', [cond_name], outputs=[cond_output_id], name=cond_identity_name ) fake_var_identity_name = utils.make_name("fake_var_output") fake_var_output_id = utils.port_name(fake_var_identity_name) g.make_node( 'Identity', [fake_var_name], outputs=[fake_var_output_id], name=fake_var_identity_name ) g.add_graph_output(cond_output_id, TensorProto.BOOL, ()) g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ()) # use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop. g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:])) return g
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # create a dummy loop to calculate the init condition init_cond_output = self._create_subgraph_initial_cond(cond_g_info) ## create Loop body graph with existing nodes body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: shape = out_tensor_value_info.shape utils.make_sure( shape is not None, "Conversion of Loop requries output shape [{}] exists". format(out_tensor_value_info.id)) out_tensor_value_info.shape = utils.create_vague_shape_like( shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input( input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]}) gather_node = loop_body_g.make_node( "Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]}) loop_body_g.replace_all_inputs( input_ta.consumer.id, data_node.output[0]) # ops=loop_body_g.get_nodes() ## create Loop node branches = {"body": loop_body_g} loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # todo(pengwa): we don't check the case where loop body won't be executed at all. ## create Loop body graph with existing nodes # replace condition graph's inputs to be cell graph's outputs, because we want condition graph # to consumer cell graph outputs. for loop_var in cond_g_info.dependent_vars: self.g.replace_all_inputs( cond_g_info.nodes, loop_var.switch_true_identity_output.id, loop_var.next_iteration_input.id) body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: out_tensor_value_info.shape = utils.create_vague_shape_like( out_tensor_value_info.shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input( input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]}) gather_node = loop_body_g.make_node( "Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]}) loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0]) ## create Loop node loop_node = self._create_loop_node(context, loop_props) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL loop_node.set_body_graph_as_attr("body", loop_body_g) logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL