Exemple #1
0
 def add(self, layer_func):
   if isinstance(layer_func, base.Layer):
     args = estimator_util.fn_args(layer_func.call)
     self.track_layer(layer_func)
   elif callable(layer_func):
     args = estimator_util.fn_args(layer_func)
   else:
     raise TypeError(
         "Sequential.add() takes only tf.layers.Layer objects or callables; "
         "not '%s' of type '%s'." % (layer_func, type(layer_func)))
   self._layers_funcs.append((("training" in args), layer_func))
Exemple #2
0
  def _call_model_fn(self, features, labels, add_batch_size_in_params=False):
    """Calls the model_fn with required parameters."""
    model_fn_args = util.fn_args(self._model_fn)
    kwargs = {}

    config = copy.deepcopy(self._config)
    params = copy.deepcopy(self._params)

    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    else:
      if labels is not None:
        raise ValueError(
            'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = self._mode
    if 'config' in model_fn_args:
      kwargs['config'] = config
    if 'params' in model_fn_args:
      kwargs['params'] = params

    if add_batch_size_in_params:
      if 'params' not in model_fn_args:
        raise ValueError(
            'model_fn ({}) does not include params argument, '
            'required by TPUEstimator to pass batch size as '
            'params[\'batch_size\']'.format(self._model_fn))
      if self._mode == model_fn_lib.ModeKeys.TRAIN:
        # For TPU training. `params` is never `None`.
        params[_BATCH_SIZE_KEY] = _per_shard_batch_size(self._train_batch_size,
                                                        config)

    return self._model_fn(features=features, **kwargs)
  def _call_model_fn(self, features, labels, mode, config):
    """Calls model function.

    Args:
      features: features dict.
      labels: labels dict.
      mode: ModeKeys
      config: RunConfig

    Returns:
      An `EstimatorSpec` object.

    Raises:
      ValueError: if model_fn returns invalid objects.
    """
    model_fn_args = util.fn_args(self._model_fn)
    kwargs = {}
    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    else:
      if labels is not None:
        raise ValueError(
            'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = mode
    if 'params' in model_fn_args:
      kwargs['params'] = self.params
    if 'config' in model_fn_args:
      kwargs['config'] = config
    model_fn_results = self._model_fn(features=features, **kwargs)

    if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
      raise ValueError('model_fn should return an EstimatorSpec.')

    return model_fn_results
Exemple #4
0
def _call_loss_fn(loss_fn, labels, logits, features):
  """Calls loss_fn and checks the returned shape.

  Args:
    loss_fn: The loss function.
    labels: Processed labels Tensor.
    logits: Logits Tensor of shape [batch_size, logits_dimension].
    features: Features dict.
  Returns:
    Loss Tensor with shape [batch_size, 1].
  """
  loss_fn_args = util.fn_args(loss_fn)
  kwargs = {}
  if 'features' in loss_fn_args:
    kwargs['features'] = features
  unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
  batch_size = array_ops.shape(logits)[0]
  loss_shape = array_ops.shape(unweighted_loss)
  check_shape_op = control_flow_ops.Assert(
      math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])),
      data=[
          'loss_fn must return Tensor of shape [batch_size, 1]. Given: ',
          loss_shape])
  with ops.control_dependencies([check_shape_op]):
    return array_ops.identity(unweighted_loss)
  def run_step_fn(self, step_fn):
    """Run ops using a step function.

    Args:
      step_fn: A function or a method with a single argument of type
        `StepContext`.  The function may use methods of the argument to
        perform computations with access to a raw session.

        The returned value of the `step_fn` will be returned from `run_step_fn`,
        unless a stop is requested.  In that case, the next `should_stop` call
        will return True.

        Example usage:

        ```python
           with tf.Graph().as_default():
             c = tf.placeholder(dtypes.float32)
             v = tf.add(c, 4.0)
             w = tf.add(c, 0.5)

             def step_fn(step_context):
               a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
               if a <= 4.5:
                 step_context.request_stop()
               return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})

             with tf.MonitoredSession() as session:
               while not session.should_stop():
                 a = session.run_step_fn(step_fn)
        ```

        Hooks interact with the `run_with_hooks()` call inside the `step_fn`
        as they do with a `MonitoredSession.run` call.

    Returns:
      Returns the returned value of `step_fn`.

    Raises:
      StopIteration: if `step_fn` has called `request_stop()`.  It may be
        caught by `with tf.MonitoredSession()` to close the session.
      ValueError: if `step_fn` doesn't have a single argument called
        `step_context`. It may also optionally have `self` for cases when it
        belongs to an object.
    """
    step_fn_arguments = util.fn_args(step_fn)
    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
        'self',
        'step_context',
    ):
      raise ValueError(
          '`step_fn` may either have one `step_context` argument, or'
          ' `self` and `step_context` arguments if it\'s an instance'
          ' method. Got {} instead.'.format(step_fn_arguments))

    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
    # `_CoordinatedSession.run` downstream in either case. This allows
    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
    # `_RecoverableSession.run_step_fn`.
    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
Exemple #6
0
def call_logit_fn(logit_fn, features, mode, params, config):
  """Calls logit_fn.

  A utility function that calls the provided logit_fn with the relevant subset
  of provided arguments.  Similar to tf.estimator._call_model_fn().

  Args:
    logit_fn: A logit_fn as defined above.
    features: The features dict.
    mode: TRAIN / EVAL / PREDICT ModeKeys.
    params: The hyperparameter dict.
    config: The configuration object.

  Returns:
    A logit Tensor, the output of logit_fn.

  Raises:
    ValueError: if logit_fn does not return a Tensor.
  """
  logit_fn_args = util.fn_args(logit_fn)
  kwargs = {}
  if 'mode' in logit_fn_args:
    kwargs['mode'] = mode
  if 'params' in logit_fn_args:
    kwargs['params'] = params
  if 'config' in logit_fn_args:
    kwargs['config'] = config
  logit_fn_results = logit_fn(features=features, **kwargs)

  if not isinstance(logit_fn_results, ops.Tensor):
    raise ValueError('model_fn should return a Tensor.')

  return logit_fn_results
  def export(self,
             estimator,
             export_path,
             checkpoint_path=None,
             eval_result=None):
    """Exports the given Estimator to a specific format.

    Args:
      estimator: the Estimator to export.
      export_path: A string containing a directory where to write the export.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the strategy may locate a checkpoint (e.g. the most recent) by itself.
      eval_result: The output of Estimator.evaluate on this checkpoint.  This
        should be set only if checkpoint_path is provided (otherwise it is
        unclear which checkpoint this eval refers to).

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: if the export_fn does not have the required signature.
    """
    export_fn_args = util.fn_args(self.export_fn)
    kwargs = {}
    if 'checkpoint_path' in export_fn_args:
      kwargs['checkpoint_path'] = checkpoint_path
    if 'eval_result' in export_fn_args:
      if 'checkpoint_path' not in export_fn_args:
        raise ValueError('An export_fn accepting eval_result must also accept '
                         'checkpoint_path.')
      kwargs['eval_result'] = eval_result

    return self.export_fn(estimator, export_path, **kwargs)
Exemple #8
0
  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments.
    """
    input_fn_args = util.fn_args(input_fn)
    kwargs = {}
    if 'mode' in input_fn_args:
      kwargs['mode'] = mode
    if 'params' in input_fn_args:
      kwargs['params'] = self.params
    if 'config' in input_fn_args:
      kwargs['config'] = self.config
    with ops.device('/cpu:0'):
      return input_fn(**kwargs)
  def test_bounded_method(self):

    class Foo(object):

      def bar(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
Exemple #10
0
  def test_callable(self):

    class Foo(object):

      def __call__(self, a, b):
        return a + b

    self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo()))
Exemple #11
0
def _verify_metric_fn_args(metric_fn):
  args = set(estimator_util.fn_args(metric_fn))
  if tf_inspect.ismethod(metric_fn):
    if 'self' in args:
      args.remove('self')
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))
Exemple #12
0
  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments or does not have `params`.
    """
    if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
      return super(TpuEstimator, self)._call_input_fn(input_fn, mode)

    input_fn_args = util.fn_args(input_fn)
    config = self.config  # a deep copy.
    kwargs = {}
    if 'params' in input_fn_args:
      kwargs['params'] = self.params  # a deep copy.
    else:
      raise ValueError('input_fn ({}) does not include params argument, '
                       'required by TPUEstimator to pass batch size as '
                       'params["batch_size"]'.format(input_fn))
    if 'config' in input_fn_args:
      kwargs['config'] = config

    # Now for TPU training.
    per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
    kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size

    job = _tpu_job(config)
    def placement_function(index):
      if job is None:
        return '/replica:0/task:0/device:CPU:0'
      else:
        return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)

    features = []
    labels = []
    for i in range(config.tpu_config.num_shards):
      with ops.device(placement_function(i)):
        result = input_fn(**kwargs)
        # input_fn may return either features or (features, labels)
        if isinstance(result, tuple):
          features.append(result[0])
          labels.append(result[1])
        else:
          features.append(result)

    if not labels or all(l is None for l in labels):
      return _PerShardOutput(features), None

    return _PerShardOutput(features), _PerShardOutput(labels)
Exemple #13
0
  def _call_input_fn(self, input_fn, mode):
    """Calls the input function.

    Args:
      input_fn: The input function.
      mode: ModeKeys

    Returns:
      Either features or (features, labels) where features and labels are:
        features - `Tensor` or dictionary of string feature name to `Tensor`.
        labels - `Tensor` or dictionary of `Tensor` with labels.

    Raises:
      ValueError: if input_fn takes invalid arguments or does not have `params`.
    """
    input_fn_args = util.fn_args(input_fn)
    config = self.config  # a deep copy.
    kwargs = {}
    if 'params' in input_fn_args:
      kwargs['params'] = self.params  # a deep copy.
    else:
      raise ValueError('input_fn ({}) does not include params argument, '
                       'required by TPUEstimator to pass batch size as '
                       'params["batch_size"]'.format(input_fn))
    if 'config' in input_fn_args:
      kwargs['config'] = config

    # Now for TPU training.
    if mode == model_fn_lib.ModeKeys.TRAIN:
      kwargs['params'][_BATCH_SIZE_KEY] = (
          _per_shard_batch_size(self._train_batch_size, config, self._use_tpu)
          if not config.tpu_config.per_host_input_for_training else
          self._train_batch_size)

    if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
      with ops.device('/cpu:0'):
        return input_fn(**kwargs)

    job = _tpu_job(config)
    def placement_function(index):
      if job is None:
        return '/replica:0/task:0/device:CPU:0'
      else:
        return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8)

    if not config.tpu_config.per_host_input_for_training:
      num_shards = config.tpu_config.num_shards
      inputs = _InputsHolder(num_shards=num_shards)
      for i in range(config.tpu_config.num_shards):
        with ops.device(placement_function(i)):
          inputs.append_tuple(input_fn(**kwargs))

      return inputs.as_features_and_labels_tuple()
    else:
      # TODO(xiejw): Extend this to multi-host support.
      with ops.device(placement_function(0)):
        return input_fn(**kwargs)
Exemple #14
0
def _get_standardized_predicate_fn(predicate_fn):
  pred_fn_args = estimator_util.fn_args(predicate_fn)
  if "checkpoint_path" not in pred_fn_args:
    # pylint: disable=unused-argument
    def _pred_fn_wrapper(eval_results, checkpoint_path):
      return predicate_fn(eval_results)

    return _pred_fn_wrapper
  else:
    return predicate_fn
Exemple #15
0
  def test_partial_function(self):
    expected_test_arg = 123

    def fn(a, test_arg):
      if test_arg != expected_test_arg:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg=123)

    self.assertEqual(('a',), util.fn_args(wrapped_fn))
Exemple #16
0
  def test_double_partial(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(a, test_arg1, test_arg2):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg2=456)
    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)

    self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
Exemple #17
0
def _call_metric_fn(metric_fn, features, labels, predictions, config):
  """Calls metric fn with proper arguments."""
  metric_fn_args = estimator_util.fn_args(metric_fn)
  kwargs = {}
  if 'features' in metric_fn_args:
    kwargs['features'] = features
  if 'labels' in metric_fn_args:
    kwargs['labels'] = labels
  if 'predictions' in metric_fn_args:
    kwargs['predictions'] = predictions
  if 'config' in metric_fn_args:
    kwargs['config'] = config
  return metric_fn(**kwargs)
Exemple #18
0
  def test_partial_function_with_positional_args(self):
    expected_test_arg = 123

    def fn(test_arg, a):
      if test_arg != expected_test_arg:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, 123)

    self.assertEqual(('a',), util.fn_args(wrapped_fn))

    self.assertEqual(3, wrapped_fn(3))
    self.assertEqual(3, wrapped_fn(a=3))
Exemple #19
0
def _verify_compare_fn_args(compare_fn):
  """Verifies compare_fn arguments."""
  args = set(util.fn_args(compare_fn))
  if 'best_eval_result' not in args:
    raise ValueError(
        'compare_fn (%s) must include best_eval_result argument.' % compare_fn)
  if 'current_eval_result' not in args:
    raise ValueError(
        'compare_fn (%s) must include current_eval_result argument.' %
        compare_fn)
  non_valid_args = list(args - set(['best_eval_result', 'current_eval_result']))
  if non_valid_args:
    raise ValueError('compare_fn (%s) has following not expected args: %s' %
                     (compare_fn, non_valid_args))
Exemple #20
0
  def test_double_partial_with_positional_args_in_outer_layer(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(test_arg1, a, test_arg2):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg2=456)
    double_wrapped_fn = functools.partial(wrapped_fn, 123)

    self.assertEqual(('a',), util.fn_args(double_wrapped_fn))

    self.assertEqual(3, double_wrapped_fn(3))
    self.assertEqual(3, double_wrapped_fn(a=3))
Exemple #21
0
def _verify_model_fn_args(model_fn, params):
  """Verifies model fn arguments."""
  args = set(util.fn_args(model_fn))
  if 'features' not in args:
    raise ValueError('model_fn (%s) must include features argument.' % model_fn)
  if params is not None and 'params' not in args:
    raise ValueError('model_fn (%s) does not include params argument, '
                     'but params (%s) is passed to Estimator.' % (model_fn,
                                                                  params))
  if params is None and 'params' in args:
    logging.warning('Estimator\'s model_fn (%s) includes params '
                    'argument, but params are not passed to Estimator.',
                    model_fn)
  non_valid_args = list(args - _VALID_MODEL_FN_ARGS)
  if non_valid_args:
    raise ValueError('model_fn (%s) has following not expected args: %s' %
                     (model_fn, non_valid_args))
def _get_loss_towers(model_fn,
                     mode,
                     features,
                     labels,
                     params,
                     config,
                     devices,
                     local_ps_device,
                     name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
  """Replicate the loss computation across devices."""
  tower_specs = []

  model_fn_args = util.fn_args(model_fn)
  optional_params = {}
  if 'params' in model_fn_args:
    optional_params['params'] = copy.deepcopy(params)
  if 'config' in model_fn_args:
    optional_params['config'] = copy.deepcopy(config)

  for i, device in enumerate(devices):
    is_the_first_tower = (i == 0)

    device_setter = _local_device_setter(
        worker_device=device, ps_device=local_ps_device)

    # We would like to preserve the names of the variables and ops that a user
    # might be relying on. Names with prefix are going to resolve to variables
    # and ops of the first tower.
    name_scope = name_scope_pattern
    if is_the_first_tower:
      name_scope = ''

    with variable_scope.variable_scope('', reuse=not is_the_first_tower):
      with ops_lib.name_scope(name_scope.format(i)):
        with ops_lib.device(device_setter):
          labels_shard = None
          if labels:
            labels_shard = labels[i]

          tower_specs.append(
              model_fn(
                  mode=mode,
                  features=features[i],
                  labels=labels_shard,
                  **optional_params))
  return tower_specs
Exemple #23
0
def call_logit_fn(logit_fn, features, mode, params, config):
  """Calls logit_fn.

  A utility function that calls the provided logit_fn with the relevant subset
  of provided arguments.  Similar to tf.estimator._call_model_fn().

  Args:
    logit_fn: A logit_fn as defined above.
    features: The features dict.
    mode: TRAIN / EVAL / PREDICT ModeKeys.
    params: The hyperparameter dict.
    config: The configuration object.

  Returns:
    A logit Tensor, the output of logit_fn.

  Raises:
    ValueError: if logit_fn does not return a Tensor or a dictionary mapping
      strings to Tensors.
  """
  logit_fn_args = util.fn_args(logit_fn)
  kwargs = {}
  if 'mode' in logit_fn_args:
    kwargs['mode'] = mode
  if 'params' in logit_fn_args:
    kwargs['params'] = params
  if 'config' in logit_fn_args:
    kwargs['config'] = config
  logit_fn_results = logit_fn(features=features, **kwargs)

  result_is_valid_dictionary = (
      isinstance(logit_fn_results, dict) and
      all([(isinstance(k, str) and isinstance(v, ops.Tensor))
           for k, v in six.iteritems(logit_fn_results)]))
  result_is_tensor = isinstance(logit_fn_results, ops.Tensor)

  if not (result_is_valid_dictionary or result_is_tensor):
    raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
                     'strings to Tensors.  logit_fn returned: %s' %
                     logit_fn_results)

  return logit_fn_results
Exemple #24
0
def _validate_loss_fn_args(loss_fn):
  """Validates loss_fn arguments.

  Required arguments: labels, logits.
  Optional arguments: features.

  Args:
    loss_fn: The loss function.
  Raises:
    ValueError: If the signature is unexpected.
  """
  loss_fn_args = util.fn_args(loss_fn)
  for required_arg in ['labels', 'logits']:
    if required_arg not in loss_fn_args:
      raise ValueError(
          'loss_fn must contain argument: {}. '
          'Given arguments: {}'.format(required_arg, loss_fn_args))
  invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features']))
  if invalid_args:
    raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
Exemple #25
0
def _call_model_fn(model_fn, features, labels, mode, config, params,
                   require_params=False):
  """Calls the model_fn with required parameters."""
  model_fn_args = util.fn_args(model_fn)
  kwargs = {}
  if 'labels' in model_fn_args:
    kwargs['labels'] = labels
  else:
    if labels is not None:
      raise ValueError(
          'model_fn does not take labels, but input_fn returns labels.')
  if 'mode' in model_fn_args:
    kwargs['mode'] = mode
  if 'config' in model_fn_args:
    kwargs['config'] = config
  if 'params' in model_fn_args:
    kwargs['params'] = params
  elif require_params:
    raise ValueError(
        'model_fn ({}) does not include params argument, '
        'required by TPUEstimator to pass batch size as '
        'params[\'batch_size\']'.format(model_fn))
  return model_fn(features=features, **kwargs)
Exemple #26
0
def _validate_properties(run_config):
  """Validates the properties."""
  def _validate(property_name, cond, message):
    property_value = getattr(run_config, property_name)
    if property_value is not None and not cond(property_value):
      raise ValueError(message)

  _validate('model_dir', lambda dir: dir,
            message='model_dir should be non-empty')

  _validate('save_summary_steps', lambda steps: steps >= 0,
            message='save_summary_steps should be >= 0')

  _validate('save_checkpoints_steps', lambda steps: steps >= 0,
            message='save_checkpoints_steps should be >= 0')
  _validate('save_checkpoints_secs', lambda secs: secs >= 0,
            message='save_checkpoints_secs should be >= 0')

  _validate('session_config',
            lambda sc: isinstance(sc, config_pb2.ConfigProto),
            message='session_config must be instance of ConfigProto')

  _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0,
            message='keep_checkpoint_max should be >= 0')
  _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0,
            message='keep_checkpoint_every_n_hours should be > 0')
  _validate('log_step_count_steps', lambda num_steps: num_steps > 0,
            message='log_step_count_steps should be > 0')

  _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
            message='tf_random_seed must be integer.')

  _validate('device_fn', lambda device_fn: six.callable(device_fn) and
            set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
            message='device_fn must be callable with exactly'
                    ' one argument "op".')
Exemple #27
0
def _get_loss_towers(model_fn,
                     mode,
                     features,
                     labels,
                     params,
                     config,
                     devices,
                     local_ps_devices,
                     loss_reduction,
                     name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
    """Replicate the loss computation across devices."""
    tower_specs = []

    model_fn_args = util.fn_args(model_fn)
    optional_params = {}
    if 'params' in model_fn_args:
        optional_params['params'] = copy.deepcopy(params)
    if 'config' in model_fn_args:
        optional_params['config'] = copy.deepcopy(config)

    # pylint: disable=protected-access
    round_robin_strategy = device_setter_lib._RoundRobinStrategy(
        num_tasks=len(local_ps_devices))
    TowerOptimizer._graph_state().set_reduction_across_towers(
        loss_reduction, len(devices))

    for i, device in enumerate(devices):
        is_the_first_tower = (i == 0)

        device_setter = _local_device_setter(worker_device=device,
                                             ps_devices=local_ps_devices,
                                             ps_strategy=round_robin_strategy)

        # We would like to preserve the names of the variables and ops that the user
        # might be relying on. Names without a prefix are going to resolve to
        # variables and ops of the first tower.
        name_scope = name_scope_pattern
        if is_the_first_tower:
            name_scope = ''

        with variable_scope.variable_scope(
                '', reuse=not is_the_first_tower) as var_scope:
            with ops_lib.name_scope(name_scope.format(i)) as name_scope:
                with TowerOptimizer._graph_state().tower(
                        tower_id=i, var_scope=var_scope,
                        name_scope=name_scope):
                    with ops_lib.device(device_setter):
                        labels_shard = None
                        if labels:
                            labels_shard = labels[i]

                        tower_spec = model_fn(mode=mode,
                                              features=features[i],
                                              labels=labels_shard,
                                              **optional_params)

                        if (tower_spec.train_op is not None
                                and len(devices) > 1
                                and not TowerOptimizer.has_been_used()):
                            raise ValueError(
                                'Please wrap optimizers with TowerOptimizer'
                                ' in order to use replicate_model_fn with'
                                ' multiple `devices`.')

                        # Scaling the loss here doesn't actually affect gradients.  Another
                        # instance of scaling happens inside the TowerOptimizer.
                        tower_spec = _scale_tower_loss(
                            tower_spec,
                            loss_reduction,
                            number_of_towers=len(devices))
                        tower_specs.append(tower_spec)

    if not TowerOptimizer._did_towers_have_same_optimizer_calls():
        raise ValueError(
            'Each invocation of model_fn was supposed to make the same'
            ' optimizer calls.')
    TowerOptimizer._clear_graph_state()
    # pylint: enable=protected-access
    return tower_specs
    def dynamic_decode_and_search(self,
                                  embedding,
                                  start_tokens,
                                  end_token,
                                  vocab_size=None,
                                  initial_state=None,
                                  output_layer=None,
                                  beam_width=5,
                                  length_penalty=0.0,
                                  maximum_iterations=250,
                                  mode=tf.estimator.ModeKeys.PREDICT,
                                  memory=None,
                                  memory_sequence_length=None,
                                  dtype=None,
                                  return_alignment_history=False):
        if (return_alignment_history and "reorder_tensor_arrays"
                not in fn_args(tf.contrib.seq2seq.BeamSearchDecoder.__init__)):
            tf.logging.warn(
                "The current version of tf.contrib.seq2seq.BeamSearchDecoder "
                "does not support returning the alignment history. None will "
                "be returned instead. Consider upgrading TensorFlow.")
            alignment_history = False
        else:
            alignment_history = return_alignment_history

        batch_size = tf.shape(start_tokens)[0]

        # Replicate batch `beam_width` times.
        if initial_state is not None:
            initial_state = tf.contrib.seq2seq.tile_batch(
                initial_state, multiplier=beam_width)
        if memory is not None:
            memory = tf.contrib.seq2seq.tile_batch(memory,
                                                   multiplier=beam_width)
        if memory_sequence_length is not None:
            memory_sequence_length = tf.contrib.seq2seq.tile_batch(
                memory_sequence_length, multiplier=beam_width)

        cell, initial_state = self._build_cell(
            mode,
            batch_size * beam_width,
            initial_state=initial_state,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            dtype=dtype,
            alignment_history=alignment_history)

        if output_layer is None:
            output_layer = build_output_layer(self.num_units,
                                              vocab_size,
                                              dtype=dtype or memory.dtype)

        decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell,
            embedding,
            start_tokens,
            end_token,
            initial_state,
            beam_width,
            output_layer=output_layer,
            length_penalty_weight=length_penalty)

        outputs, beam_state, length = tf.contrib.seq2seq.dynamic_decode(
            decoder, maximum_iterations=maximum_iterations)

        predicted_ids = tf.transpose(outputs.predicted_ids, perm=[0, 2, 1])
        log_probs = beam_state.log_probs
        state = beam_state.cell_state

        if return_alignment_history:
            alignment_history = _get_alignment_history(state)
            if alignment_history is not None:
                alignment_history = tf.reshape(
                    alignment_history,
                    [-1, batch_size, beam_width,
                     tf.shape(memory)[1]])
            return (predicted_ids, state, length, log_probs, alignment_history)
        return (predicted_ids, state, length, log_probs)
Exemple #29
0
    def run_step_fn(self, step_fn):
        """Run ops using a step function.

    Args:
      step_fn: A function or a method with a single argument of type
        `StepContext`.  The function may use methods of the argument to
        perform computations with access to a raw session.

        The returned value of the `step_fn` will be returned from `run_step_fn`,
        unless a stop is requested.  In that case, the next `should_stop` call
        will return True.

        Example usage:

        ```python
           with tf.Graph().as_default():
             c = tf.placeholder(dtypes.float32)
             v = tf.add(c, 4.0)
             w = tf.add(c, 0.5)

             def step_fn(step_context):
               a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
               if a <= 4.5:
                 step_context.request_stop()
               return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})

             with tf.MonitoredSession() as session:
               while not session.should_stop():
                 a = session.run_step_fn(step_fn)
        ```

        Hooks interact with the `run_with_hooks()` call inside the `step_fn`
        as they do with a `MonitoredSession.run` call.

    Returns:
      Returns the returned value of `step_fn`.

    Raises:
      StopIteration: if `step_fn` has called `request_stop()`.  It may be
        caught by `with tf.MonitoredSession()` to close the session.
      ValueError: if `step_fn` doesn't have a single argument called
        `step_context`. It may also optionally have `self` for cases when it
        belongs to an object.
    """
        step_fn_arguments = util.fn_args(step_fn)
        if step_fn_arguments != ('step_context', ) and step_fn_arguments != (
                'self',
                'step_context',
        ):
            raise ValueError(
                '`step_fn` may either have one `step_context` argument, or'
                ' `self` and `step_context` arguments if it\'s an instance'
                ' method. Got {} instead.'.format(step_fn_arguments))

        # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
        # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
        # `_CoordinatedSession.run` downstream in either case. This allows
        # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
        # `_RecoverableSession.run_step_fn`.
        return self._sess.run_step_fn(step_fn,
                                      self._tf_sess(),
                                      run_with_hooks=None)
Exemple #30
0
def _verify_metric_fn_args(metric_fn):
  args = set(estimator_util.fn_args(metric_fn))
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))
def _get_loss_towers(model_fn,
                     mode,
                     features,
                     labels,
                     params,
                     config,
                     devices,
                     local_ps_devices,
                     loss_reduction,
                     name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
  """Replicate the loss computation across devices."""
  tower_specs = []

  model_fn_args = util.fn_args(model_fn)
  optional_params = {}
  if 'params' in model_fn_args:
    optional_params['params'] = copy.deepcopy(params)
  if 'config' in model_fn_args:
    optional_params['config'] = copy.deepcopy(config)

  # pylint: disable=protected-access
  round_robin_strategy = device_setter_lib._RoundRobinStrategy(
      num_tasks=len(local_ps_devices))
  TowerOptimizer._graph_state().set_reduction_across_towers(
      loss_reduction, len(devices))

  for i, device in enumerate(devices):
    is_the_first_tower = (i == 0)

    device_setter = _local_device_setter(
        worker_device=device,
        ps_devices=local_ps_devices,
        ps_strategy=round_robin_strategy)

    # We would like to preserve the names of the variables and ops that the user
    # might be relying on. Names without a prefix are going to resolve to
    # variables and ops of the first tower.
    name_scope = name_scope_pattern
    if is_the_first_tower:
      name_scope = ''

    with variable_scope.variable_scope(
        '', reuse=not is_the_first_tower) as var_scope:
      with ops_lib.name_scope(name_scope.format(i)) as name_scope:
        with TowerOptimizer._graph_state().tower(
            tower_id=i, var_scope=var_scope, name_scope=name_scope):
          with ops_lib.device(device_setter):
            labels_shard = None
            if labels:
              labels_shard = labels[i]

            tower_spec = model_fn(
                mode=mode,
                features=features[i],
                labels=labels_shard,
                **optional_params)

            if (tower_spec.train_op is not None and len(devices) > 1 and
                not TowerOptimizer.has_been_used()):
              raise ValueError('Please wrap optimizers with TowerOptimizer'
                               ' in order to use replicate_model_fn with'
                               ' multiple `devices`.')

            # Scaling the loss here doesn't actually affect gradients.  Another
            # instance of scaling happens inside the TowerOptimizer.
            tower_spec = _scale_tower_loss(
                tower_spec, loss_reduction, number_of_towers=len(devices))
            tower_specs.append(tower_spec)

  if not TowerOptimizer._did_towers_have_same_optimizer_calls():
    raise ValueError('Each invocation of model_fn was supposed to make the same'
                     ' optimizer calls.')
  TowerOptimizer._clear_graph_state()
  # pylint: enable=protected-access
  return tower_specs
Exemple #32
0
  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

    Returns:
      Output tensor(s).

    Note:
      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

    Raises:
      ValueError: if the layer's `call` method returns None (an invalid value).
    """
    self._set_scope(kwargs.pop('scope', None))

    if not context.executing_eagerly():
      try:
        # Set layer's "graph" at build time
        self._graph = ops._get_graph_from_inputs(nest.flatten(inputs),  # pylint: disable=protected-access
                                                 graph=self._graph)
      except ValueError as e:
        raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    if self.built:
      try:
        # Some classes which inherit from Layer do not use its constructor, so
        # rather than initializing to None we check for an AttributeError.
        scope_context_manager = self._always_reuse_variable_scope
      except AttributeError:
        # From this point we will always set reuse=True, so create a "final"
        # variable scope with this setting. We avoid re-creating variable scopes
        # after this point as an optimization.
        self._always_reuse_variable_scope = vs.variable_scope(
            self._scope, reuse=True, auxiliary_name_scope=False)
        scope_context_manager = self._always_reuse_variable_scope
    else:
      scope_context_manager = vs.variable_scope(
          self._scope, reuse=self._reuse, auxiliary_name_scope=False)

    with scope_context_manager as scope:
      self._current_scope = scope

      try:
        call_has_scope_arg = self._call_has_scope_arg
      except AttributeError:
        self._call_fn_args = estimator_util.fn_args(self.call)
        self._call_has_scope_arg = 'scope' in self._call_fn_args
        call_has_scope_arg = self._call_has_scope_arg
      if call_has_scope_arg:
        kwargs['scope'] = scope

      # Actually call layer
      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

    if not context.executing_eagerly():
      # Update global default collections.
      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs
Exemple #33
0
  def _run_internal_graph(self, inputs, masks=None):
    """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
    # Note: masking support is relevant mainly for Keras.
    # It cannot be factored out without having the fully reimplement the network
    # calling logic on the Keras side. We choose to incorporate it in
    # GraphNetwork because 1) it may be useful to fully support in tf.layers in
    # the future and 2) Keras is a major user of GraphNetwork.  If you don't
    # use masking, it does not interfere with regular behavior at all and you
    # can ignore it.
    if masks is None:
      masks = [None for _ in range(len(inputs))]

    # Dictionary mapping reference tensors to tuples
    # (computed tensor, compute mask)
    # we assume a 1:1 mapping from tensor to mask
    # TODO(fchollet): raise exception when a `.compute_mask()` call
    # does not return a list the same size as `call`
    tensor_map = {}
    for x, y, mask in zip(self.inputs, inputs, masks):
      tensor_map[str(id(x))] = (y, mask)

    depth_keys = list(self._nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
      nodes = self._nodes_by_depth[depth]
      for node in nodes:
        # This is always a single layer, never a list.
        layer = node.outbound_layer

        reference_input_tensors = node.input_tensors
        reference_output_tensors = node.output_tensors

        # If all previous input tensors are available in tensor_map,
        # then call node.inbound_layer on them.
        computed_data = []  # List of tuples (input, mask).
        for x in reference_input_tensors:
          if str(id(x)) in tensor_map:
            computed_data.append(tensor_map[str(id(x))])

        if len(computed_data) == len(reference_input_tensors):
          # Call layer (reapplying ops to new inputs).
          with ops.name_scope(layer.name):
            if node.arguments:
              kwargs = node.arguments
            else:
              kwargs = {}
            if len(computed_data) == 1:
              computed_tensor, computed_mask = computed_data[0]
              # Ensure mask propagation if applicable.
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_mask

              output_tensors = nest.flatten(
                  layer.call(computed_tensor, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensor, computed_mask))
              else:
                output_masks = [None for _ in range(len(output_tensors))]
              computed_tensors = [computed_tensor]
              computed_masks = [computed_mask]
            else:
              computed_tensors = [x[0] for x in computed_data]
              computed_masks = [x[1] for x in computed_data]
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_masks
              output_tensors = nest.flatten(
                  layer.call(computed_tensors, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensors, computed_masks))
              else:
                output_masks = [None for _ in range(len(output_tensors))]

            # Apply activity regularizer if any:
            if layer.activity_regularizer is not None:
              regularization_losses = [
                  layer.activity_regularizer(x) for x in computed_tensors
              ]
              layer.add_loss(regularization_losses, computed_tensors)

          if context.in_graph_mode():
            # Update model updates and losses:
            # Keep track of updates that depend on the inputs
            # (e.g. BN updates).
            self.add_update(layer.get_updates_for(computed_tensors), inputs)
            # Keep track of unconditional updates (e.g. a counter).
            self.add_update(layer.get_updates_for(None), None)
            # Keep track of losses that depend on the inputs
            # (e.g. activity regularizers).
            self.add_loss(layer.get_losses_for(computed_tensors), inputs)
            # Keep track of unconditional losses
            # (e.g. weight regularizers).
            self.add_loss(layer.get_losses_for(None), None)

          # Update tensor_map.
          for x, y, mask in zip(reference_output_tensors, output_tensors,
                                output_masks):
            tensor_map[str(id(x))] = (y, mask)

    output_tensors = []
    output_masks = []
    output_shapes = []
    for x in self.outputs:
      assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
      tensor, mask = tensor_map[str(id(x))]
      output_shapes.append(layers_util.static_shape(x))
      output_tensors.append(tensor)
      output_masks.append(mask)

    if len(output_tensors) == 1:
      output_tensors = output_tensors[0]
      if output_shapes is not None:
        output_shapes = output_shapes[0]
      if output_masks is not None:
        output_masks = output_masks[0]

    if context.in_graph_mode():
      # Update cache;
      # keys are based on ids on input tensors and inputs masks.
      cache_key = (layers_util.object_list_uid(inputs)
                   + '_' + layers_util.object_list_uid(masks))
      self._output_tensor_cache[cache_key] = output_tensors
      if output_masks is not None:
        self._output_mask_cache[cache_key] = output_masks
      if output_shapes is not None:
        input_shapes = [layers_util.static_shape(x) for x in inputs]
        cache_key = layers_util.object_list_uid(input_shapes)
        self._output_shape_cache[cache_key] = output_shapes

    return output_tensors, output_masks
def _call_optimizer_fn(optimizer_fn, params):
  arguments = {}
  optimizer_fn_arguments = util.fn_args(optimizer_fn)
  if 'params' in optimizer_fn_arguments:
    arguments['params'] = params
  return optimizer_fn(**arguments)
Exemple #35
0
  def _run_internal_graph(self, inputs, masks=None):
    """Computes output tensors for new inputs.

    # Note:
        - Expects `inputs` to be a list (potentially with 1 element).
        - Can be run on non-Keras tensors.

    Arguments:
        inputs: List of tensors
        masks: List of masks (tensors or None).

    Returns:
        Three lists: output_tensors, output_masks, output_shapes
    """
    # Note: masking support is relevant mainly for Keras.
    # It cannot be factored out without having the fully reimplement the network
    # calling logic on the Keras side. We choose to incorporate it in
    # GraphNetwork because 1) it may be useful to fully support in tf.layers in
    # the future and 2) Keras is a major user of GraphNetwork.  If you don't
    # use masking, it does not interfere with regular behavior at all and you
    # can ignore it.
    if masks is None:
      masks = [None for _ in range(len(inputs))]

    # Dictionary mapping reference tensors to tuples
    # (computed tensor, compute mask)
    # we assume a 1:1 mapping from tensor to mask
    # TODO(fchollet): raise exception when a `.compute_mask()` call
    # does not return a list the same size as `call`
    tensor_map = {}
    for x, y, mask in zip(self.inputs, inputs, masks):
      tensor_map[str(id(x))] = (y, mask)

    depth_keys = list(self._nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
      nodes = self._nodes_by_depth[depth]
      for node in nodes:
        # This is always a single layer, never a list.
        layer = node.outbound_layer
        reference_input_tensors = node.input_tensors
        reference_output_tensors = node.output_tensors

        # If all previous input tensors are available in tensor_map,
        # then call node.inbound_layer on them.
        computed_data = []  # List of tuples (input, mask).
        for x in reference_input_tensors:
          if str(id(x)) in tensor_map:
            computed_data.append(tensor_map[str(id(x))])

        if len(computed_data) == len(reference_input_tensors):
          # Call layer (reapplying ops to new inputs).
          with ops.name_scope(layer.name):
            if node.arguments:
              kwargs = node.arguments
            else:
              kwargs = {}
            if len(computed_data) == 1:
              computed_tensor, computed_mask = computed_data[0]
              # Ensure mask propagation if applicable.
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_mask

              output_tensors = nest.flatten(
                  layer.call(computed_tensor, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensor, computed_mask))
              else:
                output_masks = [None for _ in range(len(output_tensors))]
              computed_tensors = [computed_tensor]
              computed_masks = [computed_mask]
            else:
              computed_tensors = [x[0] for x in computed_data]
              computed_masks = [x[1] for x in computed_data]
              if 'mask' in estimator_util.fn_args(layer.call):
                if 'mask' not in kwargs:
                  kwargs['mask'] = computed_masks
              output_tensors = nest.flatten(
                  layer.call(computed_tensors, **kwargs))
              if hasattr(layer, 'compute_mask'):
                output_masks = nest.flatten(
                    layer.compute_mask(computed_tensors, computed_masks))
              else:
                output_masks = [None for _ in range(len(output_tensors))]

            if context.in_graph_mode():
              if layer.activity_regularizer is not None:
                regularization_losses = [
                    layer.activity_regularizer(x) for x in computed_tensors
                ]
                # Apply activity regularizer if any:
                layer.add_loss(regularization_losses, computed_tensors)

          if context.in_graph_mode():
            # Update model updates and losses:
            # Keep track of updates that depend on the inputs
            # (e.g. BN updates).
            self.add_update(layer.get_updates_for(computed_tensors), inputs)
            # Keep track of unconditional updates (e.g. a counter).
            self.add_update(layer.get_updates_for(None), None)
            # Keep track of losses that depend on the inputs
            # (e.g. activity regularizers).
            self.add_loss(layer.get_losses_for(computed_tensors), inputs)
            # Keep track of unconditional losses
            # (e.g. weight regularizers).
            self.add_loss(layer.get_losses_for(None), None)

          # Update tensor_map.
          for x, y, mask in zip(reference_output_tensors, output_tensors,
                                output_masks):
            tensor_map[str(id(x))] = (y, mask)

    output_tensors = []
    output_masks = []
    output_shapes = []
    for x in self.outputs:
      assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
      tensor, mask = tensor_map[str(id(x))]
      output_shapes.append(layers_util.static_shape(x))
      output_tensors.append(tensor)
      output_masks.append(mask)

    if len(output_tensors) == 1:
      output_tensors = output_tensors[0]
      if output_shapes is not None:
        output_shapes = output_shapes[0]
      if output_masks is not None:
        output_masks = output_masks[0]

    if context.in_graph_mode():
      # Update cache;
      # keys are based on ids on input tensors and inputs masks.
      cache_key = (layers_util.object_list_uid(inputs)
                   + '_' + layers_util.object_list_uid(masks))
      self._output_tensor_cache[cache_key] = output_tensors
      if output_masks is not None:
        self._output_mask_cache[cache_key] = output_masks
      if output_shapes is not None:
        input_shapes = [layers_util.static_shape(x) for x in inputs]
        cache_key = layers_util.object_list_uid(input_shapes)
        self._output_shape_cache[cache_key] = output_shapes

    return output_tensors, output_masks
Exemple #36
0
  def __init__(self, inputs, outputs, name=None):  # pylint: disable=super-init-not-called
    if context.in_eager_mode():
      # TODO(fchollet): check that all inputs and outputs are DeferredTensors.
      pass

    self._init_set_name(name)
    self._activity_regularizer = None
    with vs.variable_scope(
        None, default_name=self._base_name) as captured_scope:
      self._scope = captured_scope
    call_fn_args = estimator_util.fn_args(self.call)
    self._compute_previous_mask = ('mask' in call_fn_args or
                                   hasattr(self, 'compute_mask'))
    self._call_has_scope_arg = 'scope' in call_fn_args

    # This acts just like the `trainable` attribute of any layer instance.
    # It does not affect users of the underlying layers, only users of the
    # GraphNetwork instance.
    self.trainable = True
    # A GraphNetwork does not create weights of its own, thus it is already
    # built.
    self.built = True
    # A GraphNetwork does not create weights of its own, thus has no dtype.
    self._dtype = None
    # The following are implemented as property functions:
    # self.trainable_weights
    # self.non_trainable_weights
    # self.input_spec

    # Private attributes to implement compatibility with Layer.
    self._per_input_losses = {}
    self._per_input_updates = {}
    self._updates = []
    self._losses = []
    self._scope = None
    self._reuse = None
    self._graph = ops.get_default_graph()

    # GraphNetwork-specific properties.
    if isinstance(inputs, (list, tuple)):
      self.inputs = list(inputs)  # Tensor or list of tensors.
    else:
      self.inputs = [inputs]
    if isinstance(outputs, (list, tuple)):
      self.outputs = list(outputs)
    else:
      self.outputs = [outputs]
    # All layers in order of horizontal graph traversal.
    # Entries are unique. Includes input and output layers.
    self.layers = []

    # Check for redundancy in inputs.
    if len(set(self.inputs)) != len(self.inputs):
      raise ValueError('The list of inputs passed to the model '
                       'is redundant. '
                       'All inputs should only appear once.'
                       ' Found: ' + str(self.inputs))

    # # List of initial layers (1 to 1 mapping with self.inputs,
    # # hence the same layer might appear twice)
    # self._input_layers = []
    # self._input_layers_node_indices = []
    # self._input_layers_tensor_indices = []
    # # list of layers (1 to 1 mapping with self.inputs,
    # # hence the same layer might appear twice)
    # self._output_layers = []
    # self._output_layers_node_indices = []
    # self._output_layers_tensor_indices = []

    self._input_layers = []
    self._output_layers = []
    self._input_coordinates = []
    self._output_coordinates = []

    # This is for performance optimization when calling the GraphNetwork on new
    # inputs. Every time the GraphNetwork is called on a set on input tensors,
    # we compute the output tensors, output masks and output shapes in one pass,
    # then cache them here. When any of these outputs is queried later, we
    # retrieve it from there instead of recomputing it.
    self._output_mask_cache = {}
    self._output_tensor_cache = {}
    self._output_shape_cache = {}

    # User-provided arguments validation.
    for x in self.inputs:
      # Check that x has appropriate `_keras_history` metadata.
      if not hasattr(x, '_keras_history'):
        cls_name = self.__class__.__name__
        raise ValueError('Input tensors to a ' + cls_name + ' ' +
                         'must come from `tf.layers.Input`. '
                         'Received: ' + str(x) +
                         ' (missing previous layer metadata).')
      # Check that x is an input tensor.
      # pylint: disable=protected-access
      layer, node_index, tensor_index = x._keras_history
      if len(layer._inbound_nodes) > 1 or (
          layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers):
        cls_name = self.__class__.__name__
        logging.warning(cls_name + ' inputs must come from '
                        '`tf.layers.Input` (thus holding past layer metadata), '
                        'they cannot be the output of '
                        'a previous non-Input layer. '
                        'Here, a tensor specified as '
                        'input to "' + self.name + '" was not an Input tensor, '
                        'it was generated by layer ' + layer.name + '.\n'
                        'Note that input tensors are '
                        'instantiated via `tensor = tf.layers.Input(shape)`.\n'
                        'The tensor that caused the issue was: ' + str(x.name))
      # pylint: enable=protected-access
    for x in self.outputs:
      if not hasattr(x, '_keras_history'):
        cls_name = self.__class__.__name__
        raise ValueError('Output tensors to a ' + cls_name + ' must be '
                         'the output of a TensorFlow `Layer` '
                         '(thus holding past layer metadata). Found: ' + str(x))

    # Build self._output_layers:
    for x in self.outputs:
      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
      self._output_layers.append(layer)
      self._output_coordinates.append((layer, node_index, tensor_index))

    # Build self._input_layers:
    for x in self.inputs:
      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
      # It's supposed to be an input layer, so only one node
      # and one tensor output.
      assert node_index == 0
      assert tensor_index == 0
      self._input_layers.append(layer)
      self._input_coordinates.append((layer, node_index, tensor_index))

    # Network_nodes: set of nodes included in the graph
    # (not all nodes included in the layers
    # are relevant to the current graph).
    network_nodes = set()  # ids of all nodes relevant to the GraphNetwork
    nodes_depths = {}  # dict {node: depth value}
    layers_depths = {}  # dict {layer: depth value}
    layer_indices = {}  # dict {layer: index in traversal}
    nodes_in_decreasing_depth = []

    def build_map_of_graph(tensor,
                           finished_nodes,
                           nodes_in_progress,
                           layer,
                           node_index,
                           tensor_index):
      """Builds a map of the graph of layers.

      This recursively updates the map `layer_indices`,
      the list `nodes_in_decreasing_depth` and the set `network_nodes`.

      Arguments:
          tensor: Some tensor in a graph.
          finished_nodes: Set of nodes whose subgraphs have been traversed
              completely. Useful to prevent duplicated work.
          nodes_in_progress: Set of nodes that are currently active on the
              recursion stack. Useful to detect cycles.
          layer: Layer from which `tensor` comes from. If not provided,
              will be obtained from `tensor._keras_history`.
          node_index: Node index from which `tensor` comes from.
          tensor_index: Tensor_index from which `tensor` comes from.

      Raises:
          ValueError: if a cycle is detected.
      """
      node = layer._inbound_nodes[node_index]  # pylint: disable=protected-access

      # Prevent cycles.
      if node in nodes_in_progress:
        raise ValueError('The tensor ' + str(tensor) + ' at layer "' +
                         layer.name + '" is part of a cycle.')

      # Don't repeat work for shared subgraphs
      if node in finished_nodes:
        return

      node_key = _make_node_key(layer.name, node_index)
      # Update network_nodes.
      network_nodes.add(node_key)

      # Store the traversal order for layer sorting.
      if layer not in layer_indices:
        layer_indices[layer] = len(layer_indices)

      nodes_in_progress.add(node)

      # Propagate to all previous tensors connected to this node.
      for i in range(len(node.inbound_layers)):
        x = node.input_tensors[i]
        layer = node.inbound_layers[i]
        node_index = node.node_indices[i]
        tensor_index = node.tensor_indices[i]
        build_map_of_graph(x, finished_nodes, nodes_in_progress, layer,
                           node_index, tensor_index)

      finished_nodes.add(node)
      nodes_in_progress.remove(node)
      nodes_in_decreasing_depth.append(node)

    finished_nodes = set()
    nodes_in_progress = set()
    for x in self.outputs:
      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
      build_map_of_graph(x, finished_nodes, nodes_in_progress,
                         layer=layer,
                         node_index=node_index,
                         tensor_index=tensor_index)

    for node in reversed(nodes_in_decreasing_depth):
      # If the depth is not set, the node has no outbound nodes (depth 0).
      depth = nodes_depths.setdefault(node, 0)

      # Update the depth of the corresponding layer
      previous_depth = layers_depths.get(node.outbound_layer, 0)
      # If we've seen this layer before at a higher depth,
      # we should use that depth instead of the node depth.
      # This is necessary for shared layers that have inputs at different
      # depth levels in the graph.
      depth = max(depth, previous_depth)
      layers_depths[node.outbound_layer] = depth
      nodes_depths[node] = depth

      # Update the depth of inbound nodes.
      # The "depth" of a node is the max of the depths
      # of all layers it is connected to.
      for i in range(len(node.inbound_layers)):
        inbound_layer = node.inbound_layers[i]
        node_index = node.node_indices[i]
        inbound_node = inbound_layer._inbound_nodes[node_index]  # pylint: disable=protected-access
        previous_depth = nodes_depths.get(inbound_node, 0)
        nodes_depths[inbound_node] = max(depth + 1, previous_depth)

    # Build a dict {depth: list of nodes with this depth}
    nodes_by_depth = {}
    for node, depth in nodes_depths.items():
      if depth not in nodes_by_depth:
        nodes_by_depth[depth] = []
      nodes_by_depth[depth].append(node)

    # Build a dict {depth: list of layers with this depth}
    layers_by_depth = {}
    for layer, depth in layers_depths.items():
      if depth not in layers_by_depth:
        layers_by_depth[depth] = []
      layers_by_depth[depth].append(layer)

    # Get sorted list of layer depths.
    depth_keys = list(layers_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Set self.layers and self._layers_by_depth.
    layers = []
    for depth in depth_keys:
      layers_for_depth = layers_by_depth[depth]
      # GraphNetwork.layers needs to have a deterministic order:
      # here we order them by traversal order.
      layers_for_depth.sort(key=lambda x: layer_indices[x])
      layers.extend(layers_for_depth)
    self.layers = layers
    self._layers_by_depth = layers_by_depth

    # Get sorted list of node depths.
    depth_keys = list(nodes_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Check that all tensors required are computable.
    # computable_tensors: all tensors in the graph
    # that can be computed from the inputs provided.
    computable_tensors = []
    for x in self.inputs:
      computable_tensors.append(x)

    layers_with_complete_input = []  # To provide a better error msg.
    for depth in depth_keys:
      for node in nodes_by_depth[depth]:
        layer = node.outbound_layer
        if layer:
          for x in node.input_tensors:
            if x not in computable_tensors:
              raise ValueError('Graph disconnected: '
                               'cannot obtain value for tensor ' + str(x) +
                               ' at layer "' + layer.name + '". '
                               'The following previous layers '
                               'were accessed without issue: ' +
                               str(layers_with_complete_input))
          for x in node.output_tensors:
            computable_tensors.append(x)
          layers_with_complete_input.append(layer.name)

    # Keep track of the network's nodes.
    self._network_nodes = network_nodes
    self._nodes_by_depth = nodes_by_depth

    # Ensure name unicity, which will be crucial for serialization
    # (since serialized nodes refer to layers by their name).
    all_names = [layer.name for layer in self.layers]
    for name in all_names:
      if all_names.count(name) != 1:
        raise ValueError('The name "' + name + '" is used ' +
                         str(all_names.count(name)) + ' times in the model. '
                         'All layer names should be unique.')

    # Layer parameters.
    # The new network starts with a single inbound node
    # for its inputs, and no outbound nodes.
    self._outbound_nodes = []  # Will be appended to by future calls to __call__
    self._inbound_nodes = [
    ]  # Will be appended to below, and by future calls to __call__
    # Create the node linking internal inputs to internal outputs.
    base.Node(
        outbound_layer=self,
        inbound_layers=[],
        node_indices=[],
        tensor_indices=[],
        input_tensors=self.inputs,
        output_tensors=self.outputs)
def _verify_metric_fn_args(metric_fn):
  args = set(estimator_util.fn_args(metric_fn))
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))
  def _testDecoderGeneric(self,
                          decoder,
                          with_beam_search=False,
                          with_alignment_history=False,
                          support_alignment_history=True):
    batch_size = 4
    beam_width = 5
    num_hyps = beam_width if with_beam_search else 1
    vocab_size = 10
    depth = 6
    end_token = 2
    start_tokens = tf.placeholder_with_default([1] * batch_size, shape=[None])
    memory_sequence_length = [3, 7, 5, 4]
    memory_time = max(memory_sequence_length)
    memory =  tf.placeholder_with_default(
        np.random.randn(batch_size, memory_time, depth).astype(np.float32),
        shape=(None, None, depth))
    memory_sequence_length = tf.placeholder_with_default(memory_sequence_length, shape=[None])
    embedding =  tf.placeholder_with_default(
        np.random.randn(vocab_size, depth).astype(np.float32),
        shape=(vocab_size, depth))

    if with_beam_search:
      decode_fn = decoder.dynamic_decode_and_search
    else:
      decode_fn = decoder.dynamic_decode

    additional_kwargs = {}
    if with_alignment_history:
      additional_kwargs["return_alignment_history"] = True
    if with_beam_search:
      additional_kwargs["beam_width"] = beam_width

    if (with_beam_search and with_alignment_history and "RNN" in decoder.__class__.__name__
        and not "reorder_tensor_arrays" in fn_args(tf.contrib.seq2seq.BeamSearchDecoder.__init__)):
      support_alignment_history = False

    outputs = decode_fn(
        embedding,
        start_tokens,
        end_token,
        vocab_size=vocab_size,
        maximum_iterations=10,
        memory=memory,
        memory_sequence_length=memory_sequence_length,
        **additional_kwargs)

    ids = outputs[0]
    state = outputs[1]
    lengths = outputs[2]
    log_probs = outputs[3]

    decode_time = tf.shape(ids)[-1]

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())

    if not with_alignment_history:
      self.assertEqual(4, len(outputs))
    else:
      self.assertEqual(5, len(outputs))
      alignment_history = outputs[4]
      if support_alignment_history:
        self.assertIsInstance(alignment_history, tf.Tensor)
        with self.test_session() as sess:
          alignment_history, decode_time = sess.run([alignment_history, decode_time])
          self.assertAllEqual(
              [batch_size, num_hyps, decode_time, memory_time], alignment_history.shape)
      else:
        self.assertIsNone(alignment_history)

    with self.test_session() as sess:
      ids, lengths, log_probs = sess.run([ids, lengths, log_probs])
      self.assertAllEqual([batch_size, num_hyps], ids.shape[0:2])
      self.assertAllEqual([batch_size, num_hyps], lengths.shape)
      self.assertAllEqual([batch_size, num_hyps], log_probs.shape)