def pfor(loop_fn, iters): """Equivalent to running `loop_fn` `iters` times and stacking the outputs. `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` times, with input from 0 to `iters - 1`, and stacking corresponding output of each iteration. However the implementation does not use a tf.while_loop. Instead it adds new operations to the graph that collectively compute the same value as what running `loop_fn` in a loop would compute. This is an experimental feature and currently has a lot of limitations: - There should be no data depenendency between the different iterations. For example, a future iteration should not depend on a value or side-effect of a previous iteration. - Stateful kernels may mostly not be supported since these often imply a data dependency or ordering of the iterations. We do support a limited set of such stateful kernels though (like RandomFoo, Variable operations like reads, etc). - Conversion works only on a limited set of kernels for which a converter has been registered. - loop_fn cannot currently contain control flow operations like tf.while_loop or tf.cond. - `loop_fn` should return nested structure of Tensors or Operations. However if an Operation is returned, it should have zero outputs. - The shape and dtype of `loop_fn` outputs should not depend on the input to loop_fn. Args: loop_fn: A function that takes an int32 scalar tf.Tensor object representing the iteration number, and returns a possibly nested structure of Tensor or Operation objects. iters: Number of iterations for which to run loop_fn. Returns: Returns a nested structure of stacked tensor objects with the same nested structure as the output of `loop_fn`. """ existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs)
def _pfor_impl(loop_fn, iters, fallback_to_while_loop, parallel_iterations=None, pfor_config=None): """Implementation of pfor.""" assert not context.executing_eagerly() loop_fn_has_config = _loop_fn_has_config(loop_fn) existing_ops = set(ops.get_default_graph().get_operations()) # Run the loop body with ops.name_scope("loop_body"): loop_var = array_ops.placeholder_with_default(0, shape=[]) if loop_fn_has_config: if pfor_config is None: pfor_config = PForConfig() pfor_config._set_iters(iters) # pylint: disable=protected-access loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) else: assert pfor_config is None loop_fn_outputs = loop_fn(loop_var) # Convert outputs to Tensor if needed. rewrap_as_ndarray = False tmp_loop_fn_outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): if (loop_fn_output is not None and not isinstance( loop_fn_output, (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): if isinstance(loop_fn_output, indexed_slices.IndexedSlices): logging.warn( "Converting %s to a dense representation may make it slow." " Alternatively, output the indices and values of the" " IndexedSlices separately, and handle the vectorized" " outputs directly." % loop_fn_output) loop_fn_output = ops.convert_to_tensor(loop_fn_output) elif isinstance(loop_fn_output, np_arrays.ndarray): loop_fn_output = loop_fn_output.data rewrap_as_ndarray = True else: loop_fn_output = ops.convert_to_tensor(loop_fn_output) tmp_loop_fn_outputs.append(loop_fn_output) loop_fn_outputs = nest.pack_sequence_as(loop_fn_outputs, tmp_loop_fn_outputs) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError( "parallel_iterations must be None or a positive integer") if parallel_iterations == 1: raise ValueError( "Found parallel_iterations == 1. Use for_loop instead.") iters_value = tensor_util.constant_value(iters) if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): output = converter.convert(loop_fn_output) if rewrap_as_ndarray: output = np_arrays.tensor_to_ndarray(output) outputs.append(output) return nest.pack_sequence_as(loop_fn_outputs, outputs) else: if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access raise ValueError( "Setting parallel_iterations currently unsupported if" " reductions across iterations are performed.") num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config) remaining_outputs = [] flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) for loop_fn_output in flattened_loop_fn_outputs: output = converter.convert(loop_fn_output) if rewrap_as_ndarray: output = np_arrays.tensor_to_ndarray(output) remaining_outputs.append(output) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ ops.convert_to_tensor(x).dtype for x in flattened_loop_fn_outputs ] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i, pfor_config=None): if loop_fn_has_config: return nest.flatten( loop_fn(i + offset, pfor_config=pfor_config)) else: return nest.flatten(loop_fn(i + offset)) return _pfor_impl( tiled_loop_fn, parallel_iterations, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config) tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs] with ops.name_scope("pfor"): iters_value = tensor_util.constant_value(iters) if iters_value is None or iters_value % parallel_iterations: outputs = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_outputs, lambda: [ array_ops.concat([x, y], axis=0) for x, y in zip(remaining_outputs, tiled_outputs) ]) else: outputs = tiled_outputs flattened_outputs = nest.flatten(outputs) if rewrap_as_ndarray: flattened_outputs = [ np_arrays.tensor_to_ndarray(x) for x in flattened_outputs ] return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
def _pfor_impl(loop_fn, iters, parallel_iterations=None): """Implementation of pfor.""" existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError( "parallel_iterations must be None or a positive integer") if parallel_iterations == 1: raise ValueError( "Found parallel_iterations == 1. Use for_loop instead.") iters_value = tensor_util.constant_value(iters) if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs) else: num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops) remaining_outputs = [] flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) for loop_fn_output in flattened_loop_fn_outputs: remaining_outputs.append(converter.convert(loop_fn_output)) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ ops.convert_to_tensor(x).dtype for x in flattened_loop_fn_outputs ] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i): return nest.flatten(loop_fn(i + offset)) return pfor(tiled_loop_fn, parallel_iterations) tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs] with ops.name_scope("pfor"): iters_value = tensor_util.constant_value(iters) if iters_value is None or iters_value % parallel_iterations: outputs = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_outputs, lambda: [ array_ops.concat([x, y], axis=0) for x, y in zip(remaining_outputs, tiled_outputs) ]) else: outputs = tiled_outputs return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
def _pfor_impl(loop_fn, iters, fallback_to_while_loop, parallel_iterations=None, pfor_config=None, warn=False): """Implementation of pfor.""" assert not context.executing_eagerly() loop_fn_has_config = _loop_fn_has_config(loop_fn) existing_ops = set(ops.get_default_graph().get_operations()) iters_value = tensor_util.constant_value(iters) # Run the loop body with ops.name_scope("loop_body"): loop_var = array_ops.placeholder_with_default(0, shape=[]) if loop_fn_has_config: if pfor_config is None: pfor_config = PForConfig() pfor_config._set_iters(iters) # pylint: disable=protected-access loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) else: assert pfor_config is None f = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) loop_fn_outputs = f(loop_var) loop_fn_output_tensors = nest.map_structure(_composite_to_tensors, loop_fn_outputs) # Convert outputs to Tensor if needed. tmp_loop_fn_outputs = [] for loop_fn_output in nest.flatten(loop_fn_output_tensors): if (loop_fn_output is not None and not isinstance( loop_fn_output, (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): if isinstance(loop_fn_output, indexed_slices.IndexedSlices): logging.warn( "Converting %s to a dense representation may make it slow." " Alternatively, output the indices and values of the" " IndexedSlices separately, and handle the vectorized" " outputs directly." % loop_fn_output) loop_fn_output = ops.convert_to_tensor(loop_fn_output) else: loop_fn_output = ops.convert_to_tensor(loop_fn_output) tmp_loop_fn_outputs.append(loop_fn_output) loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors, tmp_loop_fn_outputs) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError( "Argument `parallel_iterations` must be None or a positive integer. " f"Received: {parallel_iterations}.") if parallel_iterations == 1: raise ValueError( "Found `parallel_iterations == 1`. Use `for_loop` instead.") if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config, warn=warn) flattened_output_tensors = [] for loop_fn_output in nest.flatten(loop_fn_output_tensors): output = converter.convert(loop_fn_output) flattened_output_tensors.append(output) else: if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access raise ValueError( "Setting `parallel_iterations` currently unsupported if " "reductions across iterations are performed.") num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config) remaining_output_tensors = [] flattened_output_tensors = nest.flatten(loop_fn_output_tensors) for loop_fn_output in flattened_output_tensors: output = converter.convert(loop_fn_output) remaining_output_tensors.append(output) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ ops.convert_to_tensor(x).dtype for x in flattened_output_tensors ] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i, pfor_config=None): if loop_fn_has_config: loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config) else: loop_fn_outputs = loop_fn(i + offset) return nest.flatten( # Stacking across iterations requires explicit Tensors. nest.map_structure(_composite_to_tensors, loop_fn_outputs)) return _pfor_impl( tiled_loop_fn, parallel_iterations, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config) tiled_output_tensors = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_output_tensors = [ _flatten_first_two_dims(y) for y in tiled_output_tensors ] with ops.name_scope("pfor"): if iters_value is None or iters_value % parallel_iterations: output_tensors = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_output_tensors, lambda: [ array_ops.concat([x, y], axis=0) # pylint: disable=g-long-lambda for x, y in zip(remaining_output_tensors, tiled_output_tensors) ]) else: output_tensors = tiled_output_tensors flattened_output_tensors = nest.flatten(output_tensors) for output, original_output in zip( flattened_output_tensors, nest.flatten(loop_fn_output_tensors)): # Restore any shape information lost from tiling. # TODO(b/174254748): this may not be correct for stacked `variant`s. output.set_shape( tensor_shape.TensorShape([iters_value]).concatenate( original_output.shape)) return nest.map_structure_up_to( loop_fn_outputs, functools.partial(_composite_from_tensors, batch_size=iters_value), nest.pack_sequence_as(loop_fn_output_tensors, flattened_output_tensors), loop_fn_outputs)
def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None): """Implementation of pfor.""" loop_fn_has_config = _loop_fn_has_config(loop_fn) existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) if loop_fn_has_config: if pfor_config is None: pfor_config = PForConfig() pfor_config._set_iters(iters) # pylint: disable=protected-access loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) else: assert pfor_config is None loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError( "parallel_iterations must be None or a positive integer") if parallel_iterations == 1: raise ValueError( "Found parallel_iterations == 1. Use for_loop instead.") iters_value = tensor_util.constant_value(iters) if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs) else: if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access raise ValueError( "Setting parallel_iterations currently unsupported if" " reductions across iterations are performed.") num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops, pfor_config=pfor_config) remaining_outputs = [] flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) for loop_fn_output in flattened_loop_fn_outputs: remaining_outputs.append(converter.convert(loop_fn_output)) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ ops.convert_to_tensor(x).dtype for x in flattened_loop_fn_outputs ] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i, pfor_config=None): if loop_fn_has_config: return nest.flatten( loop_fn(i + offset, pfor_config=pfor_config)) else: return nest.flatten(loop_fn(i + offset)) return _pfor_impl(tiled_loop_fn, parallel_iterations, pfor_config=pfor_config) tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs] with ops.name_scope("pfor"): iters_value = tensor_util.constant_value(iters) if iters_value is None or iters_value % parallel_iterations: outputs = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_outputs, lambda: [ array_ops.concat([x, y], axis=0) for x, y in zip(remaining_outputs, tiled_outputs) ]) else: outputs = tiled_outputs return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None): """Implementation of pfor.""" loop_fn_has_config = _loop_fn_has_config(loop_fn) existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) if loop_fn_has_config: if pfor_config is None: pfor_config = PForConfig() pfor_config._set_iters(iters) # pylint: disable=protected-access loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) else: assert pfor_config is None loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError("parallel_iterations must be None or a positive integer") if parallel_iterations == 1: raise ValueError("Found parallel_iterations == 1. Use for_loop instead.") iters_value = tensor_util.constant_value(iters) if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs) else: if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access raise ValueError("Setting parallel_iterations currently unsupported if" " reductions across iterations are performed.") num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops, pfor_config=pfor_config) remaining_outputs = [] flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) for loop_fn_output in flattened_loop_fn_outputs: remaining_outputs.append(converter.convert(loop_fn_output)) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ops.convert_to_tensor(x).dtype for x in flattened_loop_fn_outputs] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i, pfor_config=None): if loop_fn_has_config: return nest.flatten(loop_fn(i + offset, pfor_config=pfor_config)) else: return nest.flatten(loop_fn(i + offset)) return _pfor_impl( tiled_loop_fn, parallel_iterations, pfor_config=pfor_config) tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs] with ops.name_scope("pfor"): iters_value = tensor_util.constant_value(iters) if iters_value is None or iters_value % parallel_iterations: outputs = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_outputs, lambda: [array_ops.concat([x, y], axis=0) for x, y in zip(remaining_outputs, tiled_outputs)]) else: outputs = tiled_outputs return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
def pfor(loop_fn, iters, parallel_iterations=None): """Equivalent to running `loop_fn` `iters` times and stacking the outputs. `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` times, with input from 0 to `iters - 1`, and stacking corresponding output of each iteration. However the implementation does not use a tf.while_loop. Instead it adds new operations to the graph that collectively compute the same value as what running `loop_fn` in a loop would compute. This is an experimental feature and currently has a lot of limitations: - There should be no data depenendency between the different iterations. For example, a future iteration should not depend on a value or side-effect of a previous iteration. - Stateful kernels may mostly not be supported since these often imply a data dependency or ordering of the iterations. We do support a limited set of such stateful kernels though (like RandomFoo, Variable operations like reads, etc). - Conversion works only on a limited set of kernels for which a converter has been registered. - loop_fn has limited support for control flow operations. tf.cond in particular is not supported. - `loop_fn` should return nested structure of Tensors or Operations. However if an Operation is returned, it should have zero outputs. - The shape and dtype of `loop_fn` outputs should not depend on the input to loop_fn. Args: loop_fn: A function that takes an int32 scalar tf.Tensor object representing the iteration number, and returns a possibly nested structure of Tensor or Operation objects. Note that if setting `parallel_iterations` argument to something other than None, `loop_fn` may be called more than once during graph construction. So it may need to avoid mutating global state. iters: Number of iterations for which to run loop_fn. parallel_iterations: A knob to control how many iterations are vectorized and dispatched in parallel. The default value of None corresponds to vectorizing all the iterations. If `parallel_iterations` is smaller than `iters`, then chunks of at most that many iterations are dispatched in sequence. This knob can be used to control the total memory usage. Returns: Returns a nested structure of stacked tensor objects with the same nested structure as the output of `loop_fn`. Raises: ValueError: If parallel_iterations is not None and not an integer > 1. """ existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError("parallel_iterations must be None or a positive integer") if parallel_iterations == 1: raise ValueError("Found parallel_iterations == 1. Use for_loop instead.") iters_value = tensor_util.constant_value(iters) if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs) else: num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops) remaining_outputs = [] flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) for loop_fn_output in flattened_loop_fn_outputs: remaining_outputs.append(converter.convert(loop_fn_output)) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ops.convert_to_tensor(x).dtype for x in flattened_loop_fn_outputs] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i): return nest.flatten(loop_fn(i + offset)) return pfor(tiled_loop_fn, parallel_iterations) tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs] with ops.name_scope("pfor"): iters_value = tensor_util.constant_value(iters) if iters_value is None or iters_value % parallel_iterations: outputs = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_outputs, lambda: [array_ops.concat([x, y], axis=0) for x, y in zip(remaining_outputs, tiled_outputs)]) else: outputs = tiled_outputs return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
def _pfor_impl(loop_fn, iters, parallel_iterations=None): """Implementation of pfor.""" existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError("parallel_iterations must be None or a positive integer") if parallel_iterations == 1: raise ValueError("Found parallel_iterations == 1. Use for_loop instead.") iters_value = tensor_util.constant_value(iters) if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs) else: num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops) remaining_outputs = [] flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) for loop_fn_output in flattened_loop_fn_outputs: remaining_outputs.append(converter.convert(loop_fn_output)) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ops.convert_to_tensor(x).dtype for x in flattened_loop_fn_outputs] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i): return nest.flatten(loop_fn(i + offset)) return pfor(tiled_loop_fn, parallel_iterations) tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs] with ops.name_scope("pfor"): iters_value = tensor_util.constant_value(iters) if iters_value is None or iters_value % parallel_iterations: outputs = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_outputs, lambda: [array_ops.concat([x, y], axis=0) for x, y in zip(remaining_outputs, tiled_outputs)]) else: outputs = tiled_outputs return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))