def _create_zeros_for_none_grads(forward_graphs, grad_graphs): """Creates zeros for None out grads if atleast one branch has non-None grad. Args: forward_graphs: List of forward FuncGraphs. grad_graphs: List of grad FuncGraphs. """ assert len(forward_graphs) == len(grad_graphs) branch_outputs = [g.structured_outputs for g in grad_graphs] num_outputs_per_branch = [len(outs) for outs in branch_outputs] assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch for output_idx, branch_outs in enumerate(zip(*branch_outputs)): if (any(t is None for t in branch_outs) and any(t is not None for t in branch_outs)): for branch_index, t in enumerate(branch_outs): if t is None: with grad_graphs[branch_index].as_default(): zeros = default_gradient.zeros_like( forward_graphs[branch_index].inputs[output_idx]) grad_graphs[branch_index].structured_outputs[output_idx] = zeros for grad_graph in grad_graphs: grad_graph.outputs = [ t for t in func_graph_module.flatten(grad_graph.structured_outputs) if t is not None ]
def _make_output_composite_tensors_match(true_graph, false_graph): """Modifies true_graph and false_graph so they have the same output signature. Currently the only transformation implemented is turning a Tensor into an equivalent IndexedSlices if the other branch returns an IndexedSlices. Updates {true,false}_graph.{outputs,structured_outputs}. Args: true_graph: FuncGraph false_graph: FuncGraph Raises: TypeError: if a pair of outputs cannot be rewritten. """ # Note: since this is only used for gradient graphs, we do not expect the # outputs to be structured (e.g. nested lists), and thus do not need to use # nest.flatten, etc. true_outputs = list(true_graph.structured_outputs) false_outputs = list(false_graph.structured_outputs) assert len(true_outputs) == len(false_outputs) for idx, (true_out, false_out) in enumerate(zip(true_outputs, false_outputs)): if type(true_out) == type(false_out): # pylint: disable=unidiomatic-typecheck continue if (isinstance(true_out, ops.IndexedSlices) and isinstance(false_out, ops.Tensor)): with false_graph.as_default(): false_outputs[idx] = math_ops._as_indexed_slices(false_out) elif (isinstance(true_out, ops.Tensor) and isinstance(false_out, ops.IndexedSlices)): with true_graph.as_default(): true_outputs[idx] = math_ops._as_indexed_slices(true_out) else: raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" " true_fn returned: %s\n" " false_fn returned: %s" % (idx, true_out, false_out)) true_graph.structured_outputs = true_outputs true_graph.outputs = func_graph_module.flatten(true_outputs) false_graph.structured_outputs = false_outputs false_graph.outputs = func_graph_module.flatten(false_outputs)
def _make_output_composite_tensors_match(true_graph, false_graph): """Rewrites {true,false}_graph's outputs to use the same _TensorLike classes. Currently the only transformation implemented is turning a Tensor into an equivalent IndexedSlices if the other branch returns an IndexedSlices. Updates {true,false}_graph.{outputs,structured_outputs}. Args: true_graph: FuncGraph false_graph: FuncGraph Raises: TypeError: if a pair of outputs cannot be rewritten. """ # Note: since this is only used for gradient graphs, we do not expect the # outputs to be structured (e.g. nested lists), and thus do not need to use # nest.flatten, etc. true_outputs = list(true_graph.structured_outputs) false_outputs = list(false_graph.structured_outputs) assert len(true_outputs) == len(false_outputs) for idx, (true_out, false_out) in enumerate(zip(true_outputs, false_outputs)): if type(true_out) == type(false_out): # pylint: disable=unidiomatic-typecheck continue if (isinstance(true_out, ops.IndexedSlices) and isinstance(false_out, ops.Tensor)): with false_graph.as_default(): false_outputs[idx] = math_ops._as_indexed_slices(false_out) elif (isinstance(true_out, ops.Tensor) and isinstance(false_out, ops.IndexedSlices)): with true_graph.as_default(): true_outputs[idx] = math_ops._as_indexed_slices(true_out) else: raise TypeError( "Cannot reconcile tf.cond %i-th outputs:\n" " true_fn returned: %s\n" " false_fn returned: %s" % (idx, true_out, false_out)) true_graph.structured_outputs = true_outputs true_graph.outputs = func_graph_module.flatten(true_outputs) false_graph.structured_outputs = false_outputs false_graph.outputs = func_graph_module.flatten(false_outputs)
def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, output_slices, old_output_slices): """Updates graph with new IndexedSlices input/output. Updates graph's metadata to output the gradient computation defined by init_slices, input_slices, and output_slices, instead of outputting old_output_slices. Also returns a new version of loop_vars with init_slices replacing the old input. Args: graph: _WhileBodyGradFuncGraph. loop_vars: the inputs to graph. init_slices: the new IndexedSlices to use as input to graph. input_slices: the new IndexedSlices in graph that should be fed by init_slices. output_slices: the new IndexedSlices in graph that should be the corresponding output to input_slices. old_output_slices: the IndexedSlices in graph that are currently being output. Returns: New loop_vars to pass to graph. """ structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs, old_output_slices) # We assume that the component tensors of old_output_slices appear # sequentially in graph.outputs. We use the first of these tensors # as the reference index. flat_idx = _get_tensor_index_in_iterable( graph.outputs, func_graph.flatten(old_output_slices)[0]) graph.structured_outputs[structured_idx] = output_slices graph.outputs = func_graph.flatten(graph.structured_outputs) graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) + graph.inputs[flat_idx + 1:]) return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): """Rewrites grad_output_slices to be a Tensor output. Args: body_grad_graph: _WhileBodyGradFuncGraph. grad_output_slices: IndexedSlices output of body_grad_graph. """ with body_grad_graph.as_default(): new_output = ops.convert_to_tensor_v2(grad_output_slices) idx = body_grad_graph.structured_outputs.index(grad_output_slices) body_grad_graph.structured_outputs[idx] = new_output body_grad_graph.outputs = func_graph.flatten( body_grad_graph.structured_outputs)
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): """Rewrites grad_output_slices to be a Tensor output. Args: body_grad_graph: _WhileBodyGradFuncGraph. grad_output_slices: IndexedSlices output of body_grad_graph. """ with body_grad_graph.as_default(): new_output = ops.convert_to_tensor_v2(grad_output_slices) idx = body_grad_graph.structured_outputs.index(grad_output_slices) body_grad_graph.structured_outputs[idx] = new_output body_grad_graph.outputs = func_graph.flatten( body_grad_graph.structured_outputs)
def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, output_slices, old_output_slices): """Updates graph with new IndexedSlices input/output. Updates graph's metadata to output the gradient computation defined by init_slices, input_slices, and output_slices, instead of outputting old_output_slices. Also returns a new version of loop_vars with init_slices replacing the old input. Args: graph: _WhileBodyGradFuncGraph. loop_vars: the inputs to graph. init_slices: the new IndexedSlices to use as input to graph. input_slices: the new IndexedSlices in graph that should be fed by init_slices. output_slices: the new IndexedSlices in graph that should be the corresonding output to input_slices. old_output_slices: the IndexedSlices in graph that are currently being output. Returns: New loop_vars to pass to graph. """ structured_idx = graph.structured_outputs.index(old_output_slices) # We assume that the component tensors of old_output_slices appear # sequentially in graph.outputs. We use the first of these tensors # as the reference index. flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0]) graph.structured_outputs[structured_idx] = output_slices graph.outputs = func_graph.flatten( graph.structured_outputs) graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) + graph.inputs[flat_idx + 1:]) return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
def _make_output_composite_tensors_match(op_type, branch_graphs): """Modifies each branch_graph's outputs to have the same output signature. Currently the only transformation implemented is turning a Tensor into an equivalent IndexedSlices if the other branch returns an IndexedSlices. Updates branch_graph.{outputs,structured_outputs} for each branch_graph in branch_graphs. Args: op_type: _COND or _CASE branch_graphs: `list` of `FuncGraph` Raises: TypeError: if a set of outputs cannot be rewritten. """ # Note: since this is only used for gradient graphs, we do not expect the # outputs to be structured (e.g. nested lists), and thus do not need to use # nest.flatten, etc. assert branch_graphs branch_outputs = [g.structured_outputs for g in branch_graphs] outputs_per_branch = list(len(outs) for outs in branch_outputs) assert len(set(outputs_per_branch)) == 1, outputs_per_branch for output_idx, branch_outs in enumerate(zip(*branch_outputs)): if len(set(type(out) for out in branch_outs)) == 1: continue if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs): continue for branch_idx, branch_out in enumerate(branch_outs): if isinstance(branch_out, ops.IndexedSlices): continue elif isinstance(branch_out, ops.Tensor): with branch_graphs[branch_idx].as_default(): branch_outputs[branch_idx][ output_idx] = math_ops._as_indexed_slices(branch_out) else: raise TypeError( "Cannot reconcile {op_name} {output_idx}-th outputs:\n" " outputs from all branches: {outputs}".format( op_name="tf.cond" if op_type == _COND else "tf.switch_case", output_idx=output_idx, outputs=branch_outs)) for branch_graph, branch_outs in zip(branch_graphs, branch_outputs): branch_graph.structured_outputs = branch_outs branch_graph.outputs = [ t for t in func_graph_module.flatten(branch_outs) if t is not None ]