def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars), expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
def _make_indexed_slices_indices_types_match(op_type, branch_graphs): """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" assert branch_graphs indexed_slice_indices = [] current_index = 0 branch_outputs_flat_with_composites = [ nest.flatten(branch_graph.structured_outputs, expand_composites=False) for branch_graph in branch_graphs ] outs_per_branch = [ len(outs) for outs in branch_outputs_flat_with_composites ] assert len(set(outs_per_branch)) == 1, outs_per_branch # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for output_idx, branch_outs in enumerate( zip(*branch_outputs_flat_with_composites)): if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1: raise TypeError( "Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" " branches returned: {outputs}".format( op_name="cond" if op_type == _COND else "switch_case", output_idx=output_idx, outputs=branch_outs)) if isinstance(branch_outs[0], ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(branch_outs[0]): current_index += len( nest.flatten(branch_outs[0], expand_composites=True)) else: current_index += 1 if not indexed_slice_indices: return if current_index != len(branch_graphs[0].outputs): raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" "Expected: %i\n" "Actual: %i" % (current_index, len(branch_graphs[0].outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) for bg in branch_graphs): raise TypeError( "Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str([bg.outputs[index].dtype for bg in branch_graphs])) if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: for branch_graph in branch_graphs: if branch_graph.outputs[index].dtype == dtypes.int32: with branch_graph.as_default(): branch_graph.outputs[index] = math_ops.cast( branch_graph.outputs[index], dtypes.int64) for branch_graph in branch_graphs: branch_graph.structured_outputs = func_graph_module.pack_sequence_as( branch_graph.structured_outputs, branch_graph.outputs)
def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars), expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
def _make_indexed_slices_indices_types_match(true_graph, false_graph): """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs.""" indexed_slice_indices = [] current_index = 0 true_outputs_flat_with_composites = nest.flatten( true_graph.structured_outputs, expand_composites=False) false_outputs_flat_with_composites = nest.flatten( false_graph.structured_outputs, expand_composites=False) # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for idx, (true_out, false_out) in enumerate( zip(true_outputs_flat_with_composites, false_outputs_flat_with_composites)): if isinstance(true_out, ops.IndexedSlices) != isinstance( false_out, ops.IndexedSlices): raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" " true_fn returned: %s\n" " false_fn returned: %s" % (idx, true_out, false_out)) if isinstance(true_out, ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(true_out): current_index += len(nest.flatten(true_out, expand_composites=True)) else: current_index += 1 if not indexed_slice_indices: return if current_index != len(true_graph.outputs): raise ValueError("Insufficient elements in true_graph.outputs.\n" "Expected: %i\n" "Actual: %i" % (current_index, len(true_graph.outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(true_graph.outputs[index].dtype)) if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(false_graph.outputs[index].dtype)) if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype: if false_graph.outputs[index].dtype == dtypes.int32: with false_graph.as_default(): false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index], dtypes.int64) else: with true_graph.as_default(): true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index], dtypes.int64) true_graph.structured_outputs = func_graph_module.pack_sequence_as( true_graph.structured_outputs, true_graph.outputs) false_graph.structured_outputs = func_graph_module.pack_sequence_as( false_graph.structured_outputs, false_graph.outputs)
def _make_indexed_slices_indices_types_match(op_type, branch_graphs): """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" assert branch_graphs # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`. indexed_slice_indices = [] current_index = 0 # Note that this still contains Nones. We leave those in so that error # messages contain the correct indices. We handle the Nones later when # updating `current_index`. branch_outputs_flat_with_composites = [ nest.flatten(branch_graph.structured_outputs, expand_composites=False) for branch_graph in branch_graphs ] outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites] assert len(set(outs_per_branch)) == 1, outs_per_branch # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for output_idx, branch_outs in enumerate( zip(*branch_outputs_flat_with_composites)): if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1: raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" " branches returned: {outputs}".format( op_name="cond" if op_type == _COND else "switch_case", output_idx=output_idx, outputs=branch_outs)) if isinstance(branch_outs[0], ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(branch_outs[0]): current_index += len(nest.flatten(branch_outs[0], expand_composites=True)) elif branch_outs[0] is not None: # `FuncGraph.outputs` does not contain Nones so no need to update the # counter in that case. current_index += 1 if not indexed_slice_indices: return # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus # the Nones. if current_index != len(branch_graphs[0].outputs): raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" "Expected: %i\n" "Actual: %i" % (current_index, len(branch_graphs[0].outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) for bg in branch_graphs): raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str([bg.outputs[index].dtype for bg in branch_graphs])) if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: for branch_graph in branch_graphs: if branch_graph.outputs[index].dtype == dtypes.int32: with branch_graph.as_default(): branch_graph.outputs[index] = math_ops.cast( branch_graph.outputs[index], dtypes.int64) for branch_graph in branch_graphs: branch_graph.structured_outputs = _pack_sequence_as( branch_graph.structured_outputs, branch_graph.outputs)
def _make_indexed_slices_indices_types_match(true_graph, false_graph): """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs.""" indexed_slice_indices = [] current_index = 0 true_outputs_flat_with_composites = nest.flatten( true_graph.structured_outputs, expand_composites=False) false_outputs_flat_with_composites = nest.flatten( false_graph.structured_outputs, expand_composites=False) # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for idx, (true_out, false_out) in enumerate( zip(true_outputs_flat_with_composites, false_outputs_flat_with_composites)): if isinstance(true_out, ops.IndexedSlices) != isinstance( false_out, ops.IndexedSlices): raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" " true_fn returned: %s\n" " false_fn returned: %s" % (idx, true_out, false_out)) if isinstance(true_out, ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(true_out): current_index += len(nest.flatten(true_out, expand_composites=True)) else: current_index += 1 if not indexed_slice_indices: return if current_index != len(true_graph.outputs): raise ValueError("Insufficient elements in true_graph.outputs.\n" "Expected: %i\n" "Actual: %i" % (current_index, len(true_graph.outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError( "Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(true_graph.outputs[index].dtype)) if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError( "Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(false_graph.outputs[index].dtype)) if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype: if false_graph.outputs[index].dtype == dtypes.int32: with false_graph.as_default(): false_graph.outputs[index] = math_ops.cast( false_graph.outputs[index], dtypes.int64) else: with true_graph.as_default(): true_graph.outputs[index] = math_ops.cast( true_graph.outputs[index], dtypes.int64) true_graph.structured_outputs = func_graph_module.pack_sequence_as( true_graph.structured_outputs, true_graph.outputs) false_graph.structured_outputs = func_graph_module.pack_sequence_as( false_graph.structured_outputs, false_graph.outputs)