예제 #1
0
 def add(self, layer_func):
   if isinstance(layer_func, base.Layer):
     args = function_utils.fn_args(layer_func.call)
     self.track_layer(layer_func)
   elif callable(layer_func):
     args = function_utils.fn_args(layer_func)
   else:
     raise TypeError(
         "Sequential.add() takes only tf.layers.Layer objects or callables; "
         "not '%s' of type '%s'." % (layer_func, type(layer_func)))
   self._layers_funcs.append((("training" in args), layer_func))
예제 #2
0
  def run_step_fn(self, step_fn):
    """Run ops using a step function.

    Args:
      step_fn: A function or a method with a single argument of type
        `StepContext`.  The function may use methods of the argument to
        perform computations with access to a raw session.

        The returned value of the `step_fn` will be returned from `run_step_fn`,
        unless a stop is requested.  In that case, the next `should_stop` call
        will return True.

        Example usage:

        ```python
           with tf.Graph().as_default():
             c = tf.placeholder(dtypes.float32)
             v = tf.add(c, 4.0)
             w = tf.add(c, 0.5)

             def step_fn(step_context):
               a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
               if a <= 4.5:
                 step_context.request_stop()
               return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})

             with tf.MonitoredSession() as session:
               while not session.should_stop():
                 a = session.run_step_fn(step_fn)
        ```

        Hooks interact with the `run_with_hooks()` call inside the `step_fn`
        as they do with a `MonitoredSession.run` call.

    Returns:
      Returns the returned value of `step_fn`.

    Raises:
      StopIteration: if `step_fn` has called `request_stop()`.  It may be
        caught by `with tf.MonitoredSession()` to close the session.
      ValueError: if `step_fn` doesn't have a single argument called
        `step_context`. It may also optionally have `self` for cases when it
        belongs to an object.
    """
    step_fn_arguments = function_utils.fn_args(step_fn)
    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
        'self',
        'step_context',
    ):
      raise ValueError(
          '`step_fn` may either have one `step_context` argument, or'
          ' `self` and `step_context` arguments if it\'s an instance'
          ' method. Got {} instead.'.format(step_fn_arguments))

    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
    # `_CoordinatedSession.run` downstream in either case. This allows
    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
    # `_RecoverableSession.run_step_fn`.
    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
    def call(*args):
      kwargs = dict(
          zip(function_utils.fn_args(getattr(self._type, name))[1:], args))
      specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs)

      if specs is None:
        raise ValueError(
            'No tensor specifications were provided for: %s' % name)

      flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs))
      flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs))

      def py_call(*args):
        try:
          self._out.send(args)
          result = self._out.recv()
          if isinstance(result, Exception):
            raise result
          if result is not None:
            return result
        except Exception as e:
          if isinstance(e, IOError):
            raise StopIteration()  # Clean exit.
          else:
            raise

      result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes,
                          name=name)

      if isinstance(result, tf.Operation):
        return result

      for t, shape in zip(result, flat_shapes):
        t.set_shape(shape)
      return nest.pack_sequence_as(specs, result)
예제 #4
0
    def eval_step():
      """A single step of evaluation."""
      estimator_spec = self._call_model_fn(features, labels,
                                           model_fn_lib.ModeKeys.EVAL, params)

      try:
        captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
      except AttributeError:
        captured_scaffold_fn.capture(None)

      eval_metric_fn = None
      eval_metric_fn_tensors = []
      try:
        if estimator_spec.eval_metrics:
          (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
      except AttributeError:
        pass

      # If a dictionary is provided, we need to convert it into a list sorted
      # according to order of eval_metric_fn positional arguments.
      if isinstance(eval_metric_fn_tensors, dict):
        eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
        eval_metric_fn_tensors = [
            eval_metric_fn_tensors[i] for i in eval_metric_fn_args
        ]

      captured_eval_metric_fn.capture(eval_metric_fn)

      return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
  def test_callable(self):

    class Foo(object):

      def __call__(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
  def test_bounded_method(self):

    class Foo(object):

      def bar(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
  def __init__(self, type_, *constructor_args, **constructor_kwargs):
    self._type = type_
    self._constructor_kwargs = dict(
        zip(function_utils.fn_args(type_.__init__)[1:], constructor_args))
    self._constructor_kwargs.update(constructor_kwargs)

    tf.add_to_collection(PyProcess.COLLECTION, self)

    self._proxy = _TFProxy(type_, self._constructor_kwargs)
예제 #8
0
def _get_standardized_predicate_fn(predicate_fn):
  pred_fn_args = function_utils.fn_args(predicate_fn)
  if "checkpoint_path" not in pred_fn_args:
    # pylint: disable=unused-argument
    def _pred_fn_wrapper(eval_results, checkpoint_path):
      return predicate_fn(eval_results)

    return _pred_fn_wrapper
  else:
    return predicate_fn
예제 #9
0
  def _verify_estimator_spec(self, estimator_spec):
    """Verifies estimator spec contains correct data."""
    # TODO(ycao): Implement estimator spec verification for other modes.

    try:
      if estimator_spec.scaffold:
        logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
                        '. Please use TPUEstimatorSpec.scaffold_fn instead.')
    except AttributeError:
      pass

    try:
      if estimator_spec.eval_metric_ops:
        raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
                         'XLA compilation. Please use '
                         'TPUEstimatorSpec.eval_metrics instead.')
    except AttributeError:
      pass

    if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
      # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
      # check that eval_metrics contains eval_metric_fn and
      # eval_metric_fn_tensors with matching arguments.
      try:
        eval_metrics = estimator_spec.eval_metrics
      except AttributeError:
        eval_metrics = None

      if eval_metrics:
        (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
        eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)

        if isinstance(eval_metric_fn_tensors, dict):
          missing_tensors = [
              i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
          ]
          additional_tensors = [
              i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
          ]

          if missing_tensors:
            raise ValueError('Arguments %s are needed by metric_fn (first '
                             'element of TPUEstimatorSpec.eval_metrics) but '
                             'they are not provided by evaluation tensors '
                             '(second element of TPUEstimatorSpec.eval_metrics)'
                             '.' % missing_tensors)

          if additional_tensors:
            raise ValueError('Arguments %s are provided by evaluation tensors '
                             '(second element of TPUEstimatorSpec.eval_metrics)'
                             ' but they are not needed by metric_fn (first '
                             'element of TPUEstimatorSpec.eval_metrics).' %
                             additional_tensors)

    return estimator_spec
  def test_partial_function(self):
    expected_test_arg = 123

    def fn(a, test_arg):
      if test_arg != expected_test_arg:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg=123)

    self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
예제 #11
0
  def test_double_partial(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(a, test_arg1, test_arg2):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg2=456)
    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
예제 #12
0
def _call_metric_fn(metric_fn, features, labels, predictions, config):
  """Calls metric fn with proper arguments."""
  metric_fn_args = function_utils.fn_args(metric_fn)
  kwargs = {}
  if 'features' in metric_fn_args:
    kwargs['features'] = features
  if 'labels' in metric_fn_args:
    kwargs['labels'] = labels
  if 'predictions' in metric_fn_args:
    kwargs['predictions'] = predictions
  if 'config' in metric_fn_args:
    kwargs['config'] = config
  return metric_fn(**kwargs)
  def test_double_partial(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(a, test_arg1, test_arg2):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg2=456)
    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
예제 #14
0
def _call_metric_fn(metric_fn, features, labels, predictions, config):
    """Calls metric fn with proper arguments."""
    metric_fn_args = function_utils.fn_args(metric_fn)
    kwargs = {}
    if 'features' in metric_fn_args:
        kwargs['features'] = features
    if 'labels' in metric_fn_args:
        kwargs['labels'] = labels
    if 'predictions' in metric_fn_args:
        kwargs['predictions'] = predictions
    if 'config' in metric_fn_args:
        kwargs['config'] = config
    return metric_fn(**kwargs)
예제 #15
0
    def test_partial_function_with_positional_args(self):
        expected_test_arg = 123

        def fn(test_arg, a):
            if test_arg != expected_test_arg:
                return ValueError('partial fn does not work correctly')
            return a

        wrapped_fn = functools.partial(fn, 123)

        self.assertEqual(('a', ), function_utils.fn_args(wrapped_fn))

        self.assertEqual(3, wrapped_fn(3))
        self.assertEqual(3, wrapped_fn(a=3))
예제 #16
0
    def test_do_not_convert_argspec(self):
        class TestClass(object):
            def test_method(self, x, y):
                z = x + y
                return z

            test_method_whitelisted = api.do_not_convert(test_method)

        tc = TestClass()
        self.assertTrue(tf_inspect.ismethod(tc.test_method_whitelisted))
        # Because the wrapped function is not generated, we can't preserve its
        # arg spec.
        self.assertEqual(
            (), tuple(function_utils.fn_args(tc.test_method_whitelisted)))
예제 #17
0
def _validate_properties(run_config):
    """Validates the properties."""
    def _validate(property_name, cond, message):
        property_value = getattr(run_config, property_name)
        if property_value is not None and not cond(property_value):
            raise ValueError(message)

    _validate('model_dir',
              lambda dir: dir,
              message='model_dir should be non-empty')

    _validate('save_summary_steps',
              lambda steps: steps >= 0,
              message='save_summary_steps should be >= 0')

    _validate('save_checkpoints_steps',
              lambda steps: steps >= 0,
              message='save_checkpoints_steps should be >= 0')
    _validate('save_checkpoints_secs',
              lambda secs: secs >= 0,
              message='save_checkpoints_secs should be >= 0')

    _validate('session_config',
              lambda sc: isinstance(sc, config_pb2.ConfigProto),
              message='session_config must be instance of ConfigProto')

    _validate('keep_checkpoint_max',
              lambda keep_max: keep_max >= 0,
              message='keep_checkpoint_max should be >= 0')
    _validate('keep_checkpoint_every_n_hours',
              lambda keep_hours: keep_hours > 0,
              message='keep_checkpoint_every_n_hours should be > 0')
    _validate('log_step_count_steps',
              lambda num_steps: num_steps > 0,
              message='log_step_count_steps should be > 0')

    _validate('tf_random_seed',
              lambda seed: isinstance(seed, six.integer_types),
              message='tf_random_seed must be integer.')

    _validate('device_fn',
              lambda device_fn: six.callable(device_fn) and set(
                  function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
              message='device_fn must be callable with exactly'
              ' one argument "op".')

    _validate('protocol',
              lambda protocol: protocol in (None, "grpc", "grpc+verbs"),
              message='protocol should be grpc or grpc+verbs')
예제 #18
0
    def _call_adanet_model_fn(self, input_fn, mode, config):
        """See the `Estimator` base class for details."""

        # Bind parameters to input_fn since the parent's input_fn is not expected to
        # have any arguments.
        input_fn_args = function_utils.fn_args(input_fn)
        kwargs = {}
        if "mode" in input_fn_args:
            kwargs["mode"] = mode
        if "params" in input_fn_args:
            kwargs["params"] = self.params
        if "config" in input_fn_args:
            kwargs["config"] = config
        input_fn = functools.partial(input_fn, **kwargs)
        super(TPUEstimator, self)._call_adanet_model_fn(input_fn, mode, config)
예제 #19
0
    def _call_model_fn(self, features, labels, mode, config):
        model_fn_args = function_utils.fn_args(self._model_fn)
        kwargs = {}
        if 'labels' in model_fn_args:
            kwargs['labels'] = labels
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = config

        model_fn_results = self._model_fn(features=features, **kwargs)

        return model_fn_results
예제 #20
0
  def test_double_partial_with_positional_args_in_both_layers(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(test_arg1, test_arg2, a):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, 123)  # binds to test_arg1
    double_wrapped_fn = functools.partial(wrapped_fn, 456)  # binds to test_arg2

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))

    self.assertEqual(3, double_wrapped_fn(3))  # pylint: disable=no-value-for-parameter
    self.assertEqual(3, double_wrapped_fn(a=3))  # pylint: disable=no-value-for-parameter
  def test_double_partial_with_positional_args_in_both_layers(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(test_arg1, test_arg2, a):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, 123)  # binds to test_arg1
    double_wrapped_fn = functools.partial(wrapped_fn, 456)  # binds to test_arg2

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))

    self.assertEqual(3, double_wrapped_fn(3))
    self.assertEqual(3, double_wrapped_fn(a=3))
예제 #22
0
def _validate_function_args(function, required_args):
    """Asserts all of the `required_args` are presented in `function` args.

  Args:
    function: (function) A python function.
    required_args: (list) A list of strings indicating the required args.

  Raises:
    ValueError: If any of `required_args` does not present in the `function`.
  """
    fn_args = function_utils.fn_args(function)
    if set(fn_args) != set(required_args):
        raise ValueError(
            "Function `%s` needs to have the following arguments: %s."
            " What were provided are the following: %s." %
            (function.__name__, sorted(required_args), sorted(fn_args)))
예제 #23
0
    def test_double_partial_with_positional_args_in_outer_layer(self):
        expected_test_arg1 = 123
        expected_test_arg2 = 456

        def fn(test_arg1, a, test_arg2):
            if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
                return ValueError('partial fn does not work correctly')
            return a

        wrapped_fn = functools.partial(fn, test_arg2=456)
        double_wrapped_fn = functools.partial(wrapped_fn, 123)

        self.assertEqual(('a', ), function_utils.fn_args(double_wrapped_fn))

        self.assertEqual(3, double_wrapped_fn(3))
        self.assertEqual(3, double_wrapped_fn(a=3))
예제 #24
0
def verify_model_fn_args(model_fn, params):
  """Verifies `model_fn` arguments."""
  args = set(function_utils.fn_args(model_fn))
  if 'features' not in args:
    raise ValueError('model_fn (%s) must include features argument.' % model_fn)
  if params is not None and 'params' not in args:
    raise ValueError('model_fn (%s) does not include params argument, '
                     'but params (%s) is passed to Estimator.' %
                     (model_fn, params))
  if params is None and 'params' in args:
    tf.compat.v1.logging.warn(
        'Estimator\'s model_fn (%s) includes params '
        'argument, but params are not passed to Estimator.', model_fn)
  non_valid_args = list(args - _VALID_MODEL_FN_ARGS)
  if non_valid_args:
    raise ValueError('model_fn (%s) has following not expected args: %s' %
                     (model_fn, non_valid_args))
예제 #25
0
    def test_do_not_convert_preserves_argspec(self):
        class TestClass(object):
            @api.do_not_convert(run_as=api.RunMode.GRAPH)
            def test_method(self, x, y):
                z = x + y
                return z

            test_method_do_not_convert = api.do_not_convert(
                run_as=api.RunMode.GRAPH)(test_method)

        tc = TestClass()
        self.assertTrue(tf_inspect.ismethod(tc.test_method_do_not_convert))
        self.assertAllEqual(
            ('x', 'y'), function_utils.fn_args(tc.test_method_do_not_convert))
        self.assertListEqual(
            list(tf_inspect.getfullargspec(tc.test_method)),
            list(tf_inspect.getfullargspec(tc.test_method_do_not_convert)))
예제 #26
0
def _wrap_and_verify_model_fn(model_fn,
                              mode=None,
                              config=None,
                              params=None,
                              input_signature=None):
    """Returns a function that only has only tensor arguments (features, labels).

  Args:
    model_fn: Model function. Must follow the signature defined in
      `tf.estimator.Estimator`.
    mode: Optional string `tf.estimstor.ModeKey`.
    config: Optional `estimator.RunConfig` object.
    params: Optional `dict` of hyperparameters.
    input_signature: Possibly nested TensorSpec of the tensor arguments.

  Returns:
    tuple of (
      function that only accepts tensor arguments (features and/or labels),
      whether the returned function expects a labels argument)
  """
    model_fn_lib.verify_model_fn_args(model_fn, params)
    args = function_utils.fn_args(model_fn)
    kwargs = {}
    if 'mode' in args:
        kwargs['mode'] = mode
    if 'params' in args:
        kwargs['params'] = params
    if 'config' in args:
        kwargs['config'] = config

    if 'labels' in args:
        if input_signature is None or len(input_signature) == 2:

            def wrapped_model_fn(features, labels=None):
                return model_fn(features=features, labels=labels, **kwargs)
        else:

            def wrapped_model_fn(features):
                return model_fn(features=features, labels=None, **kwargs)
    else:

        def wrapped_model_fn(features):
            return model_fn(features=features, **kwargs)

    return wrapped_model_fn, 'labels' in args
예제 #27
0
def call_logit_fn(logit_fn, features, mode, params, config):
  """Calls logit_fn (experimental).

  THIS FUNCTION IS EXPERIMENTAL. Keras layers/models are the recommended APIs
  for logit and model composition.

  A utility function that calls the provided logit_fn with the relevant subset
  of provided arguments. Similar to tf.estimator._call_model_fn().

  Args:
    logit_fn: A logit_fn as defined above.
    features: The features dict.
    mode: TRAIN / EVAL / PREDICT ModeKeys.
    params: The hyperparameter dict.
    config: The configuration object.

  Returns:
    A logit Tensor, the output of logit_fn.

  Raises:
    ValueError: if logit_fn does not return a Tensor or a dictionary mapping
      strings to Tensors.
  """
  logit_fn_args = function_utils.fn_args(logit_fn)
  kwargs = {}
  if 'mode' in logit_fn_args:
    kwargs['mode'] = mode
  if 'params' in logit_fn_args:
    kwargs['params'] = params
  if 'config' in logit_fn_args:
    kwargs['config'] = config
  logit_fn_results = logit_fn(features=features, **kwargs)

  result_is_valid_dictionary = (
      isinstance(logit_fn_results, dict) and
      all([(isinstance(k, six.string_types) and isinstance(v, tf.Tensor))
           for k, v in six.iteritems(logit_fn_results)]))
  result_is_tensor = isinstance(logit_fn_results, tf.Tensor)

  if not (result_is_valid_dictionary or result_is_tensor):
    raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
                     'strings to Tensors.  logit_fn returned: %s' %
                     logit_fn_results)

  return logit_fn_results
예제 #28
0
  def _call_input_fn_in_new_graph(self, input_fn, mode, config):
    """See the `Estimator` base class for details."""

    # Bind parameters to input_fn since the parent's input_fn is not expected to
    # have any arguments.
    from tensorflow.python.util import function_utils  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
    input_fn_args = function_utils.fn_args(input_fn)
    kwargs = {}
    if "mode" in input_fn_args:
      kwargs["mode"] = mode
    if "params" in input_fn_args:
      kwargs["params"] = self.params
    if "config" in input_fn_args:
      kwargs["config"] = config
    input_fn = functools.partial(input_fn, **kwargs)
    with super(TPUEstimator,
               self)._call_input_fn_in_new_graph(input_fn, mode, config) as res:
      yield res
예제 #29
0
    def _call_model_fn(self, features, labels, mode, config):
        model_fn_args = function_utils.fn_args(self._model_fn)
        kwargs = {}
        if 'labels' in model_fn_args:
            kwargs['labels'] = labels
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = config

        model_fn_results = self._model_fn(features=features, **kwargs)

        if not isinstance(model_fn_results, TFEstimatorSpec):
            raise ValueError('model_fn should return an TFEstimatorSpec.')

        return model_fn_results
예제 #30
0
  def _call_model_fn(self, features, labels, mode, params):
    """Calls the model_fn with required parameters."""
    model_fn_args = function_utils.fn_args(self._model_fn)
    kwargs = {}

    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    elif labels is not None:
      raise ValueError(
          'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = mode

    if 'params' in model_fn_args:
      kwargs['params'] = params

    return self._verify_estimator_spec(
        self._model_fn(features=features, **kwargs))
예제 #31
0
def add_train_op(model_fn, features, labels, mode, params, config, optimizer):
    model_fn_args = function_utils.fn_args(model_fn)
    kwargs = {}
    if 'labels' in model_fn_args:
        kwargs['labels'] = labels
    else:
        if labels is not None:
            raise ValueError(
                'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
        kwargs['mode'] = mode
    if 'params' in model_fn_args:
        kwargs['params'] = params
    if 'config' in model_fn_args:
        kwargs['config'] = config

    spec = model_fn(features=features, **kwargs)

    if isinstance(spec, tf.estimator.EstimatorSpec):
        train_op = spec.train_op
    else:
        train_op = None

    if mode == tf.estimator.ModeKeys.TRAIN and train_op is None:
        if optimizer is None:
            raise ValueError(
                "optimizer should be set when used for training. For example:"
                + " Estimator(model_fn, tf.train.AdamOptimizer())")
        grads_and_vars = optimizer.compute_gradients(spec.loss)

        vars_with_grad = [v for g, v in grads_and_vars if g is not None]
        if not vars_with_grad:
            raise ValueError(
                "No gradients provided for any variable, check your graph for ops"
                " that do not support gradients, between variables %s and loss %s."
                % ([str(v) for _, v in grads_and_vars], spec.loss))

        train_op = spec.optimizer.apply_gradients(
            grads_and_vars, global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(mode, spec.predictions, spec.loss,
                                      train_op, spec.eval_metric_ops,
                                      spec.export_outputs)
예제 #32
0
def call_logit_fn(logit_fn, features, mode, params, config):
  """Calls logit_fn.

  A utility function that calls the provided logit_fn with the relevant subset
  of provided arguments.  Similar to tf.estimator._call_model_fn().

  Args:
    logit_fn: A logit_fn as defined above.
    features: The features dict.
    mode: TRAIN / EVAL / PREDICT ModeKeys.
    params: The hyperparameter dict.
    config: The configuration object.

  Returns:
    A logit Tensor, the output of logit_fn.

  Raises:
    ValueError: if logit_fn does not return a Tensor or a dictionary mapping
      strings to Tensors.
  """
  logit_fn_args = function_utils.fn_args(logit_fn)
  kwargs = {}
  if 'mode' in logit_fn_args:
    kwargs['mode'] = mode
  if 'params' in logit_fn_args:
    kwargs['params'] = params
  if 'config' in logit_fn_args:
    kwargs['config'] = config
  logit_fn_results = logit_fn(features=features, **kwargs)

  result_is_valid_dictionary = (
      isinstance(logit_fn_results, dict) and
      all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor))
           for k, v in six.iteritems(logit_fn_results)]))
  result_is_tensor = isinstance(logit_fn_results, ops.Tensor)

  if not (result_is_valid_dictionary or result_is_tensor):
    raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
                     'strings to Tensors.  logit_fn returned: %s' %
                     logit_fn_results)

  return logit_fn_results
        def call(*args):
            kwargs = dict(
                zip(
                    function_utils.fn_args(getattr(self._type, name))[1:],
                    args))
            specs = self._type._tensor_specs(name, kwargs,
                                             self._constructor_kwargs)

            if specs is None:
                raise ValueError(
                    'No tensor specifications were provided for: %s' % name)

            flat_dtypes = nest.flatten(
                nest.map_structure(lambda s: s.dtype, specs))
            flat_shapes = nest.flatten(
                nest.map_structure(lambda s: s.shape, specs))

            def py_call(*args):
                try:
                    self._out.send(args)
                    result = self._out.recv()

                    if isinstance(result, Exception):
                        raise result
                    if result is not None:
                        return result
                except Exception as e:
                    if isinstance(e, IOError):
                        raise StopIteration()  # Clean exit.
                    else:
                        raise

            result = tf.py_func(py_call, (name, ) + tuple(args),
                                flat_dtypes,
                                name=name)
            if isinstance(result, tf.Operation):
                return result

            for t, shape in zip(result, flat_shapes):
                t.set_shape(shape)
            retval = nest.pack_sequence_as(specs, result)
            return retval
예제 #34
0
def _validate_properties(run_config):
  """Validates the properties."""
  def _validate(property_name, cond, message):
    property_value = getattr(run_config, property_name)
    if property_value is not None and not cond(property_value):
      raise ValueError(message)

  _validate('model_dir', lambda dir: dir,
            message='model_dir should be non-empty')

  _validate('save_summary_steps', lambda steps: steps >= 0,
            message='save_summary_steps should be >= 0')

  _validate('save_checkpoints_steps', lambda steps: steps >= 0,
            message='save_checkpoints_steps should be >= 0')
  _validate('save_checkpoints_secs', lambda secs: secs >= 0,
            message='save_checkpoints_secs should be >= 0')

  _validate('session_config',
            lambda sc: isinstance(sc, config_pb2.ConfigProto),
            message='session_config must be instance of ConfigProto')

  _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0,
            message='keep_checkpoint_max should be >= 0')
  _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0,
            message='keep_checkpoint_every_n_hours should be > 0')
  _validate('log_step_count_steps', lambda num_steps: num_steps > 0,
            message='log_step_count_steps should be > 0')

  _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
            message='tf_random_seed must be integer.')

  _validate('device_fn', lambda device_fn: six.callable(device_fn) and
            set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
            message='device_fn must be callable with exactly'
                    ' one argument "op".')

  _validate('protocol',
            lambda protocol: protocol in (None, "grpc", "grpc+verbs"),
            message='protocol should be grpc or grpc+verbs')
예제 #35
0
def validate_loss_fn_args(loss_fn):
  """Validates loss_fn arguments.

  Required arguments: labels, logits.
  Optional arguments: features, loss_reduction.

  Args:
    loss_fn: The loss function.

  Raises:
    ValueError: If the signature is unexpected.
  """
  loss_fn_args = function_utils.fn_args(loss_fn)
  for required_arg in ['labels', 'logits']:
    if required_arg not in loss_fn_args:
      raise ValueError(
          'loss_fn must contain argument: {}. '
          'Given arguments: {}'.format(required_arg, loss_fn_args))
  invalid_args = list(set(loss_fn_args) - set(
      ['labels', 'logits', 'features', 'loss_reduction']))
  if invalid_args:
    raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
예제 #36
0
        def wrapper(features: Any, labels: Any, mode: Any, params: Any, config: Any) -> Any:
            # Tensorflow inspects the arguments of `model_fn()`. We provide all the possible
            # arguments and then inspect the ones that are used by the `model_fn()`.
            model_fn_args = function_utils.fn_args(f)

            kwargs = {}
            if "labels" in model_fn_args:
                kwargs["labels"] = labels
            if "mode" in model_fn_args:
                kwargs["mode"] = mode
            if "params" in model_fn_args:
                kwargs["params"] = params
            if "config" in model_fn_args:
                kwargs["config"] = config

            self._set_default_tensorflow_session(
                env=self.env,
                hvd_config=self.hvd_config,
                session_config=config.session_config,
            )

            return f(features, **kwargs)
예제 #37
0
    def _easy_call_model_fn(self, features, labels, mode, config):
        model_fn_args = function_utils.fn_args(self._model_fn)
        kwargs = {}
        if 'labels' in model_fn_args:
            kwargs['labels'] = labels
        else:
            if labels is not None:
                raise ValueError('model_fn does not take labels, '
                                 'but input_fn returns labels.')
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = config

        tf.logging.info('Calling model_fn.')
        model_fn_results = self._model_fn(features=features, **kwargs)
        tf.logging.info('Done calling model_fn.\n')

        if not isinstance(model_fn_results, tf.estimator.EstimatorSpec):
            raise ValueError('model_fn should return an EstimatorSpec.')

        return model_fn_results
예제 #38
0
    def _call_input_fn(self, input_fn, mode, input_context=None):
        """Calls the input function.

    Args:
      input_fn: The input function.
      mode: `tf.estimator.ModeKeys`

    Returns:
      The return value of the passed `input_fn`, which should be one of:

        * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
            tuple `(features, labels)` with same constraints as below.
        * A tuple `(features, labels)`: Where `features` is a `Tensor` or a
          dictionary of string feature name to `Tensor` and `labels` is a
          `Tensor` or a dictionary of string label name to `Tensor`. Both
          `features` and `labels` are consumed by `model_fn`. They should
          satisfy the expectation of `model_fn` from inputs.

    Raises:
      ValueError: if `input_fn` takes invalid arguments.
    """
        input_fn_args = function_utils.fn_args(input_fn)
        kwargs = {}
        if 'mode' in input_fn_args:
            kwargs['mode'] = mode
        if 'params' in input_fn_args:
            kwargs['params'] = self.params
        if 'config' in input_fn_args:
            kwargs['config'] = self.config
        if input_context and 'input_context' in input_fn_args:
            logging.info(
                'The `input_fn` accepts an `input_context` which will '
                'be given by DistributionStrategy')
            kwargs['input_context'] = input_context
        #with ops.device('/cpu:0'):
        return input_fn(**kwargs)
예제 #39
0
    def validate(host_calls):
        """Validates the `eval_metrics` and `host_call` in `NPUEstimatorSpec`."""

        for name, host_call in host_calls.items():
            if not isinstance(host_call, (tuple, list)):
                raise ValueError('{} should be tuple or list'.format(name))
            if len(host_call) != 2:
                raise ValueError('{} should have two elements.'.format(name))
            if not callable(host_call[0]):
                raise TypeError('{}[0] should be callable.'.format(name))
            if not isinstance(host_call[1], (tuple, list, dict)):
                raise ValueError(
                    '{}[1] should be tuple or list, or dict.'.format(name))

            if isinstance(host_call[1], (tuple, list)):
                fullargspec = tf_inspect.getfullargspec(host_call[0])
                fn_args = function_utils.fn_args(host_call[0])
                # wrapped_hostcall_with_global_step uses varargs, so we allow that.
                if fullargspec.varargs is None and len(
                        host_call[1]) != len(fn_args):
                    raise RuntimeError(
                        'In NPUEstimatorSpec.{}, length of tensors {} does not match '
                        'method args of the function, which takes {}.'.format(
                            name, len(host_call[1]), len(fn_args)))
 def test_simple_function(self):
   def fn(a, b):
     return a + b
   self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
예제 #41
0
    def _verify_estimator_spec(self, estimator_spec):
        """Verifies estimator spec contains correct data."""
        # TODO(ycao): Implement estimator spec verification for other modes.

        try:
            if estimator_spec.scaffold:
                logging.warning(
                    'EstimatorSpec.scaffold is ignored with XLA compilation'
                    '. Please use TPUEstimatorSpec.scaffold_fn instead.')
        except AttributeError:
            pass

        try:
            if estimator_spec.eval_metric_ops:
                raise ValueError(
                    'EstimatorSpec.eval_metric_ops is not supported with '
                    'XLA compilation. Please use '
                    'TPUEstimatorSpec.eval_metrics instead.')
        except AttributeError:
            pass

        if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
            # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
            # check that eval_metrics contains eval_metric_fn and
            # eval_metric_fn_tensors with matching arguments.
            try:
                eval_metrics = estimator_spec.eval_metrics
            except AttributeError:
                eval_metrics = None

            if eval_metrics:
                (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
                eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)

                if isinstance(eval_metric_fn_tensors, dict):
                    missing_tensors = [
                        i for i in eval_metric_fn_args
                        if i not in eval_metric_fn_tensors
                    ]
                    additional_tensors = [
                        i for i in eval_metric_fn_tensors
                        if i not in eval_metric_fn_args
                    ]

                    if missing_tensors:
                        raise ValueError(
                            'Arguments %s are needed by metric_fn (first '
                            'element of TPUEstimatorSpec.eval_metrics) but '
                            'they are not provided by evaluation tensors '
                            '(second element of TPUEstimatorSpec.eval_metrics)'
                            '.' % missing_tensors)

                    if additional_tensors:
                        raise ValueError(
                            'Arguments %s are provided by evaluation tensors '
                            '(second element of TPUEstimatorSpec.eval_metrics)'
                            ' but they are not needed by metric_fn (first '
                            'element of TPUEstimatorSpec.eval_metrics).' %
                            additional_tensors)

        return estimator_spec
예제 #42
0
def _verify_metric_fn_args(metric_fn):
    args = set(function_utils.fn_args(metric_fn))
    invalid_args = list(args - _VALID_METRIC_FN_ARGS)
    if invalid_args:
        raise ValueError('metric_fn (%s) has following not expected args: %s' %
                         (metric_fn, invalid_args))
예제 #43
0
def _validate_properties(run_config):
    """Validates the properties."""
    def _validate(property_name, cond, message):
        property_value = getattr(run_config, property_name)
        if property_value is not None and not cond(property_value):
            raise ValueError(message)

    def _validate_delay(delay):
        """Check that delay is an integer value.

    Since this has to work for both Python2 and Python3 and PEP237 defines long
    to be basically int, we cannot just use a lambda function.
    """
        try:
            return isinstance(delay, (int, long))
        except NameError:
            # PEP237 redefines long to int for Python3
            return isinstance(delay, int)

    _validate('model_dir',
              lambda dir: dir,
              message='model_dir should be non-empty')

    _validate('save_summary_steps',
              lambda steps: steps >= 0,
              message='save_summary_steps should be >= 0')

    _validate('save_checkpoints_steps',
              lambda steps: steps >= 0,
              message='save_checkpoints_steps should be >= 0')
    _validate('save_checkpoints_secs',
              lambda secs: secs >= 0,
              message='save_checkpoints_secs should be >= 0')

    _validate('session_config',
              lambda sc: isinstance(sc, config_pb2.ConfigProto),
              message='session_config must be instance of ConfigProto')

    _validate('keep_checkpoint_max',
              lambda keep_max: keep_max >= 0,
              message='keep_checkpoint_max should be >= 0')
    _validate('keep_checkpoint_every_n_hours',
              lambda keep_hours: keep_hours > 0,
              message='keep_checkpoint_every_n_hours should be > 0')
    _validate('log_step_count_steps',
              lambda num_steps: num_steps > 0,
              message='log_step_count_steps should be > 0')

    _validate('tf_random_seed',
              lambda seed: isinstance(seed, six.integer_types),
              message='tf_random_seed must be integer.')

    _validate(
        'experimental_max_worker_delay_secs',
        _validate_delay,
        message='experimental_max_worker_delay_secs must be an integer if'
        ' set.')
    _validate('session_creation_timeout_secs',
              lambda timeout_secs: timeout_secs > 0,
              message='session_creation_timeout_secs should be > 0')

    _validate('device_fn',
              lambda device_fn: six.callable(device_fn) and set(
                  function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
              message='device_fn must be callable with exactly'
              ' one argument "op".')

    _validate('protocol',
              lambda protocol: protocol in (None, 'grpc', 'grpc+verbs'),
              message='protocol should be grpc or grpc+verbs')
예제 #44
0
def _get_loss_towers(model_fn,
                     mode,
                     features,
                     labels,
                     params,
                     config,
                     devices,
                     local_ps_devices,
                     loss_reduction,
                     name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
  """Replicate the loss computation across devices."""
  tower_specs = []

  model_fn_args = function_utils.fn_args(model_fn)
  optional_params = {}
  if 'params' in model_fn_args:
    optional_params['params'] = copy.deepcopy(params)
  if 'config' in model_fn_args:
    optional_params['config'] = copy.deepcopy(config)

  # pylint: disable=protected-access
  round_robin_strategy = device_setter_lib._RoundRobinStrategy(
      num_tasks=len(local_ps_devices))
  TowerOptimizer._graph_state().set_reduction_across_towers(
      loss_reduction, len(devices))

  for i, device in enumerate(devices):
    is_the_first_tower = (i == 0)

    device_setter = _local_device_setter(
        worker_device=device,
        ps_devices=local_ps_devices,
        ps_strategy=round_robin_strategy)

    # We would like to preserve the names of the variables and ops that the user
    # might be relying on. Names without a prefix are going to resolve to
    # variables and ops of the first tower.
    name_scope = name_scope_pattern
    if is_the_first_tower:
      name_scope = ''

    with variable_scope.variable_scope(
        '', reuse=not is_the_first_tower) as var_scope:
      with ops_lib.name_scope(name_scope.format(i)) as name_scope:
        with TowerOptimizer._graph_state().tower(
            tower_id=i, var_scope=var_scope, name_scope=name_scope):
          with ops_lib.device(device_setter):
            labels_shard = None
            if labels:
              labels_shard = labels[i]

            tower_spec = model_fn(
                mode=mode,
                features=features[i],
                labels=labels_shard,
                **optional_params)

            if (tower_spec.train_op is not None and len(devices) > 1 and
                not TowerOptimizer.has_been_used()):
              raise ValueError('Please wrap optimizers with TowerOptimizer'
                               ' in order to use replicate_model_fn with'
                               ' multiple `devices`.')

            # Scaling the loss here doesn't actually affect gradients.  Another
            # instance of scaling happens inside the TowerOptimizer.
            tower_spec = _scale_tower_loss(
                tower_spec, loss_reduction, number_of_towers=len(devices))
            tower_specs.append(tower_spec)

  if not TowerOptimizer._did_towers_have_same_optimizer_calls():
    raise ValueError('Each invocation of model_fn was supposed to make the same'
                     ' optimizer calls.')
  TowerOptimizer._clear_graph_state()
  # pylint: enable=protected-access
  return tower_specs
예제 #45
0
        def call(*args):
            kwargs = dict(
                zip(
                    function_utils.fn_args(getattr(self._type, name))[1:],
                    args))
            specs = self._type._tensor_specs(name, kwargs,
                                             self._constructor_kwargs)
            # print("name is: ", name)
            # print("kwargs are: ", kwargs)
            # print("specs are: ", specs)

            if specs is None:
                raise ValueError(
                    'No tensor specifications were provided for: %s' % name)

            flat_dtypes = nest.flatten(
                nest.map_structure(lambda s: s.dtype, specs))
            # print("(py_process.py) tensor specs: ", specs)
            # print()
            # print("(py_process.py) Flat dtypes: ", flat_dtypes)
            flat_shapes = nest.flatten(
                nest.map_structure(lambda s: s.shape, specs))

            def py_call(*args):
                # print("(PyProcess.py) args are: ", args)
                # print("(PyProcess.py) first args are: ", args[0])
                # print("(PyProcess.py) type of args are: ", type(args))
                try:
                    self._out.send(args)
                    result = self._out.recv()
                    # print("(PyProcess.py) result is: ", result)
                    # print("Result: ", result)
                    # print("This is what is given to initial(): ", result)
                    if isinstance(result, Exception):
                        # print("Result: ", result)
                        # print("(PyProcess.py) Exception is: ", result)
                        raise result
                    if result is not None:
                        # print("(PyProcess.py) good result: ", type(result[1]))
                        return result
                except Exception as e:
                    if isinstance(e, IOError):
                        raise StopIteration()  # Clean exit.
                    else:
                        raise

            # print("name is: ", name)
            # print("args are: ", tuple(args))
            # print("(PyProcess.py) Pycall is: ", type(py_call))
            # print("Before result in pyprocess")

            # print("(PyProcess.py) 2nd argument: ", (name,) + tuple(args))

            result = tf.py_func(py_call, (name, ) + tuple(args),
                                flat_dtypes,
                                name=name)
            # print("after result in pyprocess")
            # print("(PyProcess.py) operation result is: ", result)
            if isinstance(result, tf.Operation):
                # print("(PyProcess.py) operation result is: ", result)
                return result

            for t, shape in zip(result, flat_shapes):
                t.set_shape(shape)
            retval = nest.pack_sequence_as(specs, result)
            return retval
예제 #46
0
def _verify_metric_fn_args(metric_fn):
  args = set(function_utils.fn_args(metric_fn))
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))
예제 #47
0
def _get_loss_towers(model_fn,
                     mode,
                     features,
                     labels,
                     params,
                     config,
                     devices,
                     local_ps_devices,
                     loss_reduction,
                     name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
  """Replicate the loss computation across devices."""
  tower_specs = []

  model_fn_args = function_utils.fn_args(model_fn)
  optional_params = {}
  if 'params' in model_fn_args:
    optional_params['params'] = copy.deepcopy(params)
  if 'config' in model_fn_args:
    optional_params['config'] = copy.deepcopy(config)

  # pylint: disable=protected-access
  round_robin_strategy = device_setter_lib._RoundRobinStrategy(
      num_tasks=len(local_ps_devices))
  TowerOptimizer._graph_state().set_reduction_across_towers(
      loss_reduction, len(devices))

  for i, device in enumerate(devices):
    is_the_first_tower = (i == 0)

    device_setter = _local_device_setter(
        worker_device=device,
        ps_devices=local_ps_devices,
        ps_strategy=round_robin_strategy)

    # We would like to preserve the names of the variables and ops that the user
    # might be relying on. Names without a prefix are going to resolve to
    # variables and ops of the first tower.
    name_scope = name_scope_pattern
    if is_the_first_tower:
      name_scope = ''

    with variable_scope.variable_scope(
        '', reuse=not is_the_first_tower) as var_scope:
      with ops_lib.name_scope(name_scope.format(i)) as name_scope:
        with TowerOptimizer._graph_state().tower(
            tower_id=i, var_scope=var_scope, name_scope=name_scope):
          with ops_lib.device(device_setter):
            labels_shard = None
            if labels:
              labels_shard = labels[i]

            tower_spec = model_fn(
                mode=mode,
                features=features[i],
                labels=labels_shard,
                **optional_params)

            if (tower_spec.train_op is not None and len(devices) > 1 and
                not TowerOptimizer.has_been_used()):
              raise ValueError('Please wrap optimizers with TowerOptimizer'
                               ' in order to use replicate_model_fn with'
                               ' multiple `devices`.')

            # Scaling the loss here doesn't actually affect gradients.  Another
            # instance of scaling happens inside the TowerOptimizer.
            tower_spec = _scale_tower_loss(
                tower_spec, loss_reduction, number_of_towers=len(devices))
            tower_specs.append(tower_spec)

  if not TowerOptimizer._did_towers_have_same_optimizer_calls():
    raise ValueError('Each invocation of model_fn was supposed to make the same'
                     ' optimizer calls.')
  TowerOptimizer._clear_graph_state()
  # pylint: enable=protected-access
  return tower_specs
예제 #48
0
파일: base.py 프로젝트: neuroph12/CNNDDDD
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

    Returns:
      Output tensor(s).

    Note:
      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

    Raises:
      ValueError: if the layer's `call` method returns None (an invalid value).
    """
    self._set_scope(kwargs.pop('scope', None))

    if not context.executing_eagerly():
      try:
        # Set layer's "graph" at build time
        self._graph = ops._get_graph_from_inputs(nest.flatten(inputs),  # pylint: disable=protected-access
                                                 graph=self._graph)
      except ValueError as e:
        raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    if self.built:
      try:
        # Some classes which inherit from Layer do not use its constructor, so
        # rather than initializing to None we check for an AttributeError.
        scope_context_manager = self._always_reuse_variable_scope
      except AttributeError:
        # From this point we will always set reuse=True, so create a "final"
        # variable scope with this setting. We avoid re-creating variable scopes
        # after this point as an optimization.
        self._always_reuse_variable_scope = vs.variable_scope(
            self._scope, reuse=True, auxiliary_name_scope=False)
        scope_context_manager = self._always_reuse_variable_scope
    else:
      scope_context_manager = vs.variable_scope(
          self._scope, reuse=self._reuse, auxiliary_name_scope=False)

    with scope_context_manager as scope:
      self._current_scope = scope

      try:
        call_has_scope_arg = self._call_has_scope_arg
      except AttributeError:
        self._call_fn_args = function_utils.fn_args(self.call)
        self._call_has_scope_arg = 'scope' in self._call_fn_args
        call_has_scope_arg = self._call_has_scope_arg
      if call_has_scope_arg:
        kwargs['scope'] = scope

      # Actually call layer
      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

    if not context.executing_eagerly():
      # Update global default collections.
      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs
예제 #49
0
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

    Returns:
      Output tensor(s).

    Note:
      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

    Raises:
      ValueError: if the layer's `call` method returns None (an invalid value).
    """
    scope = kwargs.pop('scope', None)

    if self._keras_style:
      if scope is not None:
        raise ValueError(
            'scope argument not allowed when keras style layers are enabled, '
            'but saw: {}'.format(scope))
      return super(Layer, self).__call__(inputs, *args, **kwargs)

    self._set_scope(scope)

    if not context.executing_eagerly():
      try:
        # Set layer's "graph" at build time
        self._graph = ops._get_graph_from_inputs(nest.flatten(inputs),  # pylint: disable=protected-access
                                                 graph=self._graph)
      except ValueError as e:
        raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    if self.built:
      try:
        # Some classes which inherit from Layer do not use its constructor, so
        # rather than initializing to None we check for an AttributeError.
        scope_context_manager = self._always_reuse_variable_scope
      except AttributeError:
        # From this point we will always set reuse=True, so create a "final"
        # variable scope with this setting. We avoid re-creating variable scopes
        # after this point as an optimization.
        self._always_reuse_variable_scope = vs.variable_scope(
            self._scope, reuse=True, auxiliary_name_scope=False)
        scope_context_manager = self._always_reuse_variable_scope
    else:
      scope_context_manager = vs.variable_scope(
          self._scope, reuse=self._reuse, auxiliary_name_scope=False)

    with scope_context_manager as scope:
      self._current_scope = scope

      try:
        call_has_scope_arg = self._call_has_scope_arg
      except AttributeError:
        self._call_fn_args = function_utils.fn_args(self.call)
        self._call_has_scope_arg = 'scope' in self._call_fn_args
        call_has_scope_arg = self._call_has_scope_arg
      if call_has_scope_arg:
        kwargs['scope'] = scope

      # Actually call layer
      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

    if not context.executing_eagerly():
      # Update global default collections.
      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs