Esempio n. 1
0
 def _get_func_graph_for_branch(name_attr_list):
     """Generates and returns a FuncGraph for the given branch."""
     inputs = op.inputs[1:]  # First input is pred.
     input_shapes = [t.shape for t in inputs]
     func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name)
     for external_t, internal_t in zip(inputs, func_graph.inputs):
         custom_gradient.copy_handle_data(external_t, internal_t)
     func_graph.reset_captures(zip(inputs, func_graph.inputs))
     # Link the op so that the gradient code can use it.
     func_graph._forward_cond = op
     return func_graph
Esempio n. 2
0
def _get_graph(while_op, func_attr_name):
  """Returns `FuncGraph` for the given function attribute.

  Args:
    while_op: The While Operation.
    func_attr_name: string

  Returns:
    `FuncGraph`
  """
  # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
  input_shapes = [
      tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
  ]
  func_name = while_op.get_attr(func_attr_name).name
  func_graph = util.get_func_graph(while_op, input_shapes, func_name)
  func_graph._while = while_op
  return func_graph