Exemple #1
0
def fn_args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
  _, fn = tf_decorator.unwrap(fn)

  # Handle callables.
  if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
    return tuple(tf_inspect.getargspec(fn.__call__).args)

  # Handle functools.partial and similar objects.
  if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
    # Handle nested partial.
    original_args = fn_args(fn.func)
    if not original_args:
      return tuple()

    return tuple([
        arg for arg in original_args[len(fn.args):]
        if arg not in set((fn.keywords or {}).keys())
    ])

  # Handle function.
  return tuple(tf_inspect.getargspec(fn).args)
  def testGetArgSpecOnPartialInvalidArgspec(self):
    """Tests getargspec on partial function that doesn't have valid argspec."""

    def func(m, n, l, k=4):
      return 2 * m + l + n * k

    partial_func = functools.partial(func, n=7)

    exception_message = (r"Some arguments \['l'\] do not have default value, "
                         "but they are positioned after those with default "
                         "values. This can not be expressed with ArgSpec.")
    with self.assertRaisesRegexp(ValueError, exception_message):
      tf_inspect.getargspec(partial_func)
  def testGetArgSpecOnPartialArgumentWithConvertibleToFalse(self):
    """Tests getargspec on partial function with args that convert to False."""

    def func(m, n):
      return 2 * m + n

    partial_func = functools.partial(func, m=0)

    exception_message = (r"Some arguments \['n'\] do not have default value, "
                         "but they are positioned after those with default "
                         "values. This can not be expressed with ArgSpec.")
    with self.assertRaisesRegexp(ValueError, exception_message):
      tf_inspect.getargspec(partial_func)
def _get_arg_infos(func, elementwise_args):
  """Returns `_ArgInfo`s for each `func` arg specified by `elementwise_args`.

  Args:
    func: The function whose arguments should be described.
    elementwise_args: The names of the arguments to get info for.

  Returns:
    A dictionary that maps both names and positions of arguments to
    `_ArgInfo` tuples.
  """
  arg_infos = {}

  # Inspect the func's argspec to find the position of each arg.
  arg_spec = tf_inspect.getargspec(func)
  for argname in elementwise_args:
    assert isinstance(argname, str)
    is_list = argname.startswith('[') and argname.endswith(']')
    if is_list:
      argname = argname[1:-1]
    assert argname in arg_spec.args, (func, argname, arg_spec.args)
    arg_info = _ArgInfo(argname, arg_spec.args.index(argname), is_list)
    arg_infos[arg_info.name] = arg_info
    arg_infos[arg_info.position] = arg_info
  return arg_infos
def assert_stmt(expression1, expression2):
  """Functional form of an assert statement.

  This follows the semantics of the Python assert statement, however the
  concrete implementations may deviate from it. See the respective
  implementation for details.

  In general, the assert statement should not be used for control flow.
  Furthermore, it is encouraged that the assertion expressions should not have
  side effects.

  Args:
    expression1: Any
    expression2: Callable[[], Any], returns the expression to include in the
        error message when expression1 evaluates to False. When expression1 is
        True, the result of expression2 will not be evaluated, however,
        expression2 itself may be evaluated in some implementations.

  Returns:
    Any, implementation-dependent.

  Raises:
    ValueError: if any arguments are illegal.
  """
  if not callable(expression2):
    raise ValueError('{} must be a callable'.format(expression2))
  args, _, keywords, _ = tf_inspect.getargspec(expression2)
  if args or keywords:
    raise ValueError('{} may not have any arguments'.format(expression2))

  if tensor_util.is_tensor(expression1):
    return _tf_assert_stmt(expression1, expression2)
  else:
    return _py_assert_stmt(expression1, expression2)
Exemple #6
0
 def test_reduction_ops(self):
   ops_to_test = [
       (keras.backend.max, np.max),
       (keras.backend.min, np.min),
       (keras.backend.sum, np.sum),
       (keras.backend.prod, np.prod),
       (keras.backend.var, np.var),
       (keras.backend.std, np.std),
       (keras.backend.mean, np.mean),
       (keras.backend.argmin, np.argmin),
       (keras.backend.argmax, np.argmax),
   ]
   for keras_op, np_op in ops_to_test:
     with self.test_session():
       compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
                                        keras_kwargs={'axis': 1},
                                        np_kwargs={'axis': 1})
       compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
                                        keras_kwargs={'axis': -1},
                                        np_kwargs={'axis': -1})
       if 'keepdims' in tf_inspect.getargspec(keras_op).args:
         compare_single_input_op_to_numpy(keras_op, np_op,
                                          input_shape=(4, 7, 5),
                                          keras_kwargs={'axis': 1,
                                                        'keepdims': True},
                                          np_kwargs={'axis': 1,
                                                     'keepdims': True})
def _make_prediction_gan_model(input_data, input_data_domain_label,
                               generator_fn, generator_scope):
  """Make a `StarGANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(
        generator_fn, mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    # pylint:disable=protected-access
    input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
    input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
        input_data_domain_label)
    # pylint:enable=protected-access
    generated_data = generator_fn(input_data, input_data_domain_label)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.StarGANModel(
      input_data=input_data,
      input_data_domain_label=None,
      generated_data=generated_data,
      generated_data_domain_target=input_data_domain_label,
      reconstructed_data=None,
      discriminator_input_data_source_predication=None,
      discriminator_generated_data_source_predication=None,
      discriminator_input_data_domain_predication=None,
      discriminator_generated_data_domain_predication=None,
      generator_variables=generator_variables,
      generator_scope=generator_scope,
      generator_fn=generator_fn,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
Exemple #8
0
 def end(self, session):
   self._last_step = None
   for m in self._monitors:
     if "session" in tf_inspect.getargspec(m.end).args:
       m.end(session=session)
     else:
       m.end()
  def export(self,
             estimator,
             export_path,
             checkpoint_path=None,
             eval_result=None):
    """Exports the given Estimator to a specific format.

    Args:
      estimator: the Estimator to export.
      export_path: A string containing a directory where to write the export.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the strategy may locate a checkpoint (e.g. the most recent) by itself.
      eval_result: The output of Estimator.evaluate on this checkpoint.  This
        should be set only if checkpoint_path is provided (otherwise it is
        unclear which checkpoint this eval refers to).

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if the export_fn does not have the required signature
    """
    # don't break existing export_fns that don't accept checkpoint_path and
    # eval_result
    export_fn_args = tf_inspect.getargspec(self.export_fn).args
    kwargs = {}
    if 'checkpoint_path' in export_fn_args:
      kwargs['checkpoint_path'] = checkpoint_path
    if 'eval_result' in export_fn_args:
      if 'checkpoint_path' not in export_fn_args:
        raise ValueError('An export_fn accepting eval_result must also accept '
                         'checkpoint_path.')
      kwargs['eval_result'] = eval_result

    return self.export_fn(estimator, export_path, **kwargs)
Exemple #10
0
  def __call__(self, func):
    # Various sanity checks on the callable func.
    if not callable(func):
      raise ValueError("func %s must be callable" % func)

    # Func should not use kwargs and defaults.
    argspec = tf_inspect.getargspec(func)
    if argspec.keywords or argspec.defaults:
      raise ValueError("Functions with argument defaults or keyword "
                       "arguments are not supported.")

    # Computes how many arguments 'func' has.
    min_args = len(argspec.args)
    max_args = min_args
    if argspec.varargs:
      max_args = 1000000
    argnames = argspec.args
    if tf_inspect.ismethod(func):
      # 1st argument is the "class" type.
      min_args -= 1
      argnames = argnames[1:]

    if self._input_types:
      # If Defun is given a list of types for the inputs, the number
      # of input types should be compatible with 'func'.
      num = len(self._input_types)
      if num < min_args or num > max_args:
        raise ValueError(
            "The function has fewer arguments than the number of specified "
            "input types.")
      return _DefinedFunction(
          func,
          argnames,
          self._input_types,
          self._func_name,
          self._grad_func,
          self._python_grad_func,
          out_names=self._out_names,
          **self._extra_kwargs)

    # 'func' expects no arguments and input types is an empty list.
    if min_args == 0 and max_args == 0:
      return _DefinedFunction(
          func, [], [],
          self._func_name,
          self._grad_func,
          self._python_grad_func,
          out_names=self._out_names,
          **self._extra_kwargs)

    # Input types are unknown. It's an overloaded function and hence
    # its definition needs to be deferred until it's called.
    return _OverloadedFunction(
        func,
        argnames,
        self._func_name,
        self._grad_func,
        self._python_grad_func,
        out_names=self._out_names,
        **self._extra_kwargs)
Exemple #11
0
  def call(self, inputs, training=None, mask=None):
    kwargs = {}
    func_args = tf_inspect.getargspec(self.layer.call).args
    if 'training' in func_args:
      kwargs['training'] = training
    if 'mask' in func_args:
      kwargs['mask'] = mask

    y = self.forward_layer.call(inputs, **kwargs)
    y_rev = self.backward_layer.call(inputs, **kwargs)
    if self.return_sequences:
      y_rev = K.reverse(y_rev, 1)
    if self.merge_mode == 'concat':
      output = K.concatenate([y, y_rev])
    elif self.merge_mode == 'sum':
      output = y + y_rev
    elif self.merge_mode == 'ave':
      output = (y + y_rev) / 2
    elif self.merge_mode == 'mul':
      output = y * y_rev
    elif self.merge_mode is None:
      output = [y, y_rev]

    # Properly set learning phase
    if 0 < self.layer.dropout + self.layer.recurrent_dropout:
      if self.merge_mode is None:
        for out in output:
          out._uses_learning_phase = True
      else:
        output._uses_learning_phase = True
    return output
def get_args(symbol):
  if hasattr(inspect, "signature"):
    signature = inspect.signature(symbol)
    # Ignore *args and **kwargs for now.
    return [param.name for param in signature.parameters.values()
            if param.kind == param.POSITIONAL_OR_KEYWORD]
  return tf_inspect.getargspec(symbol)[0]
Exemple #13
0
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
  """See recompute_grad."""
  has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
  for arg in args:
    if not isinstance(arg, framework_ops.Tensor):
      raise ValueError("All inputs to function must be Tensors")
  use_data_dep_ = use_data_dep
  if use_data_dep_ == _USE_DEFAULT:
    use_data_dep_ = _is_on_tpu()

  # Use custom_gradient and return a grad_fn that recomputes on the backwards
  # pass.
  @custom_gradient.custom_gradient
  def fn_with_recompute(*args):
    """Wrapper for fn."""
    # Capture the variable and arg scopes so we can re-enter them when
    # recomputing.
    vs = variable_scope.get_variable_scope()
    arg_scope = contrib_framework_ops.current_arg_scope()
    # Track all variables touched in the function.
    with backprop.GradientTape() as tape:
      fn_kwargs = {}
      if has_is_recompute_kwarg:
        fn_kwargs["is_recomputing"] = False
      outputs = fn(*args, **fn_kwargs)
    original_vars = set(tape.watched_variables())

    def _grad_fn(output_grads, variables=None):
      # Validate that custom_gradient passes the right variables into grad_fn.
      if original_vars:
        assert variables, ("Fn created variables but the variables were not "
                           "passed to the gradient fn.")
        if set(variables) != original_vars:
          raise ValueError(_WRONG_VARS_ERR)

      return _recomputing_grad_fn(
          compute_fn=fn,
          original_args=args,
          original_vars=original_vars,
          output_grads=output_grads,
          grad_fn_variables=variables,
          use_data_dep=use_data_dep_,
          tupleize_grads=tupleize_grads,
          arg_scope=arg_scope,
          var_scope=vs,
          has_is_recompute_kwarg=has_is_recompute_kwarg)

    # custom_gradient inspects the signature of the function to determine
    # whether the user expects variables passed in the grad_fn. If the function
    # created variables, the grad_fn should accept the "variables" kwarg.
    if original_vars:
      def grad_fn(*output_grads, **kwargs):
        return _grad_fn(output_grads, kwargs["variables"])
    else:
      def grad_fn(*output_grads):
        return _grad_fn(output_grads)

    return outputs, grad_fn

  return fn_with_recompute(*args)
Exemple #14
0
def kwarg_only(f):
  """A wrapper that throws away all non-kwarg arguments."""
  def wrapper(**kwargs):
    return f(**kwargs)

  return tf_decorator.make_decorator(
      f, wrapper, decorator_argspec=tf_inspect.getargspec(f))
Exemple #15
0
  def check_accepts(f):
    """Check the types."""
    spec = tf_inspect.getargspec(f)

    num_function_arguments = len(spec.args)
    if len(types) != num_function_arguments:
      raise Error(
          "Function %r has %d arguments but only %d types were provided in the "
          "annotation." % (f, num_function_arguments, len(types)))

    if spec.defaults:
      num_defaults = len(spec.defaults)
      for (name, a, t) in zip(spec.args[-num_defaults:],
                              spec.defaults,
                              types[-num_defaults:]):
        allowed_type = _replace_forward_references(t, f.__globals__)
        if not isinstance(a, allowed_type):
          raise Error("default argument value %r of type %r is not an instance "
                      "of the allowed type %s for the %s argument to %r"
                      % (a, type(a), _type_repr(allowed_type), name, f))

    @functools.wraps(f)
    def new_f(*args, **kwds):
      """A helper function."""
      for (a, t) in zip(args, types):
        allowed_type = _replace_forward_references(t, f.__globals__)
        if not isinstance(a, allowed_type):
          raise Error("%r of type %r is not an instance of the allowed type %s "
                      "for %r" % (a, type(a), _type_repr(allowed_type), f))
      return f(*args, **kwds)

    return new_f
Exemple #16
0
  def check_params(self, params):
    """Checks for user typos in "params".

    Arguments:
        params: dictionary; the parameters to be checked

    Raises:
        ValueError: if any member of `params` is not a valid argument.
    """
    legal_params_fns = [
        Sequential.fit, Sequential.predict, Sequential.predict_classes,
        Sequential.evaluate
    ]
    if self.build_fn is None:
      legal_params_fns.append(self.__call__)
    elif (not isinstance(self.build_fn, types.FunctionType) and
          not isinstance(self.build_fn, types.MethodType)):
      legal_params_fns.append(self.build_fn.__call__)
    else:
      legal_params_fns.append(self.build_fn)

    legal_params = []
    for fn in legal_params_fns:
      legal_params += tf_inspect.getargspec(fn)[0]
    legal_params = set(legal_params)

    for params_name in params:
      if params_name not in legal_params:
        if params_name != 'nb_epoch':
          raise ValueError('{} is not a legal parameter'.format(params_name))
def _loop_fn_has_config(loop_fn):
  """Test if `loop_fn` has a `pfor_config` argument."""
  if tf_inspect.isfunction(loop_fn):
    argspec = tf_inspect.getargspec(loop_fn)
    return PFOR_CONFIG_ARG in argspec.args
  elif isinstance(loop_fn, functools.partial):
    fn = loop_fn.func
    argspec = tf_inspect.getargspec(fn)
    return (PFOR_CONFIG_ARG in argspec.args and
            PFOR_CONFIG_ARG not in loop_fn.keywords)
  else:
    loop_class = tf_decorator.unwrap(loop_fn)[1]
    if not hasattr(loop_class, "__call__"):
      raise ValueError("loop_fn object did not have a __call__ method")
    argspec = tf_inspect.getargspec(loop_class.__call__)
    return PFOR_CONFIG_ARG in argspec.args
def _check_method_supports_args(method, kwargs):
  """Checks that the given method supports the given args."""
  supported_args = tuple(tf_inspect.getargspec(method).args)
  for kwarg in kwargs:
    if kwarg not in supported_args:
      raise ValueError(
          'Argument `{}` is not supported in method {}.'.format(kwarg, method))
Exemple #19
0
def _args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.
  """
  if hasattr(fn, 'func') and hasattr(fn, 'keywords'):
    # Handle functools.partial and similar objects.
    return tuple([
        arg for arg in tf_inspect.getargspec(fn.func).args
        if arg not in set(fn.keywords.keys())
    ])
  # Handle function.
  return tuple(tf_inspect.getargspec(fn).args)
Exemple #20
0
def _get_arg_spec(func):
  """Extracts signature information from a function or functools.partial object.

  For functions, uses `tf_inspect.getargspec`. For `functools.partial` objects,
  corrects the signature of the underlying function to take into account the
  removed arguments.

  Args:
    func: A function whose signature to extract.

  Returns:
    An `ArgSpec` namedtuple `(args, varargs, keywords, defaults)`, as returned
    by `tf_inspect.getargspec`.
  """
  # getargspec does not work for functools.partial objects directly.
  if isinstance(func, functools.partial):
    argspec = tf_inspect.getargspec(func.func)
    # Remove the args from the original function that have been used up.
    first_default_arg = (
        len(argspec.args or []) - len(argspec.defaults or []))
    partial_args = len(func.args)
    argspec_args = []

    if argspec.args:
      argspec_args = list(argspec.args[partial_args:])

    argspec_defaults = list(argspec.defaults or ())
    if argspec.defaults and partial_args > first_default_arg:
      argspec_defaults = list(argspec.defaults[partial_args-first_default_arg:])

    first_default_arg = max(0, first_default_arg - partial_args)
    for kwarg in (func.keywords or []):
      if kwarg in (argspec.args or []):
        i = argspec_args.index(kwarg)
        argspec_args.pop(i)
        if i >= first_default_arg:
          argspec_defaults.pop(i-first_default_arg)
        else:
          first_default_arg -= 1
    return tf_inspect.ArgSpec(args=argspec_args,
                              varargs=argspec.varargs,
                              keywords=argspec.keywords,
                              defaults=tuple(argspec_defaults))
  else:  # Regular function or method, getargspec will work fine.
    return tf_inspect.getargspec(func)
Exemple #21
0
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
    Returns:
      Output tensor(s).
    """
    self._set_scope(kwargs.pop('scope', None))

    # Ensure the Layer, if being reused, is working with inputs from
    # the same graph as where it was created.
    try:
      ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
    except ValueError as e:
      raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    with vs.variable_scope(self._scope,
                           reuse=self.built or self._reuse) as scope:
      with ops.name_scope(scope.original_name_scope):
        if not self.built:
          # Check input assumptions set before layer building, e.g. input rank.
          self._assert_input_compatibility(inputs)
          input_list = [
              ops.convert_to_tensor(x, name='input')
              for x in nest.flatten(inputs)]
          input_shapes = [x.get_shape() for x in input_list]
          if len(input_shapes) == 1:
            self.build(input_shapes[0])
          else:
            self.build(input_shapes)
        if 'scope' in tf_inspect.getargspec(self.call).args:
          kwargs['scope'] = scope
        # Check input assumptions set after layer building, e.g. input shape.
        self._assert_input_compatibility(inputs)
        outputs = self.call(inputs, *args, **kwargs)

        # Apply activity regularization.
        # Note that it should be applied every time the layer creates a new
        # output, since it is output-specific.
        if hasattr(self, 'activity_regularizer') and self.activity_regularizer:
          output_list = _to_list(outputs)
          for output in output_list:
            with ops.name_scope('ActivityRegularizer'):
              activity_regularization = self.activity_regularizer(output)
            self.add_loss(activity_regularization)
            _add_elements_to_collection(
                activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES)

    # Update global default collections.
    _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    self.built = True
    return outputs
  def testGetArgSpecOnDecoratorThatChangesArgspec(self):
    argspec = tf_inspect.ArgSpec(
        args=['a', 'b', 'c'],
        varargs=None,
        keywords=None,
        defaults=(1, 'hello'))

    decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
                                         argspec)
    self.assertEqual(argspec, tf_inspect.getargspec(decorator))
  def testGetArgSpecOnPartialNoArgumentsLeft(self):
    """Tests getargspec on partial function that prunes all arguments."""

    def func(m, n):
      return 2 * m + n

    partial_func = functools.partial(func, 7, 10)
    argspec = tf_inspect.ArgSpec(
        args=[], varargs=None, keywords=None, defaults=None)

    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
  def testGetArgSpecReturnsOutermostDecoratorThatChangesArgspec(self):
    outer_argspec = tf_inspect.ArgSpec(
        args=['a'], varargs=None, keywords=None, defaults=None)
    inner_argspec = tf_inspect.ArgSpec(
        args=['b'], varargs=None, keywords=None, defaults=None)

    inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
                                               '', inner_argspec)
    outer_decorator = tf_decorator.TFDecorator('', inner_decorator, '',
                                               outer_argspec)
    self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
  def testGetArgSpecIgnoresDecoratorsThatDontProvideArgspec(self):
    argspec = tf_inspect.ArgSpec(
        args=['a', 'b', 'c'],
        varargs=None,
        keywords=None,
        defaults=(1, 'hello'))

    inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
                                               '', argspec)
    outer_decorator = tf_decorator.TFDecorator('', inner_decorator)
    self.assertEqual(argspec, tf_inspect.getargspec(outer_decorator))
  def testGetArgSpecOnPartialPositionalArgumentOnly(self):
    """Tests getargspec on partial function with only positional arguments."""

    def func(m, n):
      return 2 * m + n

    partial_func = functools.partial(func, 7)
    argspec = tf_inspect.ArgSpec(
        args=['n'], varargs=None, keywords=None, defaults=None)

    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
  def testGetArgSpecOnPartialWithVarkwargs(self):
    """Tests getargspec on partial function with variable keyword arguments."""

    def func(m, n, **kwarg):
      return m * n + len(kwarg)

    partial_func = functools.partial(func, 7)
    argspec = tf_inspect.ArgSpec(
        args=['n'], varargs=None, keywords='kwarg', defaults=None)

    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
  def testGetArgSpecOnPartialKeywordArgumentWithDefaultValue(self):
    """Tests getargspec on partial function that prunes argument by keyword."""

    def func(m=1, n=2):
      return 2 * m + n

    partial_func = functools.partial(func, n=7)
    argspec = tf_inspect.ArgSpec(
        args=['m', 'n'], varargs=None, keywords=None, defaults=(1, 7))

    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
  def testGetArgSpecOnPartialKeywordArgument(self):
    """Tests getargspec on partial function that prunes some arguments."""

    def func(m, n):
      return 2 * m + n

    partial_func = functools.partial(func, n=7)
    argspec = tf_inspect.ArgSpec(
        args=['m', 'n'], varargs=None, keywords=None, defaults=(7,))

    self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
  def testGetArgSpecOnPartialWithDecoratorThatChangesArgspec(self):
    """Tests getargspec on partial function with decorated argspec."""

    argspec = tf_inspect.ArgSpec(
        args=['a', 'b', 'c'],
        varargs=None,
        keywords=None,
        defaults=(1, 'hello'))
    decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
                                         argspec)
    partial_argspec = tf_inspect.ArgSpec(
        args=['a', 'b', 'c'],
        varargs=None,
        keywords=None,
        defaults=(2, 1, 'hello'))
    partial_with_decorator = functools.partial(decorator, a=2)

    self.assertEqual(argspec, tf_inspect.getargspec(decorator))
    self.assertEqual(partial_argspec,
                     tf_inspect.getargspec(partial_with_decorator))
    def arg_test_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        names_v1 = tf_export.get_v1_names(attr)

        for name in names_v1:
          name = "tf.%s" % name
          if name not in all_keyword_renames:
            continue
          arg_names_v1 = tf_inspect.getargspec(attr)[0]
          keyword_renames = all_keyword_renames[name]
          self.assertEqual(type(keyword_renames), dict)

          # Assert that v1 function has valid v1 argument names.
          for from_name, _ in keyword_renames.items():
            self.assertIn(
                from_name, arg_names_v1,
                "%s not found in %s arguments: %s" %
                (from_name, name, str(arg_names_v1)))
def check_function_argument_count(func, input_arity, infeed_queue):
    """Validate the number of input arguments to a tpu function.

  Args:
    func: the Python function that will be called to generate the body of an XLA
      computation graph.
    input_arity: the number of explicit arguments supplied by the caller.
    infeed_queue: if not None, the infeed queue that will supply
      additional arguments to the function.

  Returns:
    None if function can be called with the supplied number of
      arguments, or an error string if it cannot.
  """
    def format_error(complaint, quantity):
        return "%s %d argument%s" % (complaint, quantity,
                                     "" if quantity == 1 else "s")

    number_of_arguments_needed = input_arity
    if infeed_queue is not None:
        number_of_arguments_needed += infeed_queue.number_of_tuple_elements
    arg_spec = tf_inspect.getargspec(func)
    number_of_args = len(arg_spec.args)
    if arg_spec.defaults is None:
        number_of_defaults = 0
    else:
        number_of_defaults = len(arg_spec.defaults)
    min_required_arguments = number_of_args - number_of_defaults
    if number_of_arguments_needed < min_required_arguments:
        # The required number of arguments is not enough to call the function.
        if number_of_defaults == 0 and arg_spec.varargs is None:
            return format_error("exactly", number_of_args)
        else:
            return format_error("at least", min_required_arguments)
    if arg_spec.varargs is None and number_of_arguments_needed > number_of_args:
        # The required number of arguments is too many to call the function.
        if number_of_defaults == 0:
            return format_error("exactly", number_of_args)
        else:
            return format_error("at most", number_of_args)
    # Since there are varargs, func can accept any number of arguments
    # greater than the minimum.
    return None
Exemple #33
0
def _add_name_scope_wrapper(func, api_signature):
    """Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`.

  If `func` already expects a "name" arg, or if `api_signature` does not
  expect a "name" arg, then returns `func` as-is.

  Args:
    func: The function to wrap.  Signature must match `api_signature` (except
      the "name" parameter may be missing.
    api_signature: The signature of the original API (used to find the index for
      the "name" parameter).

  Returns:
    The wrapped function (or the original function if no wrapping is needed).
  """
    if "name" not in api_signature.parameters:
        return func  # no wrapping needed (API has no name parameter).

    func_signature = tf_inspect.signature(func)
    func_argspec = tf_inspect.getargspec(func)
    if "name" in func_signature.parameters or func_argspec.keywords is not None:
        return func  # No wrapping needed (already has name parameter).

    name_index = list(api_signature.parameters).index("name")

    def wrapped_func(*args, **kwargs):
        if name_index < len(args):
            name = args[name_index]
            args = args[:name_index] + args[name_index + 1:]
        else:
            name = kwargs.pop("name", None)
        if name is None:
            return func(*args, **kwargs)
        else:
            with ops.name_scope(name):
                return func(*args, **kwargs)

    wrapped_func = tf_decorator.make_decorator(func, wrapped_func)
    wrapped_func.__signature__ = func_signature.replace(
        parameters=(list(func_signature.parameters.values()) +
                    [api_signature.parameters["name"]]))
    del wrapped_func._tf_decorator
    return wrapped_func
    def filter_sk_params(self, fn, override=None):
        """Filters `sk_params` and return those in `fn`'s arguments.

    Arguments:
        fn : arbitrary function
        override: dictionary, values to override sk_params

    Returns:
        res : dictionary dictionary containing variables
            in both sk_params and fn's arguments.
    """
        override = override or {}
        res = {}
        fn_args = tf_inspect.getargspec(fn)[0]
        for name, value in self.sk_params.items():
            if name in fn_args:
                res.update({name: value})
        res.update(override)
        return res
Exemple #35
0
    def decorator(dispatch_target):

        # Get the name & index for each iterable parameter.
        if iterable_parameters is None:
            iterable_params = None
        else:
            arg_names = tf_inspect.getargspec(dispatch_target).args
            iterable_params = [(name, arg_names.index(name))
                               for name in iterable_parameters]

        @traceback_utils.filter_traceback
        def op_dispatch_handler(*args, **kwargs):
            """Call `dispatch_target`, peforming dispatch when appropriate."""

            # Type-based dispatch system (dispatch v2):
            if api_dispatcher is not None:
                if iterable_params is not None:
                    args, kwargs = replace_iterable_params(
                        args, kwargs, iterable_params)
                result = api_dispatcher.Dispatch(args, kwargs)
                if result is not NotImplemented:
                    return result

            # Fallback dispatch system (dispatch v1):
            try:
                return dispatch_target(*args, **kwargs)
            except (TypeError, ValueError):
                # Note: convert_to_eager_tensor currently raises a ValueError, not a
                # TypeError, when given unexpected types.  So we need to catch both.
                result = dispatch(op_dispatch_handler, args, kwargs)
                if result is not OpDispatcher.NOT_SUPPORTED:
                    return result
                else:
                    raise

        add_fallback_dispatch_list(op_dispatch_handler)
        op_dispatch_handler = tf_decorator.make_decorator(
            dispatch_target, op_dispatch_handler)
        add_type_based_api_dispatcher(op_dispatch_handler)
        api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR,
                                 None)
        return op_dispatch_handler
def _make_gan_model(generator_fn, discriminator_fn, real_data,
                    generator_inputs, generator_scope, add_summaries, mode):
    """Make a `GANModel`, and optionally pass in `mode`."""
    # If `generator_fn` has an argument `mode`, pass mode to it.
    if 'mode' in inspect.getargspec(generator_fn).args:
        generator_fn = functools.partial(generator_fn, mode=mode)
    gan_model = tfgan_train.gan_model(generator_fn,
                                      discriminator_fn,
                                      real_data,
                                      generator_inputs,
                                      generator_scope=generator_scope,
                                      check_shapes=False)
    if add_summaries:
        if not isinstance(add_summaries, (tuple, list)):
            add_summaries = [add_summaries]
        with ops.name_scope(None):
            for summary_type in add_summaries:
                _summary_type_map[summary_type](gan_model)

    return gan_model
Exemple #37
0
def add_type_based_api_dispatcher(target):
    """Adds a PythonAPIDispatcher to the given TensorFlow API function."""
    if hasattr(target, TYPE_BASED_DISPATCH_ATTR):
        raise ValueError(f"{target} already has a type-based API dispatcher.")

    _, unwrapped = tf_decorator.unwrap(target)
    target_argspec = tf_inspect.getargspec(unwrapped)
    if target_argspec.varargs or target_argspec.keywords:
        # @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs
        # and keywords.  Examples of APIs that take varargs and kwargs: meshgrid,
        # einsum, map_values, map_flat_values.
        return target

    setattr(
        target, TYPE_BASED_DISPATCH_ATTR,
        _api_dispatcher.PythonAPIDispatcher(unwrapped.__name__,
                                            target_argspec.args,
                                            target_argspec.defaults))
    _TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list)
    return target
Exemple #38
0
 def wrapper(*args, **kwargs):
     """Wrapper that calls the compiled version of the wrapped function."""
     partial_types = ()
     arg_names = tf_inspect.getargspec(f)[0]
     for name, arg in zip(arg_names, args):
         arg_class = arg.__class__
         if tf_inspect.isclass(arg_class):
             # If arg_value_hints specifies any name, use that instead.
             # TODO(mdan): Shouldn't this just be in the func's globals?
             if name not in arg_value_hints:
                 arg_value_hints[name] = (arg_class.__name__, arg_class)
             # Annotated methods need to specify that their owner type is partial,
             # otherwise other members they call will not be converted.
             if name == 'self':
                 partial_types = (arg_class, )
     wrapped = to_graph(f,
                        recursive=recursive,
                        arg_value_hints=arg_value_hints,
                        partial_types=partial_types)
     return wrapped(*args, **kwargs)
Exemple #39
0
def ragged_op_list(tf_version=2):
    """Returns a string listing operations that have dispathers registered."""
    lines = []
    api_signatures = dispatch.type_based_dispatch_signatures_for(
        ragged_tensor.RaggedTensor)
    for api, signatures in api_signatures.items():
        arg_names = tf_inspect.getargspec(api).args
        ragged_args = set()
        for signature in signatures:
            for arg in signature:
                ragged_args.add(
                    arg if isinstance(arg, int) else arg_names.index(arg))
        if _op_is_in_tf_version(api, tf_version):
            lines.append(_ragged_op_signature(api, ragged_args))

    lines.append(
        _ragged_op_signature(logging_ops.print_v2, [], ragged_varargs=True))
    return (
        '\n\n### Additional ops that support `RaggedTensor`\n\n'
        'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
        '\n'.join(sorted(lines)) + 'n')
def _eager_mode_decorator(f, *args, **kwargs):
    """Implement custom gradient decorator for eager mode."""
    with backprop.GradientTape() as tape:
        result, grad_fn = f(*args, **kwargs)
    all_inputs = list(args) + list(kwargs.values())
    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables = [
        v for v in set(tape.watched_variables()) if v not in all_inputs
    ]
    grad_argspec = tf_inspect.getargspec(grad_fn)
    if (variables and
            not ("variables" in grad_argspec.args or grad_argspec.keywords)):
        raise TypeError("If using @custom_gradient with a function that "
                        "uses variables, then grad_fn must accept a keyword "
                        "argument 'variables'.")
    flat_result = nest.flatten(result)
    # TODO(apassos) consider removing the identity below.
    flat_result = [gen_array_ops.identity(x) for x in flat_result]

    def actual_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []
        return nest.flatten(input_grads) + variable_grads

    input_tensors = [
        ops.convert_to_tensor(x) for x in list(args) + list(variables)
    ]
    tape_lib.record_operation(f.__name__, flat_result, input_tensors,
                              actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(result, flat_result)
Exemple #41
0
 def wrapper(*args, **kwargs):
     """Wrapper that calls the compiled version of the wrapped function."""
     partial_types = ()
     arg_values = {}
     arg_names = tf_inspect.getargspec(f)[0]
     for name, arg in zip(arg_names, args):
         arg_values[name] = arg
         arg_class = arg.__class__
         # If arg_value_hints specifies any name, use that instead.
         if name not in arg_types:
             arg_types[name] = (arg_class.__name__, arg_class)
         if name == 'self' and tf_inspect.isclass(arg_class):
             # Annotated methods need to specify that their owner type is partial,
             # otherwise other members they call will not be converted.
             partial_types = (arg_class, )
     wrapped = to_graph(f,
                        recursive=recursive,
                        verbose=verbose,
                        arg_values=arg_values,
                        arg_types=arg_types,
                        partial_types=partial_types)
     return wrapped(*args, **kwargs)
Exemple #42
0
 def _FlatOutputProcessor(source_id, record):
     """Returns a flattened list of 'processor(inputs)'."""
     processor_spec = tf_inspect.getargspec(processor)
     tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
     processor_args = set(processor_spec.args) - set(['self'])
     if len(processor_args) == 1:
         output, bucketing_key = processor(record)
     elif processor_args == set(['source_id', 'record']):
         output, bucketing_key = processor(source_id=source_id,
                                           record=record)
     else:
         raise ValueError(
             'GenericInput: processor should take either a single arg '
             'or two args named as "source_id" and "record". '
             'Actual: %s' % processor_args)
     if isinstance(output, list):
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output), '{}'.format(output)
     else:
         assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output.Flatten()), '{}'.format(
                        output.DebugString())
     bucketing_key = tf.cast(bucketing_key, tf.int32)
     tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                      bucketing_key)
     output_tmpl.out_values = output
     flat_output_tmpl = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      py_utils.GetExtraInputs(), py_utils.GetExtraArgs(),
                      py_utils.GetExtraVars())
     assert not py_utils.GetExtraArgs(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, py_utils.GetExtraArgs()))
     return flat_output_tmpl + [bucketing_key]
Exemple #43
0
def fn_args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
  _, fn = tf_decorator.unwrap(fn)
  if isinstance(fn, functools.partial):
    args = fn_args(fn.func)
    args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
  else:
    if _is_callable_object(fn):
      fn = fn.__call__
    args = tf_inspect.getargspec(fn).args
    if _is_bounded_method(fn):
      args.remove('self')
  return tuple(args)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
  """Make a `GANModel` from just the generator."""
  # If `generator_fn` has an argument `mode`, pass mode to it.
  if 'mode' in inspect.getargspec(generator_fn).args:
    generator_fn = functools.partial(generator_fn,
                                     mode=model_fn_lib.ModeKeys.PREDICT)
  with variable_scope.variable_scope(generator_scope) as gen_scope:
    generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs)  # pylint:disable=protected-access
    generated_data = generator_fn(generator_inputs)
  generator_variables = variable_lib.get_trainable_variables(gen_scope)

  return tfgan_tuples.GANModel(
      generator_inputs,
      generated_data,
      generator_variables,
      gen_scope,
      generator_fn,
      real_data=None,
      discriminator_real_outputs=None,
      discriminator_gen_outputs=None,
      discriminator_variables=None,
      discriminator_scope=None,
      discriminator_fn=None)
Exemple #45
0
    def export(self,
               estimator,
               export_path,
               checkpoint_path=None,
               eval_result=None):
        """Exports the given Estimator to a specific format.

    Args:
      estimator: the Estimator to export.
      export_path: A string containing a directory where to write the export.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the strategy may locate a checkpoint (e.g. the most recent) by itself.
      eval_result: The output of Estimator.evaluate on this checkpoint.  This
        should be set only if checkpoint_path is provided (otherwise it is
        unclear which checkpoint this eval refers to).

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if the export_fn does not have the required signature
    """
        # don't break existing export_fns that don't accept checkpoint_path and
        # eval_result
        export_fn_args = tf_inspect.getargspec(self.export_fn).args
        kwargs = {}
        if 'checkpoint_path' in export_fn_args:
            kwargs['checkpoint_path'] = checkpoint_path
        if 'eval_result' in export_fn_args:
            if 'checkpoint_path' not in export_fn_args:
                raise ValueError(
                    'An export_fn accepting eval_result must also accept '
                    'checkpoint_path.')
            kwargs['eval_result'] = eval_result
        if 'strip_default_attrs' in export_fn_args:
            kwargs['strip_default_attrs'] = self.strip_default_attrs
        return self.export_fn(estimator, export_path, **kwargs)
Exemple #46
0
def _SanitizedArgSpec(obj):
    """Get an ArgSpec string that is free of addresses.

  We have callables as function arg defaults. This results in addresses in
  getargspec output. This function returns a sanitized string list of base
  classes.

  Args:
    obj: A python routine for us the create the sanitized arspec of.

  Returns:
    string, a string representation of the argspec.
  """
    output_string = ''
    unsanitized_arg_spec = tf_inspect.getargspec(obj)

    for clean_attr in ('args', 'varargs', 'keywords'):
        output_string += '%s=%s, ' % (
            clean_attr, getattr(unsanitized_arg_spec, clean_attr))

    if unsanitized_arg_spec.defaults:
        sanitized_defaults = []
        for val in unsanitized_arg_spec.defaults:
            str_val = str(val)
            # Sanitize argspecs that have hex code in them.
            if ' at 0x' in str_val:
                sanitized_defaults.append('%s instance>' %
                                          str_val.split(' at ')[0])
            else:
                sanitized_defaults.append(str_val)

        output_string += 'defaults=%s, ' % sanitized_defaults

    else:
        output_string += 'defaults=None'

    return output_string
Exemple #47
0
 def test_reduction_ops(self):
     ops_to_test = [
         (keras.backend.max, np.max),
         (keras.backend.min, np.min),
         (keras.backend.sum, np.sum),
         (keras.backend.prod, np.prod),
         (keras.backend.var, np.var),
         (keras.backend.std, np.std),
         (keras.backend.mean, np.mean),
         (keras.backend.argmin, np.argmin),
         (keras.backend.argmax, np.argmax),
     ]
     for keras_op, np_op in ops_to_test:
         with self.test_session():
             compare_single_input_op_to_numpy(keras_op,
                                              np_op,
                                              input_shape=(4, 7, 5),
                                              keras_kwargs={'axis': 1},
                                              np_kwargs={'axis': 1})
             compare_single_input_op_to_numpy(keras_op,
                                              np_op,
                                              input_shape=(4, 7, 5),
                                              keras_kwargs={'axis': -1},
                                              np_kwargs={'axis': -1})
             if 'keepdims' in tf_inspect.getargspec(keras_op).args:
                 compare_single_input_op_to_numpy(keras_op,
                                                  np_op,
                                                  input_shape=(4, 7, 5),
                                                  keras_kwargs={
                                                      'axis': 1,
                                                      'keepdims': True
                                                  },
                                                  np_kwargs={
                                                      'axis': 1,
                                                      'keepdims': True
                                                  })
def _get_arg_infos(func, arg_names):
  """Returns an `_ArgInfo` for each argument of `func` specified by `arg_names`.

  Args:
    func: The function whose arguments should be described.
    arg_names: The names of the arguments to get info for.

  Returns:
    A tuple of `_ArgInfo`s.
  """
  arg_infos = []

  # Inspect the func's argspec to find the position of each arg.
  arg_spec = tf_inspect.getargspec(func)
  for argname in arg_names:
    assert isinstance(argname, str)
    is_list = argname.startswith('[') and argname.endswith(']')
    if is_list:
      argname = argname[1:-1]
    if argname not in arg_spec.args:
      raise ValueError('Argument %r not found function in %s.  Args=%s' %
                       (argname, func, arg_spec.args))
    arg_infos.append(_ArgInfo(argname, arg_spec.args.index(argname), is_list))
  return arg_infos
def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None):  # pylint: disable=invalid-name
    """Construct a TPUDistributionStrategy."""
    from tensorflow.contrib.distribute.python import tpu_strategy  # pylint: disable=g-import-not-at-top
    # TODO(b/112705069): Remove this when TPUStrategy API is consistent.
    # We are including this for (a) backwards compatibility for open sourced
    # releases of TensorFlow and (b) to work around a circular dependency
    # where keras_support and tpu_strategy depends on each other. Once we release
    # a final version and remove support for the old API, this will be deleted.
    # (See bug above for more details)
    if tpu_cluster_resolver is None:
        tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')

    args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
    if len(args) == 4:
        logging.info('Detected new TPUStrategy API.')
        return tpu_strategy.TPUStrategy(tpu_cluster_resolver,
                                        steps_per_run=1,
                                        num_cores=num_cores)
    else:
        logging.info('Detected old TPUStrategy API.')
        strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
        strategy._tpu_cluster_resolver = tpu_cluster_resolver

    return strategy
Exemple #50
0
def _get_arg_spec(f, params, param_args):
  """The positions of the parameters of f to be differentiated in param_args."""
  try:
    args = tf_inspect.getargspec(f).args
  except TypeError as e:
    # TypeError can happen when f is a callable object.
    if params is None:
      return range(len(param_args))
    elif all(isinstance(x, int) for x in params):
      return params
    raise ValueError("Either callable provided is not a function or could not "
                     "inspect its arguments by name: %s. Original error: %s"
                     % (f, e))
  if params is None:
    if not args:
      return range(len(param_args))
    return range(len(args))
  elif all(isinstance(x, six.string_types) for x in params):
    return [args.index(n) for n in params]
  elif all(isinstance(x, int) for x in params):
    return params
  else:
    raise ValueError(
        "params must be all strings or all integers; got %s." % params)
Exemple #51
0
    def __call__(self, inputs, *args, **kwargs):
        """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.
    Returns:
      Output tensor(s).
    """
        self._set_scope(kwargs.pop('scope', None))

        # Ensure the Layer, if being reused, is working with inputs from
        # the same graph as where it was created.
        try:
            ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
        except ValueError as e:
            raise ValueError(
                'Input graph and Layer graph are not the same: %s' % e)

        with vs.variable_scope(self._scope, reuse=self.built
                               or self._reuse) as scope:
            with ops.name_scope(scope.original_name_scope):
                if not self.built:
                    # Check input assumptions set before layer building, e.g. input rank.
                    self._assert_input_compatibility(inputs)
                    input_list = [
                        ops.convert_to_tensor(x, name='input')
                        for x in nest.flatten(inputs)
                    ]
                    input_shapes = [x.get_shape() for x in input_list]
                    if len(input_shapes) == 1:
                        self.build(input_shapes[0])
                    else:
                        self.build(input_shapes)
                if 'scope' in tf_inspect.getargspec(self.call).args:
                    kwargs['scope'] = scope
                # Check input assumptions set after layer building, e.g. input shape.
                self._assert_input_compatibility(inputs)
                outputs = self.call(inputs, *args, **kwargs)

                # Apply activity regularization.
                # Note that it should be applied every time the layer creates a new
                # output, since it is output-specific.
                if hasattr(
                        self,
                        'activity_regularizer') and self.activity_regularizer:
                    output_list = _to_list(outputs)
                    for output in output_list:
                        with ops.name_scope('ActivityRegularizer'):
                            activity_regularization = self.activity_regularizer(
                                output)
                        self.add_loss(activity_regularization)
                        _add_elements_to_collection(
                            activity_regularization,
                            ops.GraphKeys.REGULARIZATION_LOSSES)

        # Update global default collections.
        _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
        self.built = True
        return outputs
def _recompute_grad(fn, args, use_data_dep=True, tupleize_grads=False):
    """See recompute_grad."""
    has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
    for arg in args:
        if not isinstance(arg, (framework_ops.Tensor, variables_lib.Variable)):
            raise ValueError("All inputs to function must be Tensors")
    use_data_dep_ = use_data_dep
    if use_data_dep_:  # USE_DEFAULT == True
        use_data_dep_ = _is_on_tpu()

    # Use custom_gradient and return a grad_fn that recomputes on the backwards pass.
    @custom_gradient.custom_gradient
    def fn_with_recompute(*args):
        """Wrapper for fn."""
        # Capture the variable and arg scopes so we can re-enter them when recomputing.
        vs = variable_scope.get_variable_scope()
        arg_scope = arg_scope_lib.current_arg_scope()
        # Track all variables touched in the function.
        with backprop.GradientTape() as tape:
            fn_kwargs = {}
            if has_is_recompute_kwarg:
                fn_kwargs["is_recomputing"] = False
            outputs = fn(*args, **fn_kwargs)
        original_vars = set(_as_ref(v) for v in tape.watched_variables())

        def _grad_fn(output_grads, variables=None):
            # Validate that custom_gradient passes the right variables into grad_fn.
            if original_vars:
                assert variables, (
                    "Fn created variables but the variables were not passed to the gradient fn."
                )
                if set(_as_ref(v) for v in variables) != original_vars:
                    raise ValueError(_WRONG_VARS_ERR)

            return _recomputing_grad_fn(
                compute_fn=fn,
                original_args=args,
                original_vars=original_vars,
                output_grads=output_grads,
                grad_fn_variables=variables,
                use_data_dep=use_data_dep_,
                tupleize_grads=tupleize_grads,
                arg_scope=arg_scope,
                var_scope=vs,
                has_is_recompute_kwarg=has_is_recompute_kwarg)

        # custom_gradient inspects the signature of the function to determine
        # whether the user expects variables passed in the grad_fn. If the function
        # created variables, the grad_fn should accept the "variables" kwarg.
        if original_vars:

            def grad_fn(*output_grads, **kwargs):
                return _grad_fn(output_grads, kwargs["variables"])
        else:

            def grad_fn(*output_grads):
                return _grad_fn(output_grads)

        return outputs, grad_fn

    return fn_with_recompute(*args)
  def __call__(self, func):
    # Various sanity checks on the callable func.
    if not callable(func):
      raise ValueError(f"Function {func} must be a callable.")

    # Func should not use kwargs and defaults.
    argspec = tf_inspect.getargspec(func)
    if argspec.keywords or argspec.defaults:
      raise ValueError(
          "Functions with argument defaults or keywords arguments are not "
          f"supported. {func} has defaults {argspec.defaults} and keywords "
          f"{argspec.keywords}.")

    # Computes how many arguments 'func' has.
    min_args = len(argspec.args)
    max_args = min_args
    if argspec.varargs:
      max_args = 1000000
    argnames = argspec.args
    if tf_inspect.ismethod(func):
      # 1st argument is the "class" type.
      min_args -= 1
      argnames = argnames[1:]

    if self._input_types:
      # If Defun is given a list of types for the inputs, the number
      # of input types should be compatible with 'func'.
      num = len(self._input_types)
      if num < min_args or num > max_args:
        raise ValueError(
            "The number of tf.function input types is not compatible with the "
            f"allowed arguments of {func}. The tf.function have {num} input "
            f"types, while the python function allows minimum {min_args} and "
            f"maximum {max_args} arguments.")
      return _DefinedFunction(
          func,
          argnames,
          self._input_types,
          self._func_name,
          self._grad_func,
          self._python_grad_func,
          out_names=self._out_names,
          **self._extra_kwargs)

    # 'func' expects no arguments and input types is an empty list.
    if min_args == 0 and max_args == 0:
      return _DefinedFunction(
          func, [], [],
          self._func_name,
          self._grad_func,
          self._python_grad_func,
          out_names=self._out_names,
          **self._extra_kwargs)

    # Input types are unknown. It's an overloaded function and hence
    # its definition needs to be deferred until it's called.
    return _OverloadedFunction(
        func,
        argnames,
        self._func_name,
        self._grad_func,
        self._python_grad_func,
        out_names=self._out_names,
        **self._extra_kwargs)
Exemple #54
0
def supports_kwargs(module_or_fn, kwargs_list):
    """Determines whether the provided callable supports all the kwargs.

  This is useful when you have a module that might or might not support a
  kwarg such as `is_training`. Rather than calling the module and catching the
  error, risking the potential modification of underlying state, this function
  introspects the module to see what kwargs are actually supported, using
  the python `inspect` module.

  Note that many TF functions do not export a valid argspec object, rather they
  have a generic *args, **kwargs signature due to various layers of wrapping
  (deprecation decorators, etc). In those circumstances we return
  MAYBE_SUPPORTED, and users will have to use another method to tell whether
  the kwargs are supported (e.g. by just calling the function).

  Args:
    module_or_fn: some callable, generally an object or a method of some object.
      If an object is provided, we check wither `module_or_fn.__call__` supports
      the provided kwargs, which for a Sonnet module will automatically check
      the signature of _build. If `module_or_fn` is a function/method, then
      we check its signature directly, so non-Sonnet functions can be used.
    kwargs_list: string or iterable of strings of keyword arg names to test for.
      If an empty iterable is provided this function will always return True.

  Raises:
    ValueError: if a non-string is provided in `kwargs_list`.

  Returns:
    a string, one of 'supported', 'not_supported' or 'maybe_supported'.
  """
    if isinstance(kwargs_list, six.string_types):
        kwargs_list = [kwargs_list]

    # If it's not a function or method, then assume it's a module, so introspect
    # the __call__ method. wrapt ensures that for Sonnet modules the _build
    # signature is available here.
    if not (inspect.isfunction(module_or_fn)
            or inspect.ismethod(module_or_fn)):
        module_or_fn = module_or_fn.__call__

    arg_spec = tf_inspect.getargspec(module_or_fn)

    # If there is a keywords element, then an arbitrary kwargs will work, as far
    # as we can tell from here.
    takes_arbitrary_kwargs = (arg_spec.keywords is not None)

    for kwarg in kwargs_list:
        if not isinstance(kwarg, six.string_types):
            raise ValueError(
                "kwargs should be strings, instead got {}".format(kwarg))
        if kwarg not in arg_spec.args:
            if not takes_arbitrary_kwargs:
                # The function doesn't take **kwargs, and this name is not in the
                # regular args, so it would definitely cause an error to call this.
                return NOT_SUPPORTED
            else:
                # The function may accept the kwarg, but we can't say for sure. Even
                # though this is only one kwarg, we can't be certain about the whole
                # lot, so the combined answer is now "maybe".
                return MAYBE_SUPPORTED
    # All the kwargs must actually be present in the specific args list
    return SUPPORTED
Exemple #55
0
def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
               input_data=None, expected_output=None,
               expected_output_dtype=None):
  """Test routine for a layer with a single input and single output.

  Arguments:
    layer_cls: Layer class object.
    kwargs: Optional dictionary of keyword arguments for instantiating the
      layer.
    input_shape: Input shape tuple.
    input_dtype: Data type of the input data.
    input_data: Numpy array of input data.
    expected_output: Shape tuple for the expected shape of the output.
    expected_output_dtype: Data type expected for the output.

  Returns:
    The output data (Numpy array) returned by the layer, for additional
    checks to be done by the calling code.

  Raises:
    ValueError: if `input_shape is None`.
  """
  if input_data is None:
    if input_shape is None:
      raise ValueError('input_shape is None')
    if not input_dtype:
      input_dtype = 'float32'
    input_data_shape = list(input_shape)
    for i, e in enumerate(input_data_shape):
      if e is None:
        input_data_shape[i] = np.random.randint(1, 4)
    input_data = 10 * np.random.random(input_data_shape)
    if input_dtype[:5] == 'float':
      input_data -= 0.5
    input_data = input_data.astype(input_dtype)
  elif input_shape is None:
    input_shape = input_data.shape
  if input_dtype is None:
    input_dtype = input_data.dtype
  if expected_output_dtype is None:
    expected_output_dtype = input_dtype

  # instantiation
  kwargs = kwargs or {}
  layer = layer_cls(**kwargs)

  # test get_weights , set_weights at layer level
  weights = layer.get_weights()
  layer.set_weights(weights)

  # test and instantiation from weights
  if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
    kwargs['weights'] = weights
    layer = layer_cls(**kwargs)

  # test in functional API
  x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype)
  y = layer(x)
  if keras.backend.dtype(y) != expected_output_dtype:
    raise AssertionError('When testing layer %s, for input %s, found output '
                         'dtype=%s but expected to find %s.\nFull kwargs: %s' %
                         (layer_cls.__name__,
                          x,
                          keras.backend.dtype(y),
                          expected_output_dtype,
                          kwargs))
  # check shape inference
  model = keras.models.Model(x, y)
  expected_output_shape = tuple(
      layer.compute_output_shape(
          tensor_shape.TensorShape(input_shape)).as_list())
  actual_output = model.predict(input_data)
  actual_output_shape = actual_output.shape
  for expected_dim, actual_dim in zip(expected_output_shape,
                                      actual_output_shape):
    if expected_dim is not None:
      if expected_dim != actual_dim:
        raise AssertionError(
            'When testing layer %s, for input %s, found output_shape='
            '%s but expected to find %s.\nFull kwargs: %s' %
            (layer_cls.__name__,
             x,
             actual_output_shape,
             expected_output_shape,
             kwargs))
  if expected_output is not None:
    np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3)

  # test serialization, weight setting at model level
  model_config = model.get_config()
  recovered_model = keras.models.Model.from_config(model_config)
  if model.weights:
    weights = model.get_weights()
    recovered_model.set_weights(weights)
    output = recovered_model.predict(input_data)
    np.testing.assert_allclose(output, actual_output, rtol=2e-3)

  # test training mode (e.g. useful for dropout tests)
  # Rebuild the model to avoid the graph being reused between predict() and
  # train(). This was causing some error for layer with Defun as it body.
  # See b/120160788 for more details. This should be mitigated after 2.0.
  model = keras.models.Model(x, layer(x))
  if _thread_local_data.run_eagerly is not None:
    model.compile(
        'rmsprop',
        'mse',
        weighted_metrics=['acc'],
        run_eagerly=should_run_eagerly())
  else:
    model.compile('rmsprop', 'mse', weighted_metrics=['acc'])
  model.train_on_batch(input_data, actual_output)

  # test as first layer in Sequential API
  layer_config = layer.get_config()
  layer_config['batch_input_shape'] = input_shape
  layer = layer.__class__.from_config(layer_config)

  model = keras.models.Sequential()
  model.add(layer)
  actual_output = model.predict(input_data)
  actual_output_shape = actual_output.shape
  for expected_dim, actual_dim in zip(expected_output_shape,
                                      actual_output_shape):
    if expected_dim is not None:
      if expected_dim != actual_dim:
        raise AssertionError(
            'When testing layer %s **after deserialization**, '
            'for input %s, found output_shape='
            '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
            (layer_cls.__name__,
             x,
             actual_output_shape,
             expected_output_shape,
             kwargs))
  if expected_output is not None:
    np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3)

  # test serialization, weight setting at model level
  model_config = model.get_config()
  recovered_model = keras.models.Sequential.from_config(model_config)
  if model.weights:
    weights = model.get_weights()
    recovered_model.set_weights(weights)
    output = recovered_model.predict(input_data)
    np.testing.assert_allclose(output, actual_output, rtol=2e-3)

  # for further checks in the caller function
  return actual_output
Exemple #56
0
    def deprecated_wrapper(func):
        """Deprecation decorator."""
        decorator_utils.validate_callable(func, 'deprecated_args')
        deprecated_arg_names = _get_arg_names_to_ok_vals()

        arg_spec = tf_inspect.getargspec(func)
        deprecated_positions = _get_deprecated_positional_arguments(
            deprecated_arg_names, arg_spec)

        is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names
        is_kwargs_deprecated = arg_spec.keywords in deprecated_arg_names

        if (len(deprecated_positions) + is_varargs_deprecated +
                is_kwargs_deprecated != len(deprecated_arg_names_or_tuples)):
            known_args = arg_spec.args + [arg_spec.varargs, arg_spec.keywords]
            missing_args = [
                arg_name for arg_name in deprecated_arg_names
                if arg_name not in known_args
            ]
            raise ValueError(
                'The following deprecated arguments are not present '
                'in the function signature: %s. '
                'Found next arguments: %s.' % (missing_args, known_args))

        def _same_value(a, b):
            """A comparison operation that works for multiple object types.

      Returns True for two empty lists, two numeric values with the
      same value, etc.

      Returns False for (pd.DataFrame, None), and other pairs which
      should not be considered equivalent.

      Args:
        a: value one of the comparison.
        b: value two of the comparison.

      Returns:
        A boolean indicating whether the two inputs are the same value
        for the purposes of deprecation.
      """
            if a is b:
                return True
            try:
                equality = a == b
                if isinstance(equality, bool):
                    return equality
            except TypeError:
                return False
            return False

        @functools.wraps(func)
        def new_func(*args, **kwargs):
            """Deprecation wrapper."""
            # TODO(apassos) figure out a way to have reasonable performance with
            # deprecation warnings and eager mode.
            if context.in_graph_mode() and _PRINT_DEPRECATION_WARNINGS:
                invalid_args = []
                named_args = tf_inspect.getcallargs(func, *args, **kwargs)
                for arg_name, spec in iter(deprecated_positions.items()):
                    if (spec.position < len(args)
                            and not (spec.has_ok_value and _same_value(
                                named_args[arg_name], spec.ok_value))):
                        invalid_args.append(arg_name)
                if is_varargs_deprecated and len(args) > len(arg_spec.args):
                    invalid_args.append(arg_spec.varargs)
                if is_kwargs_deprecated and kwargs:
                    invalid_args.append(arg_spec.keywords)
                for arg_name in deprecated_arg_names:
                    if (arg_name in kwargs and not (
                            deprecated_positions[arg_name].has_ok_value
                            and _same_value(
                                named_args[arg_name],
                                deprecated_positions[arg_name].ok_value))):
                        invalid_args.append(arg_name)
                for arg_name in invalid_args:
                    if (func, arg_name) not in _PRINTED_WARNING:
                        if warn_once:
                            _PRINTED_WARNING[(func, arg_name)] = True
                        logging.warning(
                            'From %s: calling %s (from %s) with %s is deprecated and will '
                            'be removed %s.\nInstructions for updating:\n%s',
                            _call_location(),
                            decorator_utils.get_qualified_name(func),
                            func.__module__, arg_name,
                            'in a future version' if date is None else
                            ('after %s' % date), instructions)
            return func(*args, **kwargs)

        return tf_decorator.make_decorator(
            func, new_func, 'deprecated',
            _add_deprecated_arg_notice_to_docstring(func.__doc__, date,
                                                    instructions))
Exemple #57
0
def layer_test(layer_cls,
               kwargs=None,
               input_shape=None,
               input_dtype=None,
               input_data=None,
               expected_output=None,
               expected_output_dtype=None,
               expected_output_shape=None,
               validate_training=True,
               adapt_data=None):
    """Test routine for a layer with a single input and single output.

  Arguments:
    layer_cls: Layer class object.
    kwargs: Optional dictionary of keyword arguments for instantiating the
      layer.
    input_shape: Input shape tuple.
    input_dtype: Data type of the input data.
    input_data: Numpy array of input data.
    expected_output: Numpy array of the expected output.
    expected_output_dtype: Data type expected for the output.
    expected_output_shape: Shape tuple for the expected shape of the output.
    validate_training: Whether to attempt to validate training on this layer.
      This might be set to False for non-differentiable layers that output
      string or integer values.
    adapt_data: Optional data for an 'adapt' call. If None, adapt() will not
      be tested for this layer. This is only relevant for PreprocessingLayers.

  Returns:
    The output data (Numpy array) returned by the layer, for additional
    checks to be done by the calling code.

  Raises:
    ValueError: if `input_shape is None`.
  """
    if input_data is None:
        if input_shape is None:
            raise ValueError('input_shape is None')
        if not input_dtype:
            input_dtype = 'float32'
        input_data_shape = list(input_shape)
        for i, e in enumerate(input_data_shape):
            if e is None:
                input_data_shape[i] = np.random.randint(1, 4)
        input_data = 10 * np.random.random(input_data_shape)
        if input_dtype[:5] == 'float':
            input_data -= 0.5
        input_data = input_data.astype(input_dtype)
    elif input_shape is None:
        input_shape = input_data.shape
    if input_dtype is None:
        input_dtype = input_data.dtype
    if expected_output_dtype is None:
        expected_output_dtype = input_dtype

    # instantiation
    kwargs = kwargs or {}
    layer = layer_cls(**kwargs)

    # Test adapt, if data was passed.
    if adapt_data is not None:
        layer.adapt(adapt_data)

    # test get_weights , set_weights at layer level
    weights = layer.get_weights()
    layer.set_weights(weights)

    # test and instantiation from weights
    if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
        kwargs['weights'] = weights
        layer = layer_cls(**kwargs)

    # test in functional API
    x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype)
    y = layer(x)
    if keras.backend.dtype(y) != expected_output_dtype:
        raise AssertionError(
            'When testing layer %s, for input %s, found output '
            'dtype=%s but expected to find %s.\nFull kwargs: %s' %
            (layer_cls.__name__, x, keras.backend.dtype(y),
             expected_output_dtype, kwargs))

    def assert_shapes_equal(expected, actual):
        """Asserts that the output shape from the layer matches the actual shape."""
        if len(expected) != len(actual):
            raise AssertionError(
                'When testing layer %s, for input %s, found output_shape='
                '%s but expected to find %s.\nFull kwargs: %s' %
                (layer_cls.__name__, x, actual, expected, kwargs))

        for expected_dim, actual_dim in zip(expected, actual):
            if isinstance(expected_dim, tensor_shape.Dimension):
                expected_dim = expected_dim.value
            if isinstance(actual_dim, tensor_shape.Dimension):
                actual_dim = actual_dim.value
            if expected_dim is not None and expected_dim != actual_dim:
                raise AssertionError(
                    'When testing layer %s, for input %s, found output_shape='
                    '%s but expected to find %s.\nFull kwargs: %s' %
                    (layer_cls.__name__, x, actual, expected, kwargs))

    if expected_output_shape is not None:
        assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape),
                            y.shape)

    # check shape inference
    model = keras.models.Model(x, y)
    computed_output_shape = tuple(
        layer.compute_output_shape(
            tensor_shape.TensorShape(input_shape)).as_list())
    computed_output_signature = layer.compute_output_signature(
        tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype))
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    assert_shapes_equal(computed_output_shape, actual_output_shape)
    assert_shapes_equal(computed_output_signature.shape, actual_output_shape)
    if computed_output_signature.dtype != actual_output.dtype:
        raise AssertionError(
            'When testing layer %s, for input %s, found output_dtype='
            '%s but expected to find %s.\nFull kwargs: %s' %
            (layer_cls.__name__, x, actual_output.dtype,
             computed_output_signature.dtype, kwargs))
    if expected_output is not None:
        np.testing.assert_allclose(actual_output,
                                   expected_output,
                                   rtol=1e-3,
                                   atol=1e-6)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = keras.models.Model.from_config(model_config)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)

    # test training mode (e.g. useful for dropout tests)
    # Rebuild the model to avoid the graph being reused between predict() and
    # See b/120160788 for more details. This should be mitigated after 2.0.
    if validate_training:
        model = keras.models.Model(x, layer(x))
        if _thread_local_data.run_eagerly is not None:
            model.compile('rmsprop',
                          'mse',
                          weighted_metrics=['acc'],
                          run_eagerly=should_run_eagerly())
        else:
            model.compile('rmsprop', 'mse', weighted_metrics=['acc'])
        model.train_on_batch(input_data, actual_output)

    # test as first layer in Sequential API
    layer_config = layer.get_config()
    layer_config['batch_input_shape'] = input_shape
    layer = layer.__class__.from_config(layer_config)

    # Test adapt, if data was passed.
    if adapt_data is not None:
        layer.adapt(adapt_data)

    model = keras.models.Sequential()
    model.add(layer)
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    for expected_dim, actual_dim in zip(computed_output_shape,
                                        actual_output_shape):
        if expected_dim is not None:
            if expected_dim != actual_dim:
                raise AssertionError(
                    'When testing layer %s **after deserialization**, '
                    'for input %s, found output_shape='
                    '%s but expected to find inferred shape %s.\nFull kwargs: %s'
                    % (layer_cls.__name__, x, actual_output_shape,
                       computed_output_shape, kwargs))
    if expected_output is not None:
        np.testing.assert_allclose(actual_output,
                                   expected_output,
                                   rtol=1e-3,
                                   atol=1e-6)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = keras.models.Sequential.from_config(model_config)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6)

    # for further checks in the caller function
    return actual_output
Exemple #58
0
def layer_test(layer_cls,
               kwargs=None,
               input_shape=None,
               input_dtype=None,
               input_data=None,
               expected_output=None,
               expected_output_dtype=None):
    """Test routine for a layer with a single input and single output.

  Arguments:
    layer_cls: Layer class object.
    kwargs: Optional dictionary of keyword arguments for instantiating the
      layer.
    input_shape: Input shape tuple.
    input_dtype: Data type of the input data.
    input_data: Numpy array of input data.
    expected_output: Shape tuple for the expected shape of the output.
    expected_output_dtype: Data type expected for the output.

  Returns:
    The output data (Numpy array) returned by the layer, for additional
    checks to be done by the calling code.
  """
    if input_data is None:
        assert input_shape
        if not input_dtype:
            input_dtype = 'float32'
        input_data_shape = list(input_shape)
        for i, e in enumerate(input_data_shape):
            if e is None:
                input_data_shape[i] = np.random.randint(1, 4)
        input_data = 10 * np.random.random(input_data_shape)
        if input_dtype[:5] == 'float':
            input_data -= 0.5
        input_data = input_data.astype(input_dtype)
    elif input_shape is None:
        input_shape = input_data.shape
    if input_dtype is None:
        input_dtype = input_data.dtype
    if expected_output_dtype is None:
        expected_output_dtype = input_dtype

    # instantiation
    kwargs = kwargs or {}
    layer = layer_cls(**kwargs)

    # test get_weights , set_weights at layer level
    weights = layer.get_weights()
    layer.set_weights(weights)

    # test and instantiation from weights
    if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
        kwargs['weights'] = weights
        layer = layer_cls(**kwargs)

    # test in functional API
    x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype)
    y = layer(x)
    assert keras.backend.dtype(y) == expected_output_dtype

    # check shape inference
    model = keras.models.Model(x, y)
    expected_output_shape = tuple(
        layer.compute_output_shape(
            tensor_shape.TensorShape(input_shape)).as_list())
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    for expected_dim, actual_dim in zip(expected_output_shape,
                                        actual_output_shape):
        if expected_dim is not None:
            assert expected_dim == actual_dim
    if expected_output is not None:
        np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = keras.models.Model.from_config(model_config)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        np.testing.assert_allclose(output, actual_output, rtol=1e-3)

    # test training mode (e.g. useful for dropout tests)
    model.compile('rmsprop', 'mse')
    model.train_on_batch(input_data, actual_output)

    # test as first layer in Sequential API
    layer_config = layer.get_config()
    layer_config['batch_input_shape'] = input_shape
    layer = layer.__class__.from_config(layer_config)

    model = keras.models.Sequential()
    model.add(layer)
    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    for expected_dim, actual_dim in zip(expected_output_shape,
                                        actual_output_shape):
        if expected_dim is not None:
            assert expected_dim == actual_dim
    if expected_output is not None:
        np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = keras.models.Sequential.from_config(model_config)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        output = recovered_model.predict(input_data)
        np.testing.assert_allclose(output, actual_output, rtol=1e-3)

    # test training mode (e.g. useful for dropout tests)
    model.compile('rmsprop', 'mse')
    model.train_on_batch(input_data, actual_output)

    # for further checks in the caller function
    return actual_output
Exemple #59
0
def _graph_callable_internal(func, shape_and_dtypes):
  """Defines and returns a template version of func.

  Under the hood we make two function objects, each wrapping a different version
  of the graph-mode code. One version immediately runs variable initialization
  before making the variable's Tensors available for use, while the other
  version replaces the Variables with placeholders which become function
  arguments and get the current variable's value.

  Limitations in (2) and (4) are because this does not implement a graph-mode
  Variable class which has a convert_to_tensor(as_ref=True) method and a
  initialized_value method. This is fixable.

  Args:
    func: The tfe Python function to compile.
    shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.

  Raises:
    ValueError: If any one of func's outputs is not a Tensor.

  Returns:
    Callable graph object.
  """
  container = tf_ops.get_default_graph()._container  # pylint: disable=protected-access
  graph_key = tf_ops.get_default_graph()._graph_key  # pylint: disable=protected-access
  with context.graph_mode():
    # This graph will store both the initialization and the call version of the
    # wrapped function. It will later be used by the backprop code to build the
    # backprop graph, if necessary.
    captures = {}
    tmp_graph = function.CapturingGraph(captures)
    # Inherit the graph key from the original graph to ensure optimizers don't
    # misbehave.
    tmp_graph._container = container  # pylint: disable=protected-access
    tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
    with tmp_graph.as_default():
      # Placeholders for the non-variable inputs.
      func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
      func_num_args = len(tf_inspect.getargspec(func).args)
      if len(func_inputs) != func_num_args:
        raise TypeError("The number of arguments accepted by the decorated "
                        "function `%s` (%d) must match the number of "
                        "ShapeAndDtype objects passed to the graph_callable() "
                        "decorator (%d)." %
                        (func.__name__, func_num_args, len(func_inputs)))

      # First call the function to generate a graph which can initialize all
      # variables. As a side-effect this will populate the variable capturing
      # scope's view of which variables exist.
      variable_captures = _VariableCapturingScope()
      with variable_captures.initializing_scope(), function.capture_tensors(
          captures), function.AutomaticControlDependencies() as a:
        func_outputs = func(*func_inputs)
        outputs_list = nest.flatten(func_outputs)
        for i, x in enumerate(outputs_list):
          if x is not None:
            outputs_list[i] = a.mark_as_return(x)
      if len(outputs_list) == 1 and outputs_list[0] is None:
        outputs_list = []
      output_shapes = [x.shape for x in outputs_list]
      if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list):
        raise ValueError("Found non-tensor output in %s" % str(outputs_list))
      initializing_operations = tmp_graph.get_operations()

      # Call the function again, now replacing usages of variables with
      # placeholders. This assumes the variable capturing scope created above
      # knows about all variables.
      tmp_graph.clear_resource_control_flow_state()
      with variable_captures.capturing_scope(), function.capture_tensors(
          captures), function.AutomaticControlDependencies() as a:
        captured_outputs = func(*func_inputs)
      captured_outlist = nest.flatten(captured_outputs)
      for i, x in enumerate(captured_outlist):
        if x is not None:
          captured_outlist[i] = a.mark_as_return(x)
      capturing_operations = tmp_graph.get_operations()[
          len(initializing_operations):]

  sorted_variables = sorted(variable_captures.variables.values(),
                            key=lambda x: x.name)
  ids = list(sorted(captures.keys()))
  if ids:
    extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids])
  else:
    extra_inputs = []
    extra_placeholders = []

  flat_inputs = [x for x in nest.flatten(func_inputs)
                 if isinstance(x, tf_ops.Tensor)]
  placeholder_inputs = flat_inputs+ list(extra_placeholders)

  func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
  initialization_name = function._inference_name(func.__name__)  # pylint: disable=protected-access
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  # Also, what about the gradient registry of these functions? Those need to be
  # addressed as well.
  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    function._register(f._c_func.func)  # pylint: disable=protected-access
  initializer_function = function.GraphModeFunction(
      initialization_name,
      placeholder_inputs,
      extra_inputs,
      tmp_graph,
      initializing_operations,
      func_def_outputs,
      func_outputs,
      output_shapes)

  capture_func_def_outputs = [
      x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
  captured_function_name = function._inference_name(func.__name__)  # pylint: disable=protected-access
  captured_function = function.GraphModeFunction(
      captured_function_name,
      placeholder_inputs,
      extra_inputs,
      tmp_graph,
      capturing_operations,
      capture_func_def_outputs,
      captured_outputs,
      output_shapes,
      variables=[x.variable for x in sorted_variables])

  return _InitializingFunctionObject(captured_function, initializer_function,
                                     shape_and_dtypes)
Exemple #60
0
def reuse_variables(method):
    """Wraps an arbitrary method so it does variable sharing.

  This decorator creates variables the first time it calls `method`, and reuses
  them for subsequent calls. The object that calls `method` provides a
  `tf.VariableScope`, either as a `variable_scope` attribute or as the return
  value of an `_enter_variable_scope()` method.

  The first time the wrapped method is invoked, it enters the caller's
  `tf.VariableScope` with `reuse=False`. On all subsequent calls it enters the
  same variable scope with `reuse=True`.

  Variables are created in the context of the `tf.VariableScope` provided by the
  caller object. Ops are created with an additional `tf.name_scope()`, which
  adds a scope for the wrapped method name. For example:

  ```python
  class MyClass(object):

    def __init__(self, name):
      with tf.variable_scope(None, default_name=name) as variable_scope:
        self.variable_scope = variable_scope

    @snt.reuse_variables
    def add_x(self, tensor):
      x = tf.get_variable("x", shape=tensor.get_shape())
      return tensor + x

  module = MyClass("my_module_name")
  input_tensor = tf.zeros(shape=(5,))

  # This creates the variable "my_module_name/x"
  # and op "my_module_name/add_x/add"
  output = module.add_x(input_tensor)
  ```

  For performance when executing eagerly it may be desirable to additionally
  annotate these methods using `defun`, such that they are encapsulated as
  graph functions. This is not recommended if your method returns a variable
  since the output of `defun` would be an op that returned the variable's value
  when evaluated (rather than the variable instance).

  ```python
  class FooModule(snt.AbstractModule):
    def _build(self, inputs):
      return complex_math(inputs)

    @tfe.defun
    @snt.reuse_variables
    def more_complex_stuff(self, inputs):
      return more_complex_math(inputs)
  ```

  Args:
    method: The method to wrap.

  Returns:
    The wrapped method.
  """
    initialized_variable_scopes_eager = set()
    initialized_variable_scopes_graph = weakref.WeakKeyDictionary()

    # Ensure that the argument passed in is really a method by checking that the
    # first positional argument to it is "self".
    arg_spec = tf_inspect.getargspec(method)
    is_method = arg_spec.args and arg_spec.args[0] == "self"

    if not is_method:
        raise TypeError("reuse_variables can only be used with methods.")

    @wrapt.decorator
    def eager_test(method, obj, args, kwargs):
        """Validates runtime state in eager mode."""
        # If @reuse_variables is combined with @property, obj is passed in args
        # and method is still unbound at this stage.
        if obj is None:
            obj = args[0]

        if tf.executing_eagerly() and not hasattr(obj, "_template"):
            raise ValueError(
                "reuse_variables is not supported in eager mode except in Sonnet "
                "modules.")

        return method(*args, **kwargs)

    @wrapt.decorator
    def call_method(method, obj, args, kwargs):
        """Calls `method` with a variable scope whose reuse flag is set correctly.

    The first time the wrapper is called it creates a
    `(tf.Graph, tf.VariableScope)` key and checks it for membership in
    `initialized_variable_scopes`. The check is `False` if and only if this is
    the first time the wrapper has been called with the key, otherwise the
    check is `True`. The result of this check is used as the `reuse` flag for
    entering the provided variable scope before calling `method`.

    Here are two examples of how to use the reuse_variables decorator.

    1. Decorate an arbitrary instance method with a `variable_scope` attribute:

      ```python
      class Reusable(object):

        def __init__(self, name):
          with tf.variable_scope(None, default_name=name) as vs:
            self.variable_scope = vs

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj.add_a(x)
      out2 = obj.add_a(x)
      # out1 == out2
      ```

    2. Decorating a snt.AbstractModule instance method:

      ```python
      class ReusableModule(snt.AbstractModule):

        @snt.reuse_variables
        def add_a(self, input_tensor):
          a = tf.get_variable("a", shape=input_tensor.get_shape())
          return a + input_tensor

        # We don't need @snt.reuse_variables here because build is
        wrapped by # `tf.make_template` inside `snt.AbstractModule`.
        def _build(self, input_tensor):
          b = tf.get_variable("b", shape=input_tensor.get_shape())
          return b + self.add_a(input_tensor)

      obj = Reusable("reusable")
      x = tf.constant(5.0)
      out1 = obj(x)
      out2 = obj(x)
      # out1 == out2
      ```

    Args:
      method: The method to wrap.
      obj: The object instance passed to the wrapped method.
      args: The positional arguments (Tensors) passed to the wrapped method.
      kwargs: The keyword arguments passed to the wrapped method.

    Returns:
      Output of the wrapped method.

    Raises:
      ValueError: If no variable scope is provided or if `method` is a method
                  and a variable_scope keyword argument is also provided.
    """

        # If @reuse_variables is combined with @property, obj is passed in args
        # and method is still unbound at this stage.
        if obj is None:
            obj = args[0]

        def default_context_manager(reuse=None):
            variable_scope = obj.variable_scope
            return tf.variable_scope(variable_scope, reuse=reuse)

        variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                                 default_context_manager)

        with tf.init_scope():
            # We need `init_scope` incase we're running inside a defun. In that case
            # what we want is information about where the function will be called not
            # where the function is being built.
            graph = tf.get_default_graph()
            will_call_in_eager_context = tf.executing_eagerly()

        if will_call_in_eager_context:
            initialized_variable_scopes = initialized_variable_scopes_eager
        else:
            if graph not in initialized_variable_scopes_graph:
                initialized_variable_scopes_graph[graph] = set()
            initialized_variable_scopes = initialized_variable_scopes_graph[
                graph]

        # Temporarily enter the variable scope to capture it
        with variable_scope_context_manager() as tmp_variable_scope:
            variable_scope = tmp_variable_scope

        reuse = variable_scope.name in initialized_variable_scopes

        # Enter the pure variable scope with reuse correctly set
        with variable_scope_ops._pure_variable_scope(  # pylint:disable=protected-access
                variable_scope, reuse=reuse) as pure_variable_scope:
            current_name_scope = tf.get_default_graph().get_name_scope()
            # Force tf.name_scope to treat current_name_scope as an "absolute" scope
            # so we can re-enter it.
            if current_name_scope and current_name_scope[-1] != "/":
                current_name_scope += "/"
            with tf.name_scope(current_name_scope):
                module_name = pure_variable_scope.name
                method_name = to_snake_case(method.__name__)
                method_name_scope = "{}/{}".format(module_name, method_name)
                with tf.name_scope(method_name_scope) as scope:
                    if hasattr(obj, "_capture_variables"):
                        with obj._capture_variables():  # pylint: disable=protected-access
                            out_ops = method(*args, **kwargs)
                    else:
                        out_ops = method(*args, **kwargs)
            initialized_variable_scopes.add(pure_variable_scope.name)
            try:
                # If `obj` is a Sonnet module, let it know it's been connected
                # to the TF graph.
                obj._is_connected = True  # pylint: disable=protected-access
                if not tf.executing_eagerly():
                    obj._add_connected_subgraph(  # pylint: disable=protected-access
                        method, out_ops, scope, args, kwargs)
            except AttributeError:
                pass
        return out_ops

    return eager_test(call_method(method))  # pylint: disable=no-value-for-parameter