示例#1
0
def _create_zeros_for_none_grads(forward_graphs, grad_graphs):
  """Creates zeros for None out grads if atleast one branch has non-None grad.

  Args:
    forward_graphs: List of forward FuncGraphs.
    grad_graphs: List of grad FuncGraphs.
  """
  assert len(forward_graphs) == len(grad_graphs)
  branch_outputs = [g.structured_outputs for g in grad_graphs]
  num_outputs_per_branch = [len(outs) for outs in branch_outputs]
  assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch
  for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
    if (any(t is None for t in branch_outs) and
        any(t is not None for t in branch_outs)):
      for branch_index, t in enumerate(branch_outs):
        if t is None:
          with grad_graphs[branch_index].as_default():
            zeros = default_gradient.zeros_like(
                forward_graphs[branch_index].inputs[output_idx])
            grad_graphs[branch_index].structured_outputs[output_idx] = zeros

  for grad_graph in grad_graphs:
    grad_graph.outputs = [
        t for t in func_graph_module.flatten(grad_graph.structured_outputs)
        if t is not None
    ]
示例#2
0
def _make_output_composite_tensors_match(true_graph, false_graph):
    """Modifies true_graph and false_graph so they have the same output signature.

  Currently the only transformation implemented is turning a Tensor into an
  equivalent IndexedSlices if the other branch returns an IndexedSlices.
  Updates {true,false}_graph.{outputs,structured_outputs}.

  Args:
    true_graph: FuncGraph
    false_graph: FuncGraph

  Raises:
    TypeError: if a pair of outputs cannot be rewritten.
  """
    # Note: since this is only used for gradient graphs, we do not expect the
    # outputs to be structured (e.g. nested lists), and thus do not need to use
    # nest.flatten, etc.
    true_outputs = list(true_graph.structured_outputs)
    false_outputs = list(false_graph.structured_outputs)
    assert len(true_outputs) == len(false_outputs)

    for idx, (true_out,
              false_out) in enumerate(zip(true_outputs, false_outputs)):
        if type(true_out) == type(false_out):  # pylint: disable=unidiomatic-typecheck
            continue
        if (isinstance(true_out, ops.IndexedSlices)
                and isinstance(false_out, ops.Tensor)):
            with false_graph.as_default():
                false_outputs[idx] = math_ops._as_indexed_slices(false_out)
        elif (isinstance(true_out, ops.Tensor)
              and isinstance(false_out, ops.IndexedSlices)):
            with true_graph.as_default():
                true_outputs[idx] = math_ops._as_indexed_slices(true_out)
        else:
            raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
                            "  true_fn returned:  %s\n"
                            "  false_fn returned: %s" %
                            (idx, true_out, false_out))

    true_graph.structured_outputs = true_outputs
    true_graph.outputs = func_graph_module.flatten(true_outputs)
    false_graph.structured_outputs = false_outputs
    false_graph.outputs = func_graph_module.flatten(false_outputs)
示例#3
0
def _make_output_composite_tensors_match(true_graph, false_graph):
  """Rewrites {true,false}_graph's outputs to use the same _TensorLike classes.

  Currently the only transformation implemented is turning a Tensor into an
  equivalent IndexedSlices if the other branch returns an IndexedSlices.
  Updates {true,false}_graph.{outputs,structured_outputs}.

  Args:
    true_graph: FuncGraph
    false_graph: FuncGraph

  Raises:
    TypeError: if a pair of outputs cannot be rewritten.
  """
  # Note: since this is only used for gradient graphs, we do not expect the
  # outputs to be structured (e.g. nested lists), and thus do not need to use
  # nest.flatten, etc.
  true_outputs = list(true_graph.structured_outputs)
  false_outputs = list(false_graph.structured_outputs)
  assert len(true_outputs) == len(false_outputs)

  for idx, (true_out, false_out) in enumerate(zip(true_outputs, false_outputs)):
    if type(true_out) == type(false_out):  # pylint: disable=unidiomatic-typecheck
      continue
    if (isinstance(true_out, ops.IndexedSlices) and
        isinstance(false_out, ops.Tensor)):
      with false_graph.as_default():
        false_outputs[idx] = math_ops._as_indexed_slices(false_out)
    elif (isinstance(true_out, ops.Tensor) and
          isinstance(false_out, ops.IndexedSlices)):
      with true_graph.as_default():
        true_outputs[idx] = math_ops._as_indexed_slices(true_out)
    else:
      raise TypeError(
          "Cannot reconcile tf.cond %i-th outputs:\n"
          "  true_fn returned:  %s\n"
          "  false_fn returned: %s" % (idx, true_out, false_out))

  true_graph.structured_outputs = true_outputs
  true_graph.outputs = func_graph_module.flatten(true_outputs)
  false_graph.structured_outputs = false_outputs
  false_graph.outputs = func_graph_module.flatten(false_outputs)
def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
                                 output_slices, old_output_slices):
    """Updates graph with new IndexedSlices input/output.

  Updates graph's metadata to output the gradient computation defined by
  init_slices, input_slices, and output_slices, instead of outputting
  old_output_slices. Also returns a new version of loop_vars with init_slices
  replacing the old input.

  Args:
    graph: _WhileBodyGradFuncGraph.
    loop_vars: the inputs to graph.
    init_slices: the new IndexedSlices to use as input to graph.
    input_slices: the new IndexedSlices in graph that should be fed by
      init_slices.
    output_slices: the new IndexedSlices in graph that should be the
      corresponding output to input_slices.
    old_output_slices: the IndexedSlices in graph that are currently being
      output.

  Returns:
    New loop_vars to pass to graph.
  """
    structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
                                                   old_output_slices)
    # We assume that the component tensors of old_output_slices appear
    # sequentially in graph.outputs. We use the first of these tensors
    # as the reference index.
    flat_idx = _get_tensor_index_in_iterable(
        graph.outputs,
        func_graph.flatten(old_output_slices)[0])

    graph.structured_outputs[structured_idx] = output_slices
    graph.outputs = func_graph.flatten(graph.structured_outputs)

    graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) +
                    graph.inputs[flat_idx + 1:])

    return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx +
                                                                    1:]
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
  """Rewrites grad_output_slices to be a Tensor output.

  Args:
    body_grad_graph: _WhileBodyGradFuncGraph.
    grad_output_slices: IndexedSlices output of body_grad_graph.
  """
  with body_grad_graph.as_default():
    new_output = ops.convert_to_tensor_v2(grad_output_slices)

  idx = body_grad_graph.structured_outputs.index(grad_output_slices)
  body_grad_graph.structured_outputs[idx] = new_output
  body_grad_graph.outputs = func_graph.flatten(
      body_grad_graph.structured_outputs)
示例#6
0
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
    """Rewrites grad_output_slices to be a Tensor output.

  Args:
    body_grad_graph: _WhileBodyGradFuncGraph.
    grad_output_slices: IndexedSlices output of body_grad_graph.
  """
    with body_grad_graph.as_default():
        new_output = ops.convert_to_tensor_v2(grad_output_slices)

    idx = body_grad_graph.structured_outputs.index(grad_output_slices)
    body_grad_graph.structured_outputs[idx] = new_output
    body_grad_graph.outputs = func_graph.flatten(
        body_grad_graph.structured_outputs)
def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
                                 output_slices, old_output_slices):
  """Updates graph with new IndexedSlices input/output.

  Updates graph's metadata to output the gradient computation defined by
  init_slices, input_slices, and output_slices, instead of outputting
  old_output_slices. Also returns a new version of loop_vars with init_slices
  replacing the old input.

  Args:
    graph: _WhileBodyGradFuncGraph.
    loop_vars: the inputs to graph.
    init_slices: the new IndexedSlices to use as input to graph.
    input_slices: the new IndexedSlices in graph that should be fed by
      init_slices.
    output_slices: the new IndexedSlices in graph that should be the
      corresonding output to input_slices.
    old_output_slices: the IndexedSlices in graph that are currently
      being output.

  Returns:
    New loop_vars to pass to graph.
  """
  structured_idx = graph.structured_outputs.index(old_output_slices)
  # We assume that the component tensors of old_output_slices appear
  # sequentially in graph.outputs. We use the first of these tensors
  # as the reference index.
  flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0])

  graph.structured_outputs[structured_idx] = output_slices
  graph.outputs = func_graph.flatten(
      graph.structured_outputs)

  graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) +
                  graph.inputs[flat_idx + 1:])

  return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
示例#8
0
def _make_output_composite_tensors_match(op_type, branch_graphs):
    """Modifies each branch_graph's outputs to have the same output signature.

  Currently the only transformation implemented is turning a Tensor into an
  equivalent IndexedSlices if the other branch returns an IndexedSlices.
  Updates branch_graph.{outputs,structured_outputs} for each branch_graph in
  branch_graphs.

  Args:
    op_type: _COND or _CASE
    branch_graphs: `list` of `FuncGraph`

  Raises:
    TypeError: if a set of outputs cannot be rewritten.
  """
    # Note: since this is only used for gradient graphs, we do not expect the
    # outputs to be structured (e.g. nested lists), and thus do not need to use
    # nest.flatten, etc.
    assert branch_graphs
    branch_outputs = [g.structured_outputs for g in branch_graphs]
    outputs_per_branch = list(len(outs) for outs in branch_outputs)
    assert len(set(outputs_per_branch)) == 1, outputs_per_branch

    for output_idx, branch_outs in enumerate(zip(*branch_outputs)):
        if len(set(type(out) for out in branch_outs)) == 1:
            continue
        if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs):
            continue
        for branch_idx, branch_out in enumerate(branch_outs):
            if isinstance(branch_out, ops.IndexedSlices):
                continue
            elif isinstance(branch_out, ops.Tensor):
                with branch_graphs[branch_idx].as_default():
                    branch_outputs[branch_idx][
                        output_idx] = math_ops._as_indexed_slices(branch_out)
            else:
                raise TypeError(
                    "Cannot reconcile {op_name} {output_idx}-th outputs:\n"
                    "  outputs from all branches: {outputs}".format(
                        op_name="tf.cond"
                        if op_type == _COND else "tf.switch_case",
                        output_idx=output_idx,
                        outputs=branch_outs))

    for branch_graph, branch_outs in zip(branch_graphs, branch_outputs):
        branch_graph.structured_outputs = branch_outs
        branch_graph.outputs = [
            t for t in func_graph_module.flatten(branch_outs) if t is not None
        ]