Ejemplo n.º 1
0
  def add_trace(self, *args, **kwargs):
    """Traces all functions with the same args and kwargs.

    Args:
      *args: Positional args passed to the original function.
      **kwargs: Keyword args passed to the original function.
    """
    args = list(args)
    kwargs = kwargs.copy()
    self.tracing = True
    for fn in self._functions.values():
      # TODO(kathywu): Replace arguments with broader shapes defined in the
      # input signature.
      if self._expects_training_arg:
        args, kwargs = utils.set_training_arg(False, self._training_arg_index,
                                              args, kwargs)
        fn.get_concrete_function(*args, **kwargs)
        args, kwargs = utils.set_training_arg(True, self._training_arg_index,
                                              args, kwargs)
        fn.get_concrete_function(*args, **kwargs)
      else:
        fn.get_concrete_function(*args, **kwargs)
    self.tracing = False
Ejemplo n.º 2
0
 def trace_with_training(value, fn=fn):
     utils.set_training_arg(value, self._training_arg_index,
                            args, kwargs)
     with K.learning_phase_scope(value):
         fn.get_concrete_function(*args, **kwargs)
Ejemplo n.º 3
0
 def trace_with_training(value, fn=fn):
     utils.set_training_arg(value, self._training_arg_index,
                            args, kwargs)
     add_trace_to_queue(fn, args, kwargs, value)