def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" # Note: the target function is converted to graph even when in Eager mode, # so autograph is on by default here. fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx()) return self.extended.tpu_run(fn, args, kwargs)
def experimental_run_v2(self, fn, args=(), kwargs=None): """See base class.""" fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx()) return self.extended.tpu_run(fn, args, kwargs)
def inference(self, inputs, *args, **kwargs): call_context = base_layer_utils.call_context() input_list = nest.flatten(inputs) # We will attempt to build a TF graph if & only if all inputs are symbolic. # This is always the case in graph mode. It can also be the case in eager # mode when all inputs can be traced back to `keras.Input()` (when building # models using the functional API). build_graph = tf_utils.are_all_symbolic_tensors(input_list) # Accept NumPy and scalar inputs by converting to Tensors. if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): def _convert_non_tensor(x): # Don't call `ops.convert_to_tensor` on all `inputs` because # `SparseTensors` can't be converted to `Tensor`. if isinstance(x, (np.ndarray, float, int)): return ops.convert_to_tensor(x) return x inputs = nest.map_structure(_convert_non_tensor, inputs) input_list = nest.flatten(inputs) # Handle `mask` propagation from previous layer to current layer. Masks can # be propagated explicitly via the `mask` argument, or implicitly via # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed # explicitly take priority. mask_arg_passed_by_framework = False input_masks = self._collect_input_masks(inputs, args, kwargs) if (self._expects_mask_arg and input_masks is not None and not self._call_arg_was_passed('mask', args, kwargs)): mask_arg_passed_by_framework = True kwargs['mask'] = input_masks # If `training` argument was not explicitly passed, propagate `training` # value from this layer's calling layer. training_arg_passed_by_framework = False # Priority 1: `training` was explicitly passed. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: kwargs.pop('training') else: training_value = None # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training # Priority 3a: `learning_phase()` has been set. elif backend.global_learning_phase_is_set(): training_value = backend.learning_phase() # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. elif build_graph: with backend.get_graph().as_default(): if base_layer_utils.is_in_keras_graph(): training_value = backend.learning_phase() if self._expects_training_arg and training_value is not None: # Force the training_value to be bool type which matches to the contract # for layer/model call args. if tensor_util.is_tensor(training_value): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) kwargs['training'] = training_value training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a # `keras.Input`. Otherwise this Layer may be being used outside the Keras # framework. if build_graph and base_layer_utils.needs_keras_history(inputs): base_layer_utils.create_keras_history(inputs) # Clear eager losses on top level model call. # We are clearing the losses only on the top level model call and not on # every layer/model call because layer/model may be reused. if (base_layer_utils.is_in_eager_or_tf_function() and not call_context.in_call): self._clear_losses() with call_context.enter(self, inputs, build_graph, training_value): # Check input assumptions set after layer building, e.g. input shape. if build_graph: # Symbolic execution on symbolic tensors. We will attempt to build # the corresponding TF subgraph inside `backend.get_graph()` # TODO(reedwm): We should assert input compatibility after the inputs # are casted, not before. input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list) and self._supports_ragged_inputs is False): # pylint: disable=g-bool-id-comparison raise ValueError('Layer %s does not support RaggedTensors as input. ' 'Inputs received: %s. You can try converting your ' 'input to an uniform tensor.' % (self.name, inputs)) graph = backend.get_graph() with graph.as_default(), backend.name_scope(self._name_scope()): # Build layer if applicable (if the `build` method has been # overridden). self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs) # Wrapping `call` function in autograph to allow for dynamic control # flow and control dependencies in call. We are limiting this to # subclassed layers as autograph is strictly needed only for # subclassed layers and models. # tf_convert will respect the value of autograph setting in the # enclosing tf.function, if any. if (base_layer_utils.is_subclassed(self) and not base_layer_utils.from_saved_model(self)): call_fn = autograph.tf_convert( self._inference, ag_ctx.control_status_ctx()) else: call_fn = self._inference if not self.dynamic: try: with base_layer_utils.autocast_context_manager( self._compute_dtype): # Add auto_control_deps in V2 when they are not already added by # a `tf.function`. if (ops.executing_eagerly_outside_functions() and not base_layer_utils.is_in_eager_or_tf_function()): with auto_control_deps.AutomaticControlDependencies() as acd: outputs = call_fn(cast_inputs, *args, **kwargs) # Wrap Tensors in `outputs` in `tf.identity` to avoid # circular dependencies. outputs = base_layer_utils.mark_as_return(outputs, acd) else: outputs = call_fn(cast_inputs, *args, **kwargs) except errors.OperatorNotAllowedInGraphError as e: raise TypeError('You are attempting to use Python control ' 'flow in a layer that was not declared to be ' 'dynamic. Pass `dynamic=True` to the class ' 'constructor.\nEncountered error:\n"""\n' + str(e) + '\n"""') else: # We will use static shape inference to return symbolic tensors # matching the specifications of the layer outputs. # Since `self.dynamic` is True, we will never attempt to # run the underlying TF graph (which is disconnected). # TODO(fchollet): consider py_func as an alternative, which # would enable us to run the underlying graph if needed. outputs = self._symbolic_call(inputs) if outputs is None: raise ValueError('A layer\'s `call` method should return a ' 'Tensor or a list of Tensors, not None ' '(layer: ' + self.name + ').') if base_layer_utils.have_all_keras_metadata(inputs): if training_arg_passed_by_framework: kwargs.pop('training') if mask_arg_passed_by_framework: kwargs.pop('mask') inputs, outputs = self._set_connectivity_metadata_( inputs, outputs, args, kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) if hasattr(self, '_set_inputs') and not self.inputs: # Subclassed network: explicitly set metadata normally set by # a call to self._set_inputs(). # TODO(b/120997007): This should be done in Eager as well, but # causes garbage collection issues because of the placeholders # created on the default Keras graph. self._set_inputs(inputs, outputs) else: # Eager execution on data tensors. with backend.name_scope(self._name_scope()): self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs) with base_layer_utils.autocast_context_manager( self._compute_dtype): outputs = self._inference(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) return outputs
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): """Converts a function call inline. For internal use only. Note: The argument list is optimized for readability of generated code, which may look like this: ag__.converted_call(f, (arg1, arg2), None, fscope) ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) Args: f: The function to convert. args: Tuple, the original positional arguments of f kwargs: Optional[Dict], the original keyword arguments of f caller_fn_scope: Optional[function_wrappers.FunctionScope], the function scope of the converted function in which this call was originally made. options: Optional[converter.ConversionOptions], conversion options. If not specified, the value of caller_fn_scope.callopts is used. Either options or caller_fn_scope must be present. Returns: Any, the result of executing a possibly-converted `f` with the given arguments. """ logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, kwargs) if options is None: if caller_fn_scope is None: raise ValueError( 'either caller_fn_scope or options must have a value') options = caller_fn_scope.callopts if conversion.is_in_allowlist_cache(f, options): logging.log(2, 'Allowlisted %s: from cache', f) return _call_unconverted(f, args, kwargs, options, False) if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f) return _call_unconverted(f, args, kwargs, options, False) if is_autograph_artifact(f): logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) return _call_unconverted(f, args, kwargs, options) # If this is a partial, unwrap it and redo all the checks. if isinstance(f, functools.partial): new_kwargs = {} if f.keywords is not None: # Use copy to avoid mutating the underlying keywords. new_kwargs = f.keywords.copy() if kwargs is not None: new_kwargs.update(kwargs) new_args = f.args + args logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, new_kwargs) return converted_call(f.func, new_args, new_kwargs, caller_fn_scope=caller_fn_scope, options=options) if inspect_utils.isbuiltin(f): if f is eval: return py_builtins.eval_in_original_context( f, args, caller_fn_scope) if f is super: return py_builtins.super_in_original_context( f, args, caller_fn_scope) if f is globals: return py_builtins.globals_in_original_context(caller_fn_scope) if f is locals: return py_builtins.locals_in_original_context(caller_fn_scope) if kwargs: return py_builtins.overload_of(f)(*args, **kwargs) else: return py_builtins.overload_of(f)(*args) if conversion.is_unsupported(f): return _call_unconverted(f, args, kwargs, options) if not options.user_requested and conversion.is_allowlisted(f): return _call_unconverted(f, args, kwargs, options) # internal_convert_user_code is for example turned off when issuing a dynamic # call conversion from generated code while in nonrecursive mode. In that # case we evidently don't want to recurse, but we still have to convert # things like builtins. if not options.internal_convert_user_code: return _call_unconverted(f, args, kwargs, options) try: if inspect.ismethod(f) or inspect.isfunction(f): target_entity = f effective_args = args f_self = getattr(f, '__self__', None) if f_self is not None: if isinstance(f_self, function.TfMethodTarget): f_self = f_self.target effective_args = (f_self, ) + effective_args elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): # Callable objects. Dunder methods have special lookup rules, see: # https://docs.python.org/3/reference/datamodel.html#specialnames # TODO(mdan): Recurse into converted_call to simplify other verifications. # This should be handled in the same way as partials. target_entity = f.__class__.__call__ effective_args = (f, ) + args else: target_entity = f raise NotImplementedError('unknown callable type "%s"' % type(f)) except Exception as e: # pylint:disable=broad-except logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) if is_autograph_strict_conversion_mode(): raise return _fall_back_unconverted(f, args, kwargs, options, e) if not hasattr(target_entity, '__code__'): logging.log(2, 'Permanently allowed: %s: native binding', target_entity) return _call_unconverted(f, args, kwargs, options) elif (hasattr(target_entity.__code__, 'co_filename') and target_entity.__code__.co_filename == '<string>'): # TODO(mdan): __globals__['txt'] might work in Py3. logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)', target_entity) return _call_unconverted(f, args, kwargs, options) try: program_ctx = converter.ProgramContext(options=options) converted_f = _convert_actual(target_entity, program_ctx) if logging.has_verbosity(2): _log_callargs(converted_f, effective_args, kwargs) except Exception as e: # pylint:disable=broad-except logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) if is_autograph_strict_conversion_mode(): raise return _fall_back_unconverted(f, args, kwargs, options, e) with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter(): try: if kwargs is not None: result = converted_f(*effective_args, **kwargs) else: result = converted_f(*effective_args) except Exception as e: _attach_error_metadata(e, converted_f) raise return result
def unspecified_fn(): self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED)
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): """Compiles a function call inline. For internal use only. Note: The argument list is optimized for readability of generated code, which may look like this: ag__.converted_call(f, (arg1, arg2), None, fscope) ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope) ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope) Args: f: The function to convert. args: Tuple, the original positional arguments of f kwargs: Optional[Dict], the original keyword arguments of f caller_fn_scope: Optional[function_wrappers.FunctionScope], the function scope of the converted function in which this call was originally made. options: Optional[converter.ConversionOptions], conversion options. If not specified, the value of caller_fn_scope.callopts is used. Either options or caller_fn_scope must be present. Returns: Any, the result of executing a possibly-converted `f` with the given arguments. """ logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args, kwargs) if options is None: if caller_fn_scope is None: raise ValueError( 'either caller_fn_scope or options must have a value') options = caller_fn_scope.callopts if conversion.is_in_whitelist_cache(f, options): logging.log(2, 'Whitelisted %s: from cache', f) return _call_unconverted(f, args, kwargs, options, False) if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f) return _call_unconverted(f, args, kwargs, options, False) if is_autograph_artifact(f): logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f) return _call_unconverted(f, args, kwargs, options) # If this is a partial, unwrap it and redo all the checks. if isinstance(f, functools.partial): new_kwargs = {} if f.keywords is not None: new_kwargs = f.keywords if kwargs is not None: new_kwargs.update(kwargs) new_args = f.args + args logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args, new_kwargs) return converted_call(f.func, new_args, new_kwargs, caller_fn_scope=caller_fn_scope, options=options) if inspect_utils.isbuiltin(f): if f is eval: return py_builtins.eval_in_original_context( f, args, caller_fn_scope) if f is super: return py_builtins.super_in_original_context( f, args, caller_fn_scope) if kwargs: return py_builtins.overload_of(f)(*args, **kwargs) else: return py_builtins.overload_of(f)(*args) # TODO(b/122265385): Remove this bypass. if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')): logging.warn( '{} appears to be decorated by wrapt, which is not yet supported' ' by AutoGraph. The function will run as-is.' ' You may still apply AutoGraph before the wrapt decorator.'. format(f)) logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f) return _call_unconverted(f, args, kwargs, options) if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'): logging.log(2, 'Permanently whitelisted: %s: lru_cache', f) return _call_unconverted(f, args, kwargs, options) # Constructors are permanently whitelisted. # TODO(mdan): Toggle as experimental feature instead. # TODO(b/124016764): Remove this limitation. if inspect_utils.isconstructor(f): logging.log(2, 'Permanently whitelisted: %s: constructor', f) return _call_unconverted(f, args, kwargs, options) # Other built-in modules are permanently whitelisted. # TODO(mdan): Figure out how to do this consistently for all stdlib modules. if any(f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)): logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f) return _call_unconverted(f, args, kwargs, options) # Custom ops and kernels are also permanently whitelisted. # See tensorflow.framework.load_library. if (hasattr(f, '__module__') and hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')): logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f) return _call_unconverted(f, args, kwargs, options) if not options.user_requested and conversion.is_whitelisted(f): return _call_unconverted(f, args, kwargs, options) # internal_convert_user_code is for example turned off when issuing a dynamic # call conversion from generated code while in nonrecursive mode. In that # case we evidently don't want to recurse, but we still have to convert # things like builtins. if not options.internal_convert_user_code: return _call_unconverted(f, args, kwargs, options) # TODO(mdan): Move this entire block inside to_graph. try: # Begin of transformation error guards if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f f_self = inspect_utils.getmethodself(f) # TODO(b/119246461): This may be more elegantly handled using __get__? if f_self is not None: effective_args = (f_self, ) + args else: effective_args = args elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'): # Callable objects. Dunder methods have special lookup rules, see: # https://docs.python.org/3/reference/datamodel.html#specialnames target_entity = f.__class__.__call__ effective_args = (f, ) + args else: target_entity = f raise NotImplementedError('unknown callable type "%s"' % type(f)) if not tf_inspect.isclass(target_entity): if not hasattr(target_entity, '__code__'): logging.log(2, 'Permanently whitelisted: %s: native binding', target_entity) return _call_unconverted(f, args, kwargs, options) elif (hasattr(target_entity.__code__, 'co_filename') and target_entity.__code__.co_filename == '<string>'): # TODO(mdan): __globals__['txt'] might work in Py3. logging.log( 2, 'Permanently whitelisted: %s: dynamic code (exec?)', target_entity) return _call_unconverted(f, args, kwargs, options) program_ctx = converter.ProgramContext( options=options, autograph_module=tf_inspect.getmodule(converted_call)) converted_f = conversion.convert(target_entity, program_ctx) if logging.has_verbosity(2): logging.log(2, 'Defaults of %s : %s', converted_f, converted_f.__defaults__) if not six.PY2: logging.log(2, 'KW defaults of %s : %s', converted_f, converted_f.__kwdefaults__) if kwargs is not None: callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs) else: callargs = tf_inspect.getcallargs(converted_f, *effective_args) formatted_callargs = '\n'.join(' {}: {}'.format(k, v) for k, v in callargs.items()) logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs) except Exception as e: # pylint:disable=broad-except logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True) if is_autograph_strict_conversion_mode(): raise if isinstance(e, errors.UnsupportedLanguageElementError): # Repeating the check made upon function entry because the state might # have updated in the meantime. if not conversion.is_in_whitelist_cache(f, options): logging.warn( 'AutoGraph could not transform %s and will run it as-is.\n' 'Cause: %s', target_entity, e) else: logging.warn( 'AutoGraph could not transform %s and will run it as-is.\n' 'Please report this to the TensorFlow team. When filing the bug, set' ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' ' attach the full output.\n' 'Cause: %s', target_entity, e) return _call_unconverted(f, args, kwargs, options) with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter(): try: if kwargs is not None: result = converted_f(*effective_args, **kwargs) else: result = converted_f(*effective_args) except Exception as e: _attach_metadata(e, converted_f) raise return result
def f(): self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) self.assertFalse(converter_testing.is_inside_generated_code())
def converted_fn(): self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED) unconverted_fn() self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED)
def g(): self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED) return 0
def f(): self.assertEqual(ag_ctx.control_status_ctx().status, ag_ctx.Status.UNSPECIFIED) if tf.reduce_sum([1, 2]) > 0: return -1 return 1
def _pfor_impl(loop_fn, iters, fallback_to_while_loop, parallel_iterations=None, pfor_config=None, warn=False): """Implementation of pfor.""" assert not context.executing_eagerly() loop_fn_has_config = _loop_fn_has_config(loop_fn) existing_ops = set(ops.get_default_graph().get_operations()) iters_value = tensor_util.constant_value(iters) # Run the loop body with ops.name_scope("loop_body"): loop_var = array_ops.placeholder_with_default(0, shape=[]) if loop_fn_has_config: if pfor_config is None: pfor_config = PForConfig() pfor_config._set_iters(iters) # pylint: disable=protected-access loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) else: assert pfor_config is None f = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) loop_fn_outputs = f(loop_var) loop_fn_output_tensors = nest.map_structure(_composite_to_tensors, loop_fn_outputs) # Convert outputs to Tensor if needed. tmp_loop_fn_outputs = [] for loop_fn_output in nest.flatten(loop_fn_output_tensors): if (loop_fn_output is not None and not isinstance( loop_fn_output, (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): if isinstance(loop_fn_output, indexed_slices.IndexedSlices): logging.warn( "Converting %s to a dense representation may make it slow." " Alternatively, output the indices and values of the" " IndexedSlices separately, and handle the vectorized" " outputs directly." % loop_fn_output) loop_fn_output = ops.convert_to_tensor(loop_fn_output) else: loop_fn_output = ops.convert_to_tensor(loop_fn_output) tmp_loop_fn_outputs.append(loop_fn_output) loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors, tmp_loop_fn_outputs) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: if parallel_iterations < 1: raise ValueError( "Argument `parallel_iterations` must be None or a positive integer. " f"Received: {parallel_iterations}.") if parallel_iterations == 1: raise ValueError( "Found `parallel_iterations == 1`. Use `for_loop` instead.") if iters_value is not None and iters_value < parallel_iterations: parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): converter = PFor(loop_var, iters, new_ops, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config, warn=warn) flattened_output_tensors = [] for loop_fn_output in nest.flatten(loop_fn_output_tensors): output = converter.convert(loop_fn_output) flattened_output_tensors.append(output) else: if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access raise ValueError( "Setting `parallel_iterations` currently unsupported if " "reductions across iterations are performed.") num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): converter = PFor(loop_var, num_remaining_iterations, new_ops, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config) remaining_output_tensors = [] flattened_output_tensors = nest.flatten(loop_fn_output_tensors) for loop_fn_output in flattened_output_tensors: output = converter.convert(loop_fn_output) remaining_output_tensors.append(output) with ops.name_scope("pfor_tiled"): loop_fn_dtypes = [ ops.convert_to_tensor(x).dtype for x in flattened_output_tensors ] def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations def tiled_loop_fn(i, pfor_config=None): if loop_fn_has_config: loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config) else: loop_fn_outputs = loop_fn(i + offset) return nest.flatten( # Stacking across iterations requires explicit Tensors. nest.map_structure(_composite_to_tensors, loop_fn_outputs)) return _pfor_impl( tiled_loop_fn, parallel_iterations, fallback_to_while_loop=fallback_to_while_loop, pfor_config=pfor_config) tiled_output_tensors = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) tiled_output_tensors = [ _flatten_first_two_dims(y) for y in tiled_output_tensors ] with ops.name_scope("pfor"): if iters_value is None or iters_value % parallel_iterations: output_tensors = control_flow_ops.cond( math_ops.equal(num_remaining_iterations, 0), lambda: tiled_output_tensors, lambda: [ array_ops.concat([x, y], axis=0) # pylint: disable=g-long-lambda for x, y in zip(remaining_output_tensors, tiled_output_tensors) ]) else: output_tensors = tiled_output_tensors flattened_output_tensors = nest.flatten(output_tensors) for output, original_output in zip( flattened_output_tensors, nest.flatten(loop_fn_output_tensors)): # Restore any shape information lost from tiling. # TODO(b/174254748): this may not be correct for stacked `variant`s. output.set_shape( tensor_shape.TensorShape([iters_value]).concatenate( original_output.shape)) return nest.map_structure_up_to( loop_fn_outputs, functools.partial(_composite_from_tensors, batch_size=iters_value), nest.pack_sequence_as(loop_fn_output_tensors, flattened_output_tensors), loop_fn_outputs)
def inner_fn_callee(): self.assertEqual( ag_ctx.control_status_ctx().status, ag_ctx.Status.DISABLED)
def call(self, y_true, y_pred): if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true): y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions( y_pred, y_true) ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx()) return ag_fn(y_true, y_pred, **self._fn_kwargs)