def add(self, layer_func): if isinstance(layer_func, base.Layer): args = estimator_util.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): args = estimator_util.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " "not '%s' of type '%s'." % (layer_func, type(layer_func))) self._layers_funcs.append((("training" in args), layer_func))
def _call_model_fn(self, features, labels, add_batch_size_in_params=False): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} config = copy.deepcopy(self._config) params = copy.deepcopy(self._params) if 'labels' in model_fn_args: kwargs['labels'] = labels else: if labels is not None: raise ValueError( 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = self._mode if 'config' in model_fn_args: kwargs['config'] = config if 'params' in model_fn_args: kwargs['params'] = params if add_batch_size_in_params: if 'params' not in model_fn_args: raise ValueError( 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) if self._mode == model_fn_lib.ModeKeys.TRAIN: # For TPU training. `params` is never `None`. params[_BATCH_SIZE_KEY] = _per_shard_batch_size(self._train_batch_size, config) return self._model_fn(features=features, **kwargs)
def _call_model_fn(self, features, labels, mode, config): """Calls model function. Args: features: features dict. labels: labels dict. mode: ModeKeys config: RunConfig Returns: An `EstimatorSpec` object. Raises: ValueError: if model_fn returns invalid objects. """ model_fn_args = util.fn_args(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels else: if labels is not None: raise ValueError( 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = mode if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: kwargs['config'] = config model_fn_results = self._model_fn(features=features, **kwargs) if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec): raise ValueError('model_fn should return an EstimatorSpec.') return model_fn_results
def _call_loss_fn(loss_fn, labels, logits, features): """Calls loss_fn and checks the returned shape. Args: loss_fn: The loss function. labels: Processed labels Tensor. logits: Logits Tensor of shape [batch_size, logits_dimension]. features: Features dict. Returns: Loss Tensor with shape [batch_size, 1]. """ loss_fn_args = util.fn_args(loss_fn) kwargs = {} if 'features' in loss_fn_args: kwargs['features'] = features unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs) batch_size = array_ops.shape(logits)[0] loss_shape = array_ops.shape(unweighted_loss) check_shape_op = control_flow_ops.Assert( math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])), data=[ 'loss_fn must return Tensor of shape [batch_size, 1]. Given: ', loss_shape]) with ops.control_dependencies([check_shape_op]): return array_ops.identity(unweighted_loss)
def run_step_fn(self, step_fn): """Run ops using a step function. Args: step_fn: A function or a method with a single argument of type `StepContext`. The function may use methods of the argument to perform computations with access to a raw session. The returned value of the `step_fn` will be returned from `run_step_fn`, unless a stop is requested. In that case, the next `should_stop` call will return True. Example usage: ```python with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v = tf.add(c, 4.0) w = tf.add(c, 0.5) def step_fn(step_context): a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) if a <= 4.5: step_context.request_stop() return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1}) with tf.MonitoredSession() as session: while not session.should_stop(): a = session.run_step_fn(step_fn) ``` Hooks interact with the `run_with_hooks()` call inside the `step_fn` as they do with a `MonitoredSession.run` call. Returns: Returns the returned value of `step_fn`. Raises: StopIteration: if `step_fn` has called `request_stop()`. It may be caught by `with tf.MonitoredSession()` to close the session. ValueError: if `step_fn` doesn't have a single argument called `step_context`. It may also optionally have `self` for cases when it belongs to an object. """ step_fn_arguments = util.fn_args(step_fn) if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 'self', 'step_context', ): raise ValueError( '`step_fn` may either have one `step_context` argument, or' ' `self` and `step_context` arguments if it\'s an instance' ' method. Got {} instead.'.format(step_fn_arguments)) # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be # `_CoordinatedSession.run` downstream in either case. This allows # `_PREEMPTION_ERRORS` to propage from within `step_fn` to # `_RecoverableSession.run_step_fn`. return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
def call_logit_fn(logit_fn, features, mode, params, config): """Calls logit_fn. A utility function that calls the provided logit_fn with the relevant subset of provided arguments. Similar to tf.estimator._call_model_fn(). Args: logit_fn: A logit_fn as defined above. features: The features dict. mode: TRAIN / EVAL / PREDICT ModeKeys. params: The hyperparameter dict. config: The configuration object. Returns: A logit Tensor, the output of logit_fn. Raises: ValueError: if logit_fn does not return a Tensor. """ logit_fn_args = util.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode if 'params' in logit_fn_args: kwargs['params'] = params if 'config' in logit_fn_args: kwargs['config'] = config logit_fn_results = logit_fn(features=features, **kwargs) if not isinstance(logit_fn_results, ops.Tensor): raise ValueError('model_fn should return a Tensor.') return logit_fn_results
def export(self, estimator, export_path, checkpoint_path=None, eval_result=None): """Exports the given Estimator to a specific format. Args: estimator: the Estimator to export. export_path: A string containing a directory where to write the export. checkpoint_path: The checkpoint path to export. If None (the default), the strategy may locate a checkpoint (e.g. the most recent) by itself. eval_result: The output of Estimator.evaluate on this checkpoint. This should be set only if checkpoint_path is provided (otherwise it is unclear which checkpoint this eval refers to). Returns: The string path to the exported directory. Raises: ValueError: if the export_fn does not have the required signature. """ export_fn_args = util.fn_args(self.export_fn) kwargs = {} if 'checkpoint_path' in export_fn_args: kwargs['checkpoint_path'] = checkpoint_path if 'eval_result' in export_fn_args: if 'checkpoint_path' not in export_fn_args: raise ValueError('An export_fn accepting eval_result must also accept ' 'checkpoint_path.') kwargs['eval_result'] = eval_result return self.export_fn(estimator, export_path, **kwargs)
def _call_input_fn(self, input_fn, mode): """Calls the input function. Args: input_fn: The input function. mode: ModeKeys Returns: Either features or (features, labels) where features and labels are: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. Raises: ValueError: if input_fn takes invalid arguments. """ input_fn_args = util.fn_args(input_fn) kwargs = {} if 'mode' in input_fn_args: kwargs['mode'] = mode if 'params' in input_fn_args: kwargs['params'] = self.params if 'config' in input_fn_args: kwargs['config'] = self.config with ops.device('/cpu:0'): return input_fn(**kwargs)
def test_bounded_method(self): class Foo(object): def bar(self, a, b): return a + b self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
def test_callable(self): class Foo(object): def __call__(self, a, b): return a + b self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo()))
def _verify_metric_fn_args(metric_fn): args = set(estimator_util.fn_args(metric_fn)) if tf_inspect.ismethod(metric_fn): if 'self' in args: args.remove('self') invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % (metric_fn, invalid_args))
def _call_input_fn(self, input_fn, mode): """Calls the input function. Args: input_fn: The input function. mode: ModeKeys Returns: Either features or (features, labels) where features and labels are: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. Raises: ValueError: if input_fn takes invalid arguments or does not have `params`. """ if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN: return super(TpuEstimator, self)._call_input_fn(input_fn, mode) input_fn_args = util.fn_args(input_fn) config = self.config # a deep copy. kwargs = {} if 'params' in input_fn_args: kwargs['params'] = self.params # a deep copy. else: raise ValueError('input_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params["batch_size"]'.format(input_fn)) if 'config' in input_fn_args: kwargs['config'] = config # Now for TPU training. per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config) kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size job = _tpu_job(config) def placement_function(index): if job is None: return '/replica:0/task:0/device:CPU:0' else: return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8) features = [] labels = [] for i in range(config.tpu_config.num_shards): with ops.device(placement_function(i)): result = input_fn(**kwargs) # input_fn may return either features or (features, labels) if isinstance(result, tuple): features.append(result[0]) labels.append(result[1]) else: features.append(result) if not labels or all(l is None for l in labels): return _PerShardOutput(features), None return _PerShardOutput(features), _PerShardOutput(labels)
def _call_input_fn(self, input_fn, mode): """Calls the input function. Args: input_fn: The input function. mode: ModeKeys Returns: Either features or (features, labels) where features and labels are: features - `Tensor` or dictionary of string feature name to `Tensor`. labels - `Tensor` or dictionary of `Tensor` with labels. Raises: ValueError: if input_fn takes invalid arguments or does not have `params`. """ input_fn_args = util.fn_args(input_fn) config = self.config # a deep copy. kwargs = {} if 'params' in input_fn_args: kwargs['params'] = self.params # a deep copy. else: raise ValueError('input_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params["batch_size"]'.format(input_fn)) if 'config' in input_fn_args: kwargs['config'] = config # Now for TPU training. if mode == model_fn_lib.ModeKeys.TRAIN: kwargs['params'][_BATCH_SIZE_KEY] = ( _per_shard_batch_size(self._train_batch_size, config, self._use_tpu) if not config.tpu_config.per_host_input_for_training else self._train_batch_size) if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN: with ops.device('/cpu:0'): return input_fn(**kwargs) job = _tpu_job(config) def placement_function(index): if job is None: return '/replica:0/task:0/device:CPU:0' else: return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8) if not config.tpu_config.per_host_input_for_training: num_shards = config.tpu_config.num_shards inputs = _InputsHolder(num_shards=num_shards) for i in range(config.tpu_config.num_shards): with ops.device(placement_function(i)): inputs.append_tuple(input_fn(**kwargs)) return inputs.as_features_and_labels_tuple() else: # TODO(xiejw): Extend this to multi-host support. with ops.device(placement_function(0)): return input_fn(**kwargs)
def _get_standardized_predicate_fn(predicate_fn): pred_fn_args = estimator_util.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): return predicate_fn(eval_results) return _pred_fn_wrapper else: return predicate_fn
def test_partial_function(self): expected_test_arg = 123 def fn(a, test_arg): if test_arg != expected_test_arg: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, test_arg=123) self.assertEqual(('a',), util.fn_args(wrapped_fn))
def test_double_partial(self): expected_test_arg1 = 123 expected_test_arg2 = 456 def fn(a, test_arg1, test_arg2): if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: return ValueError('partial does not work correctly') return a wrapped_fn = functools.partial(fn, test_arg2=456) double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" metric_fn_args = estimator_util.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features if 'labels' in metric_fn_args: kwargs['labels'] = labels if 'predictions' in metric_fn_args: kwargs['predictions'] = predictions if 'config' in metric_fn_args: kwargs['config'] = config return metric_fn(**kwargs)
def test_partial_function_with_positional_args(self): expected_test_arg = 123 def fn(test_arg, a): if test_arg != expected_test_arg: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, 123) self.assertEqual(('a',), util.fn_args(wrapped_fn)) self.assertEqual(3, wrapped_fn(3)) self.assertEqual(3, wrapped_fn(a=3))
def _verify_compare_fn_args(compare_fn): """Verifies compare_fn arguments.""" args = set(util.fn_args(compare_fn)) if 'best_eval_result' not in args: raise ValueError( 'compare_fn (%s) must include best_eval_result argument.' % compare_fn) if 'current_eval_result' not in args: raise ValueError( 'compare_fn (%s) must include current_eval_result argument.' % compare_fn) non_valid_args = list(args - set(['best_eval_result', 'current_eval_result'])) if non_valid_args: raise ValueError('compare_fn (%s) has following not expected args: %s' % (compare_fn, non_valid_args))
def test_double_partial_with_positional_args_in_outer_layer(self): expected_test_arg1 = 123 expected_test_arg2 = 456 def fn(test_arg1, a, test_arg2): if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, test_arg2=456) double_wrapped_fn = functools.partial(wrapped_fn, 123) self.assertEqual(('a',), util.fn_args(double_wrapped_fn)) self.assertEqual(3, double_wrapped_fn(3)) self.assertEqual(3, double_wrapped_fn(a=3))
def _verify_model_fn_args(model_fn, params): """Verifies model fn arguments.""" args = set(util.fn_args(model_fn)) if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) if params is not None and 'params' not in args: raise ValueError('model_fn (%s) does not include params argument, ' 'but params (%s) is passed to Estimator.' % (model_fn, params)) if params is None and 'params' in args: logging.warning('Estimator\'s model_fn (%s) includes params ' 'argument, but params are not passed to Estimator.', model_fn) non_valid_args = list(args - _VALID_MODEL_FN_ARGS) if non_valid_args: raise ValueError('model_fn (%s) has following not expected args: %s' % (model_fn, non_valid_args))
def _get_loss_towers(model_fn, mode, features, labels, params, config, devices, local_ps_device, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] model_fn_args = util.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( worker_device=device, ps_device=local_ps_device) # We would like to preserve the names of the variables and ops that a user # might be relying on. Names with prefix are going to resolve to variables # and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' with variable_scope.variable_scope('', reuse=not is_the_first_tower): with ops_lib.name_scope(name_scope.format(i)): with ops_lib.device(device_setter): labels_shard = None if labels: labels_shard = labels[i] tower_specs.append( model_fn( mode=mode, features=features[i], labels=labels_shard, **optional_params)) return tower_specs
def call_logit_fn(logit_fn, features, mode, params, config): """Calls logit_fn. A utility function that calls the provided logit_fn with the relevant subset of provided arguments. Similar to tf.estimator._call_model_fn(). Args: logit_fn: A logit_fn as defined above. features: The features dict. mode: TRAIN / EVAL / PREDICT ModeKeys. params: The hyperparameter dict. config: The configuration object. Returns: A logit Tensor, the output of logit_fn. Raises: ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ logit_fn_args = util.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode if 'params' in logit_fn_args: kwargs['params'] = params if 'config' in logit_fn_args: kwargs['config'] = config logit_fn_results = logit_fn(features=features, **kwargs) result_is_valid_dictionary = ( isinstance(logit_fn_results, dict) and all([(isinstance(k, str) and isinstance(v, ops.Tensor)) for k, v in six.iteritems(logit_fn_results)])) result_is_tensor = isinstance(logit_fn_results, ops.Tensor) if not (result_is_valid_dictionary or result_is_tensor): raise ValueError('logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors. logit_fn returned: %s' % logit_fn_results) return logit_fn_results
def _validate_loss_fn_args(loss_fn): """Validates loss_fn arguments. Required arguments: labels, logits. Optional arguments: features. Args: loss_fn: The loss function. Raises: ValueError: If the signature is unexpected. """ loss_fn_args = util.fn_args(loss_fn) for required_arg in ['labels', 'logits']: if required_arg not in loss_fn_args: raise ValueError( 'loss_fn must contain argument: {}. ' 'Given arguments: {}'.format(required_arg, loss_fn_args)) invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features'])) if invalid_args: raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
def _call_model_fn(model_fn, features, labels, mode, config, params, require_params=False): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels else: if labels is not None: raise ValueError( 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = mode if 'config' in model_fn_args: kwargs['config'] = config if 'params' in model_fn_args: kwargs['params'] = params elif require_params: raise ValueError( 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(model_fn)) return model_fn(features=features, **kwargs)
def _validate_properties(run_config): """Validates the properties.""" def _validate(property_name, cond, message): property_value = getattr(run_config, property_name) if property_value is not None and not cond(property_value): raise ValueError(message) _validate('model_dir', lambda dir: dir, message='model_dir should be non-empty') _validate('save_summary_steps', lambda steps: steps >= 0, message='save_summary_steps should be >= 0') _validate('save_checkpoints_steps', lambda steps: steps >= 0, message='save_checkpoints_steps should be >= 0') _validate('save_checkpoints_secs', lambda secs: secs >= 0, message='save_checkpoints_secs should be >= 0') _validate('session_config', lambda sc: isinstance(sc, config_pb2.ConfigProto), message='session_config must be instance of ConfigProto') _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0, message='keep_checkpoint_max should be >= 0') _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0, message='keep_checkpoint_every_n_hours should be > 0') _validate('log_step_count_steps', lambda num_steps: num_steps > 0, message='log_step_count_steps should be > 0') _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types), message='tf_random_seed must be integer.') _validate('device_fn', lambda device_fn: six.callable(device_fn) and set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS, message='device_fn must be callable with exactly' ' one argument "op".')
def _get_loss_towers(model_fn, mode, features, labels, params, config, devices, local_ps_devices, loss_reduction, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] model_fn_args = util.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) # pylint: disable=protected-access round_robin_strategy = device_setter_lib._RoundRobinStrategy( num_tasks=len(local_ps_devices)) TowerOptimizer._graph_state().set_reduction_across_towers( loss_reduction, len(devices)) for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter(worker_device=device, ps_devices=local_ps_devices, ps_strategy=round_robin_strategy) # We would like to preserve the names of the variables and ops that the user # might be relying on. Names without a prefix are going to resolve to # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' with variable_scope.variable_scope( '', reuse=not is_the_first_tower) as var_scope: with ops_lib.name_scope(name_scope.format(i)) as name_scope: with TowerOptimizer._graph_state().tower( tower_id=i, var_scope=var_scope, name_scope=name_scope): with ops_lib.device(device_setter): labels_shard = None if labels: labels_shard = labels[i] tower_spec = model_fn(mode=mode, features=features[i], labels=labels_shard, **optional_params) if (tower_spec.train_op is not None and len(devices) > 1 and not TowerOptimizer.has_been_used()): raise ValueError( 'Please wrap optimizers with TowerOptimizer' ' in order to use replicate_model_fn with' ' multiple `devices`.') # Scaling the loss here doesn't actually affect gradients. Another # instance of scaling happens inside the TowerOptimizer. tower_spec = _scale_tower_loss( tower_spec, loss_reduction, number_of_towers=len(devices)) tower_specs.append(tower_spec) if not TowerOptimizer._did_towers_have_same_optimizer_calls(): raise ValueError( 'Each invocation of model_fn was supposed to make the same' ' optimizer calls.') TowerOptimizer._clear_graph_state() # pylint: enable=protected-access return tower_specs
def dynamic_decode_and_search(self, embedding, start_tokens, end_token, vocab_size=None, initial_state=None, output_layer=None, beam_width=5, length_penalty=0.0, maximum_iterations=250, mode=tf.estimator.ModeKeys.PREDICT, memory=None, memory_sequence_length=None, dtype=None, return_alignment_history=False): if (return_alignment_history and "reorder_tensor_arrays" not in fn_args(tf.contrib.seq2seq.BeamSearchDecoder.__init__)): tf.logging.warn( "The current version of tf.contrib.seq2seq.BeamSearchDecoder " "does not support returning the alignment history. None will " "be returned instead. Consider upgrading TensorFlow.") alignment_history = False else: alignment_history = return_alignment_history batch_size = tf.shape(start_tokens)[0] # Replicate batch `beam_width` times. if initial_state is not None: initial_state = tf.contrib.seq2seq.tile_batch( initial_state, multiplier=beam_width) if memory is not None: memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width) if memory_sequence_length is not None: memory_sequence_length = tf.contrib.seq2seq.tile_batch( memory_sequence_length, multiplier=beam_width) cell, initial_state = self._build_cell( mode, batch_size * beam_width, initial_state=initial_state, memory=memory, memory_sequence_length=memory_sequence_length, dtype=dtype, alignment_history=alignment_history) if output_layer is None: output_layer = build_output_layer(self.num_units, vocab_size, dtype=dtype or memory.dtype) decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell, embedding, start_tokens, end_token, initial_state, beam_width, output_layer=output_layer, length_penalty_weight=length_penalty) outputs, beam_state, length = tf.contrib.seq2seq.dynamic_decode( decoder, maximum_iterations=maximum_iterations) predicted_ids = tf.transpose(outputs.predicted_ids, perm=[0, 2, 1]) log_probs = beam_state.log_probs state = beam_state.cell_state if return_alignment_history: alignment_history = _get_alignment_history(state) if alignment_history is not None: alignment_history = tf.reshape( alignment_history, [-1, batch_size, beam_width, tf.shape(memory)[1]]) return (predicted_ids, state, length, log_probs, alignment_history) return (predicted_ids, state, length, log_probs)
def run_step_fn(self, step_fn): """Run ops using a step function. Args: step_fn: A function or a method with a single argument of type `StepContext`. The function may use methods of the argument to perform computations with access to a raw session. The returned value of the `step_fn` will be returned from `run_step_fn`, unless a stop is requested. In that case, the next `should_stop` call will return True. Example usage: ```python with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v = tf.add(c, 4.0) w = tf.add(c, 0.5) def step_fn(step_context): a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) if a <= 4.5: step_context.request_stop() return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1}) with tf.MonitoredSession() as session: while not session.should_stop(): a = session.run_step_fn(step_fn) ``` Hooks interact with the `run_with_hooks()` call inside the `step_fn` as they do with a `MonitoredSession.run` call. Returns: Returns the returned value of `step_fn`. Raises: StopIteration: if `step_fn` has called `request_stop()`. It may be caught by `with tf.MonitoredSession()` to close the session. ValueError: if `step_fn` doesn't have a single argument called `step_context`. It may also optionally have `self` for cases when it belongs to an object. """ step_fn_arguments = util.fn_args(step_fn) if step_fn_arguments != ('step_context', ) and step_fn_arguments != ( 'self', 'step_context', ): raise ValueError( '`step_fn` may either have one `step_context` argument, or' ' `self` and `step_context` arguments if it\'s an instance' ' method. Got {} instead.'.format(step_fn_arguments)) # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be # `_CoordinatedSession.run` downstream in either case. This allows # `_PREEMPTION_ERRORS` to propage from within `step_fn` to # `_RecoverableSession.run_step_fn`. return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
def _verify_metric_fn_args(metric_fn): args = set(estimator_util.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % (metric_fn, invalid_args))
def _get_loss_towers(model_fn, mode, features, labels, params, config, devices, local_ps_devices, loss_reduction, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] model_fn_args = util.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) # pylint: disable=protected-access round_robin_strategy = device_setter_lib._RoundRobinStrategy( num_tasks=len(local_ps_devices)) TowerOptimizer._graph_state().set_reduction_across_towers( loss_reduction, len(devices)) for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( worker_device=device, ps_devices=local_ps_devices, ps_strategy=round_robin_strategy) # We would like to preserve the names of the variables and ops that the user # might be relying on. Names without a prefix are going to resolve to # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' with variable_scope.variable_scope( '', reuse=not is_the_first_tower) as var_scope: with ops_lib.name_scope(name_scope.format(i)) as name_scope: with TowerOptimizer._graph_state().tower( tower_id=i, var_scope=var_scope, name_scope=name_scope): with ops_lib.device(device_setter): labels_shard = None if labels: labels_shard = labels[i] tower_spec = model_fn( mode=mode, features=features[i], labels=labels_shard, **optional_params) if (tower_spec.train_op is not None and len(devices) > 1 and not TowerOptimizer.has_been_used()): raise ValueError('Please wrap optimizers with TowerOptimizer' ' in order to use replicate_model_fn with' ' multiple `devices`.') # Scaling the loss here doesn't actually affect gradients. Another # instance of scaling happens inside the TowerOptimizer. tower_spec = _scale_tower_loss( tower_spec, loss_reduction, number_of_towers=len(devices)) tower_specs.append(tower_spec) if not TowerOptimizer._did_towers_have_same_optimizer_calls(): raise ValueError('Each invocation of model_fn was supposed to make the same' ' optimizer calls.') TowerOptimizer._clear_graph_state() # pylint: enable=protected-access return tower_specs
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). Note: - If the layer's `call` method takes a `scope` keyword argument, this argument will be automatically set to the current variable scope. - If the layer's `call` method takes a `mask` argument (as some Keras layers do), its default value will be set to the mask generated for `inputs` by the previous layer (if `input` did come from a layer that generated a corresponding mask, i.e. if it came from a Keras layer with masking support. Raises: ValueError: if the layer's `call` method returns None (an invalid value). """ self._set_scope(kwargs.pop('scope', None)) if not context.executing_eagerly(): try: # Set layer's "graph" at build time self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access graph=self._graph) except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) if self.built: try: # Some classes which inherit from Layer do not use its constructor, so # rather than initializing to None we check for an AttributeError. scope_context_manager = self._always_reuse_variable_scope except AttributeError: # From this point we will always set reuse=True, so create a "final" # variable scope with this setting. We avoid re-creating variable scopes # after this point as an optimization. self._always_reuse_variable_scope = vs.variable_scope( self._scope, reuse=True, auxiliary_name_scope=False) scope_context_manager = self._always_reuse_variable_scope else: scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse, auxiliary_name_scope=False) with scope_context_manager as scope: self._current_scope = scope try: call_has_scope_arg = self._call_has_scope_arg except AttributeError: self._call_fn_args = estimator_util.fn_args(self.call) self._call_has_scope_arg = 'scope' in self._call_fn_args call_has_scope_arg = self._call_has_scope_arg if call_has_scope_arg: kwargs['scope'] = scope # Actually call layer outputs = super(Layer, self).__call__(inputs, *args, **kwargs) if not context.executing_eagerly(): # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs
def _run_internal_graph(self, inputs, masks=None): """Computes output tensors for new inputs. # Note: - Expects `inputs` to be a list (potentially with 1 element). - Can be run on non-Keras tensors. Arguments: inputs: List of tensors masks: List of masks (tensors or None). Returns: Three lists: output_tensors, output_masks, output_shapes """ # Note: masking support is relevant mainly for Keras. # It cannot be factored out without having the fully reimplement the network # calling logic on the Keras side. We choose to incorporate it in # GraphNetwork because 1) it may be useful to fully support in tf.layers in # the future and 2) Keras is a major user of GraphNetwork. If you don't # use masking, it does not interfere with regular behavior at all and you # can ignore it. if masks is None: masks = [None for _ in range(len(inputs))] # Dictionary mapping reference tensors to tuples # (computed tensor, compute mask) # we assume a 1:1 mapping from tensor to mask # TODO(fchollet): raise exception when a `.compute_mask()` call # does not return a list the same size as `call` tensor_map = {} for x, y, mask in zip(self.inputs, inputs, masks): tensor_map[str(id(x))] = (y, mask) depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors # If all previous input tensors are available in tensor_map, # then call node.inbound_layer on them. computed_data = [] # List of tuples (input, mask). for x in reference_input_tensors: if str(id(x)) in tensor_map: computed_data.append(tensor_map[str(id(x))]) if len(computed_data) == len(reference_input_tensors): # Call layer (reapplying ops to new inputs). with ops.name_scope(layer.name): if node.arguments: kwargs = node.arguments else: kwargs = {} if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] # Ensure mask propagation if applicable. if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = nest.flatten( layer.call(computed_tensor, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensor, computed_mask)) else: output_masks = [None for _ in range(len(output_tensors))] computed_tensors = [computed_tensor] computed_masks = [computed_mask] else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = nest.flatten( layer.call(computed_tensors, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensors, computed_masks)) else: output_masks = [None for _ in range(len(output_tensors))] # Apply activity regularizer if any: if layer.activity_regularizer is not None: regularization_losses = [ layer.activity_regularizer(x) for x in computed_tensors ] layer.add_loss(regularization_losses, computed_tensors) if context.in_graph_mode(): # Update model updates and losses: # Keep track of updates that depend on the inputs # (e.g. BN updates). self.add_update(layer.get_updates_for(computed_tensors), inputs) # Keep track of unconditional updates (e.g. a counter). self.add_update(layer.get_updates_for(None), None) # Keep track of losses that depend on the inputs # (e.g. activity regularizers). self.add_loss(layer.get_losses_for(computed_tensors), inputs) # Keep track of unconditional losses # (e.g. weight regularizers). self.add_loss(layer.get_losses_for(None), None) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, output_masks): tensor_map[str(id(x))] = (y, mask) output_tensors = [] output_masks = [] output_shapes = [] for x in self.outputs: assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) tensor, mask = tensor_map[str(id(x))] output_shapes.append(layers_util.static_shape(x)) output_tensors.append(tensor) output_masks.append(mask) if len(output_tensors) == 1: output_tensors = output_tensors[0] if output_shapes is not None: output_shapes = output_shapes[0] if output_masks is not None: output_masks = output_masks[0] if context.in_graph_mode(): # Update cache; # keys are based on ids on input tensors and inputs masks. cache_key = (layers_util.object_list_uid(inputs) + '_' + layers_util.object_list_uid(masks)) self._output_tensor_cache[cache_key] = output_tensors if output_masks is not None: self._output_mask_cache[cache_key] = output_masks if output_shapes is not None: input_shapes = [layers_util.static_shape(x) for x in inputs] cache_key = layers_util.object_list_uid(input_shapes) self._output_shape_cache[cache_key] = output_shapes return output_tensors, output_masks
def _call_optimizer_fn(optimizer_fn, params): arguments = {} optimizer_fn_arguments = util.fn_args(optimizer_fn) if 'params' in optimizer_fn_arguments: arguments['params'] = params return optimizer_fn(**arguments)
def _run_internal_graph(self, inputs, masks=None): """Computes output tensors for new inputs. # Note: - Expects `inputs` to be a list (potentially with 1 element). - Can be run on non-Keras tensors. Arguments: inputs: List of tensors masks: List of masks (tensors or None). Returns: Three lists: output_tensors, output_masks, output_shapes """ # Note: masking support is relevant mainly for Keras. # It cannot be factored out without having the fully reimplement the network # calling logic on the Keras side. We choose to incorporate it in # GraphNetwork because 1) it may be useful to fully support in tf.layers in # the future and 2) Keras is a major user of GraphNetwork. If you don't # use masking, it does not interfere with regular behavior at all and you # can ignore it. if masks is None: masks = [None for _ in range(len(inputs))] # Dictionary mapping reference tensors to tuples # (computed tensor, compute mask) # we assume a 1:1 mapping from tensor to mask # TODO(fchollet): raise exception when a `.compute_mask()` call # does not return a list the same size as `call` tensor_map = {} for x, y, mask in zip(self.inputs, inputs, masks): tensor_map[str(id(x))] = (y, mask) depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors # If all previous input tensors are available in tensor_map, # then call node.inbound_layer on them. computed_data = [] # List of tuples (input, mask). for x in reference_input_tensors: if str(id(x)) in tensor_map: computed_data.append(tensor_map[str(id(x))]) if len(computed_data) == len(reference_input_tensors): # Call layer (reapplying ops to new inputs). with ops.name_scope(layer.name): if node.arguments: kwargs = node.arguments else: kwargs = {} if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] # Ensure mask propagation if applicable. if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = nest.flatten( layer.call(computed_tensor, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensor, computed_mask)) else: output_masks = [None for _ in range(len(output_tensors))] computed_tensors = [computed_tensor] computed_masks = [computed_mask] else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = nest.flatten( layer.call(computed_tensors, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensors, computed_masks)) else: output_masks = [None for _ in range(len(output_tensors))] if context.in_graph_mode(): if layer.activity_regularizer is not None: regularization_losses = [ layer.activity_regularizer(x) for x in computed_tensors ] # Apply activity regularizer if any: layer.add_loss(regularization_losses, computed_tensors) if context.in_graph_mode(): # Update model updates and losses: # Keep track of updates that depend on the inputs # (e.g. BN updates). self.add_update(layer.get_updates_for(computed_tensors), inputs) # Keep track of unconditional updates (e.g. a counter). self.add_update(layer.get_updates_for(None), None) # Keep track of losses that depend on the inputs # (e.g. activity regularizers). self.add_loss(layer.get_losses_for(computed_tensors), inputs) # Keep track of unconditional losses # (e.g. weight regularizers). self.add_loss(layer.get_losses_for(None), None) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, output_masks): tensor_map[str(id(x))] = (y, mask) output_tensors = [] output_masks = [] output_shapes = [] for x in self.outputs: assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) tensor, mask = tensor_map[str(id(x))] output_shapes.append(layers_util.static_shape(x)) output_tensors.append(tensor) output_masks.append(mask) if len(output_tensors) == 1: output_tensors = output_tensors[0] if output_shapes is not None: output_shapes = output_shapes[0] if output_masks is not None: output_masks = output_masks[0] if context.in_graph_mode(): # Update cache; # keys are based on ids on input tensors and inputs masks. cache_key = (layers_util.object_list_uid(inputs) + '_' + layers_util.object_list_uid(masks)) self._output_tensor_cache[cache_key] = output_tensors if output_masks is not None: self._output_mask_cache[cache_key] = output_masks if output_shapes is not None: input_shapes = [layers_util.static_shape(x) for x in inputs] cache_key = layers_util.object_list_uid(input_shapes) self._output_shape_cache[cache_key] = output_shapes return output_tensors, output_masks
def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called if context.in_eager_mode(): # TODO(fchollet): check that all inputs and outputs are DeferredTensors. pass self._init_set_name(name) self._activity_regularizer = None with vs.variable_scope( None, default_name=self._base_name) as captured_scope: self._scope = captured_scope call_fn_args = estimator_util.fn_args(self.call) self._compute_previous_mask = ('mask' in call_fn_args or hasattr(self, 'compute_mask')) self._call_has_scope_arg = 'scope' in call_fn_args # This acts just like the `trainable` attribute of any layer instance. # It does not affect users of the underlying layers, only users of the # GraphNetwork instance. self.trainable = True # A GraphNetwork does not create weights of its own, thus it is already # built. self.built = True # A GraphNetwork does not create weights of its own, thus has no dtype. self._dtype = None # The following are implemented as property functions: # self.trainable_weights # self.non_trainable_weights # self.input_spec # Private attributes to implement compatibility with Layer. self._per_input_losses = {} self._per_input_updates = {} self._updates = [] self._losses = [] self._scope = None self._reuse = None self._graph = ops.get_default_graph() # GraphNetwork-specific properties. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. else: self.inputs = [inputs] if isinstance(outputs, (list, tuple)): self.outputs = list(outputs) else: self.outputs = [outputs] # All layers in order of horizontal graph traversal. # Entries are unique. Includes input and output layers. self.layers = [] # Check for redundancy in inputs. if len(set(self.inputs)) != len(self.inputs): raise ValueError('The list of inputs passed to the model ' 'is redundant. ' 'All inputs should only appear once.' ' Found: ' + str(self.inputs)) # # List of initial layers (1 to 1 mapping with self.inputs, # # hence the same layer might appear twice) # self._input_layers = [] # self._input_layers_node_indices = [] # self._input_layers_tensor_indices = [] # # list of layers (1 to 1 mapping with self.inputs, # # hence the same layer might appear twice) # self._output_layers = [] # self._output_layers_node_indices = [] # self._output_layers_tensor_indices = [] self._input_layers = [] self._output_layers = [] self._input_coordinates = [] self._output_coordinates = [] # This is for performance optimization when calling the GraphNetwork on new # inputs. Every time the GraphNetwork is called on a set on input tensors, # we compute the output tensors, output masks and output shapes in one pass, # then cache them here. When any of these outputs is queried later, we # retrieve it from there instead of recomputing it. self._output_mask_cache = {} self._output_tensor_cache = {} self._output_shape_cache = {} # User-provided arguments validation. for x in self.inputs: # Check that x has appropriate `_keras_history` metadata. if not hasattr(x, '_keras_history'): cls_name = self.__class__.__name__ raise ValueError('Input tensors to a ' + cls_name + ' ' + 'must come from `tf.layers.Input`. ' 'Received: ' + str(x) + ' (missing previous layer metadata).') # Check that x is an input tensor. # pylint: disable=protected-access layer, node_index, tensor_index = x._keras_history if len(layer._inbound_nodes) > 1 or ( layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): cls_name = self.__class__.__name__ logging.warning(cls_name + ' inputs must come from ' '`tf.layers.Input` (thus holding past layer metadata), ' 'they cannot be the output of ' 'a previous non-Input layer. ' 'Here, a tensor specified as ' 'input to "' + self.name + '" was not an Input tensor, ' 'it was generated by layer ' + layer.name + '.\n' 'Note that input tensors are ' 'instantiated via `tensor = tf.layers.Input(shape)`.\n' 'The tensor that caused the issue was: ' + str(x.name)) # pylint: enable=protected-access for x in self.outputs: if not hasattr(x, '_keras_history'): cls_name = self.__class__.__name__ raise ValueError('Output tensors to a ' + cls_name + ' must be ' 'the output of a TensorFlow `Layer` ' '(thus holding past layer metadata). Found: ' + str(x)) # Build self._output_layers: for x in self.outputs: layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access self._output_layers.append(layer) self._output_coordinates.append((layer, node_index, tensor_index)) # Build self._input_layers: for x in self.inputs: layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access # It's supposed to be an input layer, so only one node # and one tensor output. assert node_index == 0 assert tensor_index == 0 self._input_layers.append(layer) self._input_coordinates.append((layer, node_index, tensor_index)) # Network_nodes: set of nodes included in the graph # (not all nodes included in the layers # are relevant to the current graph). network_nodes = set() # ids of all nodes relevant to the GraphNetwork nodes_depths = {} # dict {node: depth value} layers_depths = {} # dict {layer: depth value} layer_indices = {} # dict {layer: index in traversal} nodes_in_decreasing_depth = [] def build_map_of_graph(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index): """Builds a map of the graph of layers. This recursively updates the map `layer_indices`, the list `nodes_in_decreasing_depth` and the set `network_nodes`. Arguments: tensor: Some tensor in a graph. finished_nodes: Set of nodes whose subgraphs have been traversed completely. Useful to prevent duplicated work. nodes_in_progress: Set of nodes that are currently active on the recursion stack. Useful to detect cycles. layer: Layer from which `tensor` comes from. If not provided, will be obtained from `tensor._keras_history`. node_index: Node index from which `tensor` comes from. tensor_index: Tensor_index from which `tensor` comes from. Raises: ValueError: if a cycle is detected. """ node = layer._inbound_nodes[node_index] # pylint: disable=protected-access # Prevent cycles. if node in nodes_in_progress: raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name + '" is part of a cycle.') # Don't repeat work for shared subgraphs if node in finished_nodes: return node_key = _make_node_key(layer.name, node_index) # Update network_nodes. network_nodes.add(node_key) # Store the traversal order for layer sorting. if layer not in layer_indices: layer_indices[layer] = len(layer_indices) nodes_in_progress.add(node) # Propagate to all previous tensors connected to this node. for i in range(len(node.inbound_layers)): x = node.input_tensors[i] layer = node.inbound_layers[i] node_index = node.node_indices[i] tensor_index = node.tensor_indices[i] build_map_of_graph(x, finished_nodes, nodes_in_progress, layer, node_index, tensor_index) finished_nodes.add(node) nodes_in_progress.remove(node) nodes_in_decreasing_depth.append(node) finished_nodes = set() nodes_in_progress = set() for x in self.outputs: layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access build_map_of_graph(x, finished_nodes, nodes_in_progress, layer=layer, node_index=node_index, tensor_index=tensor_index) for node in reversed(nodes_in_decreasing_depth): # If the depth is not set, the node has no outbound nodes (depth 0). depth = nodes_depths.setdefault(node, 0) # Update the depth of the corresponding layer previous_depth = layers_depths.get(node.outbound_layer, 0) # If we've seen this layer before at a higher depth, # we should use that depth instead of the node depth. # This is necessary for shared layers that have inputs at different # depth levels in the graph. depth = max(depth, previous_depth) layers_depths[node.outbound_layer] = depth nodes_depths[node] = depth # Update the depth of inbound nodes. # The "depth" of a node is the max of the depths # of all layers it is connected to. for i in range(len(node.inbound_layers)): inbound_layer = node.inbound_layers[i] node_index = node.node_indices[i] inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access previous_depth = nodes_depths.get(inbound_node, 0) nodes_depths[inbound_node] = max(depth + 1, previous_depth) # Build a dict {depth: list of nodes with this depth} nodes_by_depth = {} for node, depth in nodes_depths.items(): if depth not in nodes_by_depth: nodes_by_depth[depth] = [] nodes_by_depth[depth].append(node) # Build a dict {depth: list of layers with this depth} layers_by_depth = {} for layer, depth in layers_depths.items(): if depth not in layers_by_depth: layers_by_depth[depth] = [] layers_by_depth[depth].append(layer) # Get sorted list of layer depths. depth_keys = list(layers_by_depth.keys()) depth_keys.sort(reverse=True) # Set self.layers and self._layers_by_depth. layers = [] for depth in depth_keys: layers_for_depth = layers_by_depth[depth] # GraphNetwork.layers needs to have a deterministic order: # here we order them by traversal order. layers_for_depth.sort(key=lambda x: layer_indices[x]) layers.extend(layers_for_depth) self.layers = layers self._layers_by_depth = layers_by_depth # Get sorted list of node depths. depth_keys = list(nodes_by_depth.keys()) depth_keys.sort(reverse=True) # Check that all tensors required are computable. # computable_tensors: all tensors in the graph # that can be computed from the inputs provided. computable_tensors = [] for x in self.inputs: computable_tensors.append(x) layers_with_complete_input = [] # To provide a better error msg. for depth in depth_keys: for node in nodes_by_depth[depth]: layer = node.outbound_layer if layer: for x in node.input_tensors: if x not in computable_tensors: raise ValueError('Graph disconnected: ' 'cannot obtain value for tensor ' + str(x) + ' at layer "' + layer.name + '". ' 'The following previous layers ' 'were accessed without issue: ' + str(layers_with_complete_input)) for x in node.output_tensors: computable_tensors.append(x) layers_with_complete_input.append(layer.name) # Keep track of the network's nodes. self._network_nodes = network_nodes self._nodes_by_depth = nodes_by_depth # Ensure name unicity, which will be crucial for serialization # (since serialized nodes refer to layers by their name). all_names = [layer.name for layer in self.layers] for name in all_names: if all_names.count(name) != 1: raise ValueError('The name "' + name + '" is used ' + str(all_names.count(name)) + ' times in the model. ' 'All layer names should be unique.') # Layer parameters. # The new network starts with a single inbound node # for its inputs, and no outbound nodes. self._outbound_nodes = [] # Will be appended to by future calls to __call__ self._inbound_nodes = [ ] # Will be appended to below, and by future calls to __call__ # Create the node linking internal inputs to internal outputs. base.Node( outbound_layer=self, inbound_layers=[], node_indices=[], tensor_indices=[], input_tensors=self.inputs, output_tensors=self.outputs)
def _testDecoderGeneric(self, decoder, with_beam_search=False, with_alignment_history=False, support_alignment_history=True): batch_size = 4 beam_width = 5 num_hyps = beam_width if with_beam_search else 1 vocab_size = 10 depth = 6 end_token = 2 start_tokens = tf.placeholder_with_default([1] * batch_size, shape=[None]) memory_sequence_length = [3, 7, 5, 4] memory_time = max(memory_sequence_length) memory = tf.placeholder_with_default( np.random.randn(batch_size, memory_time, depth).astype(np.float32), shape=(None, None, depth)) memory_sequence_length = tf.placeholder_with_default(memory_sequence_length, shape=[None]) embedding = tf.placeholder_with_default( np.random.randn(vocab_size, depth).astype(np.float32), shape=(vocab_size, depth)) if with_beam_search: decode_fn = decoder.dynamic_decode_and_search else: decode_fn = decoder.dynamic_decode additional_kwargs = {} if with_alignment_history: additional_kwargs["return_alignment_history"] = True if with_beam_search: additional_kwargs["beam_width"] = beam_width if (with_beam_search and with_alignment_history and "RNN" in decoder.__class__.__name__ and not "reorder_tensor_arrays" in fn_args(tf.contrib.seq2seq.BeamSearchDecoder.__init__)): support_alignment_history = False outputs = decode_fn( embedding, start_tokens, end_token, vocab_size=vocab_size, maximum_iterations=10, memory=memory, memory_sequence_length=memory_sequence_length, **additional_kwargs) ids = outputs[0] state = outputs[1] lengths = outputs[2] log_probs = outputs[3] decode_time = tf.shape(ids)[-1] with self.test_session() as sess: sess.run(tf.global_variables_initializer()) if not with_alignment_history: self.assertEqual(4, len(outputs)) else: self.assertEqual(5, len(outputs)) alignment_history = outputs[4] if support_alignment_history: self.assertIsInstance(alignment_history, tf.Tensor) with self.test_session() as sess: alignment_history, decode_time = sess.run([alignment_history, decode_time]) self.assertAllEqual( [batch_size, num_hyps, decode_time, memory_time], alignment_history.shape) else: self.assertIsNone(alignment_history) with self.test_session() as sess: ids, lengths, log_probs = sess.run([ids, lengths, log_probs]) self.assertAllEqual([batch_size, num_hyps], ids.shape[0:2]) self.assertAllEqual([batch_size, num_hyps], lengths.shape) self.assertAllEqual([batch_size, num_hyps], log_probs.shape)