def add_trace(self, *args, **kwargs): """Traces all functions with the same args and kwargs. Args: *args: Positional args passed to the original function. **kwargs: Keyword args passed to the original function. """ args = list(args) kwargs = kwargs.copy() self.tracing = True for fn in self._functions.values(): # TODO(kathywu): Replace arguments with broader shapes defined in the # input signature. if self._expects_training_arg: args, kwargs = utils.set_training_arg(False, self._training_arg_index, args, kwargs) fn.get_concrete_function(*args, **kwargs) args, kwargs = utils.set_training_arg(True, self._training_arg_index, args, kwargs) fn.get_concrete_function(*args, **kwargs) else: fn.get_concrete_function(*args, **kwargs) self.tracing = False
def trace_with_training(value, fn=fn): utils.set_training_arg(value, self._training_arg_index, args, kwargs) with K.learning_phase_scope(value): fn.get_concrete_function(*args, **kwargs)
def trace_with_training(value, fn=fn): utils.set_training_arg(value, self._training_arg_index, args, kwargs) add_trace_to_queue(fn, args, kwargs, value)