def _get_func_graph_for_branch(name_attr_list): """Generates and returns a FuncGraph for the given branch.""" inputs = op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name) for external_t, internal_t in zip(inputs, func_graph.inputs): custom_gradient.copy_handle_data(external_t, internal_t) func_graph.reset_captures(zip(inputs, func_graph.inputs)) # Link the op so that the gradient code can use it. func_graph._forward_cond = op return func_graph
def _get_graph(while_op, func_attr_name): """Returns `FuncGraph` for the given function attribute. Args: while_op: The While Operation. func_attr_name: string Returns: `FuncGraph` """ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. input_shapes = [ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") ] func_name = while_op.get_attr(func_attr_name).name func_graph = util.get_func_graph(while_op, input_shapes, func_name) func_graph._while = while_op return func_graph