def add(self, layer_func): if isinstance(layer_func, base.Layer): args = function_utils.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): args = function_utils.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " "not '%s' of type '%s'." % (layer_func, type(layer_func))) self._layers_funcs.append((("training" in args), layer_func))
def run_step_fn(self, step_fn): """Run ops using a step function. Args: step_fn: A function or a method with a single argument of type `StepContext`. The function may use methods of the argument to perform computations with access to a raw session. The returned value of the `step_fn` will be returned from `run_step_fn`, unless a stop is requested. In that case, the next `should_stop` call will return True. Example usage: ```python with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v = tf.add(c, 4.0) w = tf.add(c, 0.5) def step_fn(step_context): a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) if a <= 4.5: step_context.request_stop() return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1}) with tf.MonitoredSession() as session: while not session.should_stop(): a = session.run_step_fn(step_fn) ``` Hooks interact with the `run_with_hooks()` call inside the `step_fn` as they do with a `MonitoredSession.run` call. Returns: Returns the returned value of `step_fn`. Raises: StopIteration: if `step_fn` has called `request_stop()`. It may be caught by `with tf.MonitoredSession()` to close the session. ValueError: if `step_fn` doesn't have a single argument called `step_context`. It may also optionally have `self` for cases when it belongs to an object. """ step_fn_arguments = function_utils.fn_args(step_fn) if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 'self', 'step_context', ): raise ValueError( '`step_fn` may either have one `step_context` argument, or' ' `self` and `step_context` arguments if it\'s an instance' ' method. Got {} instead.'.format(step_fn_arguments)) # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be # `_CoordinatedSession.run` downstream in either case. This allows # `_PREEMPTION_ERRORS` to propage from within `step_fn` to # `_RecoverableSession.run_step_fn`. return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
def call(*args): kwargs = dict( zip(function_utils.fn_args(getattr(self._type, name))[1:], args)) specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs) if specs is None: raise ValueError( 'No tensor specifications were provided for: %s' % name) flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs)) flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs)) def py_call(*args): try: self._out.send(args) result = self._out.recv() if isinstance(result, Exception): raise result if result is not None: return result except Exception as e: if isinstance(e, IOError): raise StopIteration() # Clean exit. else: raise result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes, name=name) if isinstance(result, tf.Operation): return result for t, shape in zip(result, flat_shapes): t.set_shape(shape) return nest.pack_sequence_as(specs, result)
def eval_step(): """A single step of evaluation.""" estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.EVAL, params) try: captured_scaffold_fn.capture(estimator_spec.scaffold_fn) except AttributeError: captured_scaffold_fn.capture(None) eval_metric_fn = None eval_metric_fn_tensors = [] try: if estimator_spec.eval_metrics: (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics except AttributeError: pass # If a dictionary is provided, we need to convert it into a list sorted # according to order of eval_metric_fn positional arguments. if isinstance(eval_metric_fn_tensors, dict): eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) eval_metric_fn_tensors = [ eval_metric_fn_tensors[i] for i in eval_metric_fn_args ] captured_eval_metric_fn.capture(eval_metric_fn) return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
def test_callable(self): class Foo(object): def __call__(self, a, b): return a + b self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
def test_bounded_method(self): class Foo(object): def bar(self, a, b): return a + b self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
def __init__(self, type_, *constructor_args, **constructor_kwargs): self._type = type_ self._constructor_kwargs = dict( zip(function_utils.fn_args(type_.__init__)[1:], constructor_args)) self._constructor_kwargs.update(constructor_kwargs) tf.add_to_collection(PyProcess.COLLECTION, self) self._proxy = _TFProxy(type_, self._constructor_kwargs)
def _get_standardized_predicate_fn(predicate_fn): pred_fn_args = function_utils.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): return predicate_fn(eval_results) return _pred_fn_wrapper else: return predicate_fn
def _verify_estimator_spec(self, estimator_spec): """Verifies estimator spec contains correct data.""" # TODO(ycao): Implement estimator spec verification for other modes. try: if estimator_spec.scaffold: logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation' '. Please use TPUEstimatorSpec.scaffold_fn instead.') except AttributeError: pass try: if estimator_spec.eval_metric_ops: raise ValueError('EstimatorSpec.eval_metric_ops is not supported with ' 'XLA compilation. Please use ' 'TPUEstimatorSpec.eval_metrics instead.') except AttributeError: pass if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL: # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics, # check that eval_metrics contains eval_metric_fn and # eval_metric_fn_tensors with matching arguments. try: eval_metrics = estimator_spec.eval_metrics except AttributeError: eval_metrics = None if eval_metrics: (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) if isinstance(eval_metric_fn_tensors, dict): missing_tensors = [ i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors ] additional_tensors = [ i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args ] if missing_tensors: raise ValueError('Arguments %s are needed by metric_fn (first ' 'element of TPUEstimatorSpec.eval_metrics) but ' 'they are not provided by evaluation tensors ' '(second element of TPUEstimatorSpec.eval_metrics)' '.' % missing_tensors) if additional_tensors: raise ValueError('Arguments %s are provided by evaluation tensors ' '(second element of TPUEstimatorSpec.eval_metrics)' ' but they are not needed by metric_fn (first ' 'element of TPUEstimatorSpec.eval_metrics).' % additional_tensors) return estimator_spec
def test_partial_function(self): expected_test_arg = 123 def fn(a, test_arg): if test_arg != expected_test_arg: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, test_arg=123) self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
def test_double_partial(self): expected_test_arg1 = 123 expected_test_arg2 = 456 def fn(a, test_arg1, test_arg2): if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: return ValueError('partial does not work correctly') return a wrapped_fn = functools.partial(fn, test_arg2=456) double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features if 'labels' in metric_fn_args: kwargs['labels'] = labels if 'predictions' in metric_fn_args: kwargs['predictions'] = predictions if 'config' in metric_fn_args: kwargs['config'] = config return metric_fn(**kwargs)
def test_partial_function_with_positional_args(self): expected_test_arg = 123 def fn(test_arg, a): if test_arg != expected_test_arg: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, 123) self.assertEqual(('a', ), function_utils.fn_args(wrapped_fn)) self.assertEqual(3, wrapped_fn(3)) self.assertEqual(3, wrapped_fn(a=3))
def test_do_not_convert_argspec(self): class TestClass(object): def test_method(self, x, y): z = x + y return z test_method_whitelisted = api.do_not_convert(test_method) tc = TestClass() self.assertTrue(tf_inspect.ismethod(tc.test_method_whitelisted)) # Because the wrapped function is not generated, we can't preserve its # arg spec. self.assertEqual( (), tuple(function_utils.fn_args(tc.test_method_whitelisted)))
def _validate_properties(run_config): """Validates the properties.""" def _validate(property_name, cond, message): property_value = getattr(run_config, property_name) if property_value is not None and not cond(property_value): raise ValueError(message) _validate('model_dir', lambda dir: dir, message='model_dir should be non-empty') _validate('save_summary_steps', lambda steps: steps >= 0, message='save_summary_steps should be >= 0') _validate('save_checkpoints_steps', lambda steps: steps >= 0, message='save_checkpoints_steps should be >= 0') _validate('save_checkpoints_secs', lambda secs: secs >= 0, message='save_checkpoints_secs should be >= 0') _validate('session_config', lambda sc: isinstance(sc, config_pb2.ConfigProto), message='session_config must be instance of ConfigProto') _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0, message='keep_checkpoint_max should be >= 0') _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0, message='keep_checkpoint_every_n_hours should be > 0') _validate('log_step_count_steps', lambda num_steps: num_steps > 0, message='log_step_count_steps should be > 0') _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types), message='tf_random_seed must be integer.') _validate('device_fn', lambda device_fn: six.callable(device_fn) and set( function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS, message='device_fn must be callable with exactly' ' one argument "op".') _validate('protocol', lambda protocol: protocol in (None, "grpc", "grpc+verbs"), message='protocol should be grpc or grpc+verbs')
def _call_adanet_model_fn(self, input_fn, mode, config): """See the `Estimator` base class for details.""" # Bind parameters to input_fn since the parent's input_fn is not expected to # have any arguments. input_fn_args = function_utils.fn_args(input_fn) kwargs = {} if "mode" in input_fn_args: kwargs["mode"] = mode if "params" in input_fn_args: kwargs["params"] = self.params if "config" in input_fn_args: kwargs["config"] = config input_fn = functools.partial(input_fn, **kwargs) super(TPUEstimator, self)._call_adanet_model_fn(input_fn, mode, config)
def _call_model_fn(self, features, labels, mode, config): model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels if 'mode' in model_fn_args: kwargs['mode'] = mode if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: kwargs['config'] = config model_fn_results = self._model_fn(features=features, **kwargs) return model_fn_results
def test_double_partial_with_positional_args_in_both_layers(self): expected_test_arg1 = 123 expected_test_arg2 = 456 def fn(test_arg1, test_arg2, a): if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, 123) # binds to test_arg1 double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2 self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) self.assertEqual(3, double_wrapped_fn(3)) # pylint: disable=no-value-for-parameter self.assertEqual(3, double_wrapped_fn(a=3)) # pylint: disable=no-value-for-parameter
def test_double_partial_with_positional_args_in_both_layers(self): expected_test_arg1 = 123 expected_test_arg2 = 456 def fn(test_arg1, test_arg2, a): if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, 123) # binds to test_arg1 double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2 self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) self.assertEqual(3, double_wrapped_fn(3)) self.assertEqual(3, double_wrapped_fn(a=3))
def _validate_function_args(function, required_args): """Asserts all of the `required_args` are presented in `function` args. Args: function: (function) A python function. required_args: (list) A list of strings indicating the required args. Raises: ValueError: If any of `required_args` does not present in the `function`. """ fn_args = function_utils.fn_args(function) if set(fn_args) != set(required_args): raise ValueError( "Function `%s` needs to have the following arguments: %s." " What were provided are the following: %s." % (function.__name__, sorted(required_args), sorted(fn_args)))
def test_double_partial_with_positional_args_in_outer_layer(self): expected_test_arg1 = 123 expected_test_arg2 = 456 def fn(test_arg1, a, test_arg2): if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: return ValueError('partial fn does not work correctly') return a wrapped_fn = functools.partial(fn, test_arg2=456) double_wrapped_fn = functools.partial(wrapped_fn, 123) self.assertEqual(('a', ), function_utils.fn_args(double_wrapped_fn)) self.assertEqual(3, double_wrapped_fn(3)) self.assertEqual(3, double_wrapped_fn(a=3))
def verify_model_fn_args(model_fn, params): """Verifies `model_fn` arguments.""" args = set(function_utils.fn_args(model_fn)) if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) if params is not None and 'params' not in args: raise ValueError('model_fn (%s) does not include params argument, ' 'but params (%s) is passed to Estimator.' % (model_fn, params)) if params is None and 'params' in args: tf.compat.v1.logging.warn( 'Estimator\'s model_fn (%s) includes params ' 'argument, but params are not passed to Estimator.', model_fn) non_valid_args = list(args - _VALID_MODEL_FN_ARGS) if non_valid_args: raise ValueError('model_fn (%s) has following not expected args: %s' % (model_fn, non_valid_args))
def test_do_not_convert_preserves_argspec(self): class TestClass(object): @api.do_not_convert(run_as=api.RunMode.GRAPH) def test_method(self, x, y): z = x + y return z test_method_do_not_convert = api.do_not_convert( run_as=api.RunMode.GRAPH)(test_method) tc = TestClass() self.assertTrue(tf_inspect.ismethod(tc.test_method_do_not_convert)) self.assertAllEqual( ('x', 'y'), function_utils.fn_args(tc.test_method_do_not_convert)) self.assertListEqual( list(tf_inspect.getfullargspec(tc.test_method)), list(tf_inspect.getfullargspec(tc.test_method_do_not_convert)))
def _wrap_and_verify_model_fn(model_fn, mode=None, config=None, params=None, input_signature=None): """Returns a function that only has only tensor arguments (features, labels). Args: model_fn: Model function. Must follow the signature defined in `tf.estimator.Estimator`. mode: Optional string `tf.estimstor.ModeKey`. config: Optional `estimator.RunConfig` object. params: Optional `dict` of hyperparameters. input_signature: Possibly nested TensorSpec of the tensor arguments. Returns: tuple of ( function that only accepts tensor arguments (features and/or labels), whether the returned function expects a labels argument) """ model_fn_lib.verify_model_fn_args(model_fn, params) args = function_utils.fn_args(model_fn) kwargs = {} if 'mode' in args: kwargs['mode'] = mode if 'params' in args: kwargs['params'] = params if 'config' in args: kwargs['config'] = config if 'labels' in args: if input_signature is None or len(input_signature) == 2: def wrapped_model_fn(features, labels=None): return model_fn(features=features, labels=labels, **kwargs) else: def wrapped_model_fn(features): return model_fn(features=features, labels=None, **kwargs) else: def wrapped_model_fn(features): return model_fn(features=features, **kwargs) return wrapped_model_fn, 'labels' in args
def call_logit_fn(logit_fn, features, mode, params, config): """Calls logit_fn (experimental). THIS FUNCTION IS EXPERIMENTAL. Keras layers/models are the recommended APIs for logit and model composition. A utility function that calls the provided logit_fn with the relevant subset of provided arguments. Similar to tf.estimator._call_model_fn(). Args: logit_fn: A logit_fn as defined above. features: The features dict. mode: TRAIN / EVAL / PREDICT ModeKeys. params: The hyperparameter dict. config: The configuration object. Returns: A logit Tensor, the output of logit_fn. Raises: ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode if 'params' in logit_fn_args: kwargs['params'] = params if 'config' in logit_fn_args: kwargs['config'] = config logit_fn_results = logit_fn(features=features, **kwargs) result_is_valid_dictionary = ( isinstance(logit_fn_results, dict) and all([(isinstance(k, six.string_types) and isinstance(v, tf.Tensor)) for k, v in six.iteritems(logit_fn_results)])) result_is_tensor = isinstance(logit_fn_results, tf.Tensor) if not (result_is_valid_dictionary or result_is_tensor): raise ValueError('logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors. logit_fn returned: %s' % logit_fn_results) return logit_fn_results
def _call_input_fn_in_new_graph(self, input_fn, mode, config): """See the `Estimator` base class for details.""" # Bind parameters to input_fn since the parent's input_fn is not expected to # have any arguments. from tensorflow.python.util import function_utils # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top input_fn_args = function_utils.fn_args(input_fn) kwargs = {} if "mode" in input_fn_args: kwargs["mode"] = mode if "params" in input_fn_args: kwargs["params"] = self.params if "config" in input_fn_args: kwargs["config"] = config input_fn = functools.partial(input_fn, **kwargs) with super(TPUEstimator, self)._call_input_fn_in_new_graph(input_fn, mode, config) as res: yield res
def _call_model_fn(self, features, labels, mode, config): model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels if 'mode' in model_fn_args: kwargs['mode'] = mode if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: kwargs['config'] = config model_fn_results = self._model_fn(features=features, **kwargs) if not isinstance(model_fn_results, TFEstimatorSpec): raise ValueError('model_fn should return an TFEstimatorSpec.') return model_fn_results
def _call_model_fn(self, features, labels, mode, params): """Calls the model_fn with required parameters.""" model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels elif labels is not None: raise ValueError( 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = mode if 'params' in model_fn_args: kwargs['params'] = params return self._verify_estimator_spec( self._model_fn(features=features, **kwargs))
def add_train_op(model_fn, features, labels, mode, params, config, optimizer): model_fn_args = function_utils.fn_args(model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels else: if labels is not None: raise ValueError( 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = mode if 'params' in model_fn_args: kwargs['params'] = params if 'config' in model_fn_args: kwargs['config'] = config spec = model_fn(features=features, **kwargs) if isinstance(spec, tf.estimator.EstimatorSpec): train_op = spec.train_op else: train_op = None if mode == tf.estimator.ModeKeys.TRAIN and train_op is None: if optimizer is None: raise ValueError( "optimizer should be set when used for training. For example:" + " Estimator(model_fn, tf.train.AdamOptimizer())") grads_and_vars = optimizer.compute_gradients(spec.loss) vars_with_grad = [v for g, v in grads_and_vars if g is not None] if not vars_with_grad: raise ValueError( "No gradients provided for any variable, check your graph for ops" " that do not support gradients, between variables %s and loss %s." % ([str(v) for _, v in grads_and_vars], spec.loss)) train_op = spec.optimizer.apply_gradients( grads_and_vars, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec(mode, spec.predictions, spec.loss, train_op, spec.eval_metric_ops, spec.export_outputs)
def call_logit_fn(logit_fn, features, mode, params, config): """Calls logit_fn. A utility function that calls the provided logit_fn with the relevant subset of provided arguments. Similar to tf.estimator._call_model_fn(). Args: logit_fn: A logit_fn as defined above. features: The features dict. mode: TRAIN / EVAL / PREDICT ModeKeys. params: The hyperparameter dict. config: The configuration object. Returns: A logit Tensor, the output of logit_fn. Raises: ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode if 'params' in logit_fn_args: kwargs['params'] = params if 'config' in logit_fn_args: kwargs['config'] = config logit_fn_results = logit_fn(features=features, **kwargs) result_is_valid_dictionary = ( isinstance(logit_fn_results, dict) and all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor)) for k, v in six.iteritems(logit_fn_results)])) result_is_tensor = isinstance(logit_fn_results, ops.Tensor) if not (result_is_valid_dictionary or result_is_tensor): raise ValueError('logit_fn should return a Tensor or a dictionary mapping ' 'strings to Tensors. logit_fn returned: %s' % logit_fn_results) return logit_fn_results
def call(*args): kwargs = dict( zip( function_utils.fn_args(getattr(self._type, name))[1:], args)) specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs) if specs is None: raise ValueError( 'No tensor specifications were provided for: %s' % name) flat_dtypes = nest.flatten( nest.map_structure(lambda s: s.dtype, specs)) flat_shapes = nest.flatten( nest.map_structure(lambda s: s.shape, specs)) def py_call(*args): try: self._out.send(args) result = self._out.recv() if isinstance(result, Exception): raise result if result is not None: return result except Exception as e: if isinstance(e, IOError): raise StopIteration() # Clean exit. else: raise result = tf.py_func(py_call, (name, ) + tuple(args), flat_dtypes, name=name) if isinstance(result, tf.Operation): return result for t, shape in zip(result, flat_shapes): t.set_shape(shape) retval = nest.pack_sequence_as(specs, result) return retval
def _validate_properties(run_config): """Validates the properties.""" def _validate(property_name, cond, message): property_value = getattr(run_config, property_name) if property_value is not None and not cond(property_value): raise ValueError(message) _validate('model_dir', lambda dir: dir, message='model_dir should be non-empty') _validate('save_summary_steps', lambda steps: steps >= 0, message='save_summary_steps should be >= 0') _validate('save_checkpoints_steps', lambda steps: steps >= 0, message='save_checkpoints_steps should be >= 0') _validate('save_checkpoints_secs', lambda secs: secs >= 0, message='save_checkpoints_secs should be >= 0') _validate('session_config', lambda sc: isinstance(sc, config_pb2.ConfigProto), message='session_config must be instance of ConfigProto') _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0, message='keep_checkpoint_max should be >= 0') _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0, message='keep_checkpoint_every_n_hours should be > 0') _validate('log_step_count_steps', lambda num_steps: num_steps > 0, message='log_step_count_steps should be > 0') _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types), message='tf_random_seed must be integer.') _validate('device_fn', lambda device_fn: six.callable(device_fn) and set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS, message='device_fn must be callable with exactly' ' one argument "op".') _validate('protocol', lambda protocol: protocol in (None, "grpc", "grpc+verbs"), message='protocol should be grpc or grpc+verbs')
def validate_loss_fn_args(loss_fn): """Validates loss_fn arguments. Required arguments: labels, logits. Optional arguments: features, loss_reduction. Args: loss_fn: The loss function. Raises: ValueError: If the signature is unexpected. """ loss_fn_args = function_utils.fn_args(loss_fn) for required_arg in ['labels', 'logits']: if required_arg not in loss_fn_args: raise ValueError( 'loss_fn must contain argument: {}. ' 'Given arguments: {}'.format(required_arg, loss_fn_args)) invalid_args = list(set(loss_fn_args) - set( ['labels', 'logits', 'features', 'loss_reduction'])) if invalid_args: raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
def wrapper(features: Any, labels: Any, mode: Any, params: Any, config: Any) -> Any: # Tensorflow inspects the arguments of `model_fn()`. We provide all the possible # arguments and then inspect the ones that are used by the `model_fn()`. model_fn_args = function_utils.fn_args(f) kwargs = {} if "labels" in model_fn_args: kwargs["labels"] = labels if "mode" in model_fn_args: kwargs["mode"] = mode if "params" in model_fn_args: kwargs["params"] = params if "config" in model_fn_args: kwargs["config"] = config self._set_default_tensorflow_session( env=self.env, hvd_config=self.hvd_config, session_config=config.session_config, ) return f(features, **kwargs)
def _easy_call_model_fn(self, features, labels, mode, config): model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels else: if labels is not None: raise ValueError('model_fn does not take labels, ' 'but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = mode if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: kwargs['config'] = config tf.logging.info('Calling model_fn.') model_fn_results = self._model_fn(features=features, **kwargs) tf.logging.info('Done calling model_fn.\n') if not isinstance(model_fn_results, tf.estimator.EstimatorSpec): raise ValueError('model_fn should return an EstimatorSpec.') return model_fn_results
def _call_input_fn(self, input_fn, mode, input_context=None): """Calls the input function. Args: input_fn: The input function. mode: `tf.estimator.ModeKeys` Returns: The return value of the passed `input_fn`, which should be one of: * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple `(features, labels)` with same constraints as below. * A tuple `(features, labels)`: Where `features` is a `Tensor` or a dictionary of string feature name to `Tensor` and `labels` is a `Tensor` or a dictionary of string label name to `Tensor`. Both `features` and `labels` are consumed by `model_fn`. They should satisfy the expectation of `model_fn` from inputs. Raises: ValueError: if `input_fn` takes invalid arguments. """ input_fn_args = function_utils.fn_args(input_fn) kwargs = {} if 'mode' in input_fn_args: kwargs['mode'] = mode if 'params' in input_fn_args: kwargs['params'] = self.params if 'config' in input_fn_args: kwargs['config'] = self.config if input_context and 'input_context' in input_fn_args: logging.info( 'The `input_fn` accepts an `input_context` which will ' 'be given by DistributionStrategy') kwargs['input_context'] = input_context #with ops.device('/cpu:0'): return input_fn(**kwargs)
def validate(host_calls): """Validates the `eval_metrics` and `host_call` in `NPUEstimatorSpec`.""" for name, host_call in host_calls.items(): if not isinstance(host_call, (tuple, list)): raise ValueError('{} should be tuple or list'.format(name)) if len(host_call) != 2: raise ValueError('{} should have two elements.'.format(name)) if not callable(host_call[0]): raise TypeError('{}[0] should be callable.'.format(name)) if not isinstance(host_call[1], (tuple, list, dict)): raise ValueError( '{}[1] should be tuple or list, or dict.'.format(name)) if isinstance(host_call[1], (tuple, list)): fullargspec = tf_inspect.getfullargspec(host_call[0]) fn_args = function_utils.fn_args(host_call[0]) # wrapped_hostcall_with_global_step uses varargs, so we allow that. if fullargspec.varargs is None and len( host_call[1]) != len(fn_args): raise RuntimeError( 'In NPUEstimatorSpec.{}, length of tensors {} does not match ' 'method args of the function, which takes {}.'.format( name, len(host_call[1]), len(fn_args)))
def test_simple_function(self): def fn(a, b): return a + b self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
def _verify_estimator_spec(self, estimator_spec): """Verifies estimator spec contains correct data.""" # TODO(ycao): Implement estimator spec verification for other modes. try: if estimator_spec.scaffold: logging.warning( 'EstimatorSpec.scaffold is ignored with XLA compilation' '. Please use TPUEstimatorSpec.scaffold_fn instead.') except AttributeError: pass try: if estimator_spec.eval_metric_ops: raise ValueError( 'EstimatorSpec.eval_metric_ops is not supported with ' 'XLA compilation. Please use ' 'TPUEstimatorSpec.eval_metrics instead.') except AttributeError: pass if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL: # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics, # check that eval_metrics contains eval_metric_fn and # eval_metric_fn_tensors with matching arguments. try: eval_metrics = estimator_spec.eval_metrics except AttributeError: eval_metrics = None if eval_metrics: (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics eval_metric_fn_args = function_utils.fn_args(eval_metric_fn) if isinstance(eval_metric_fn_tensors, dict): missing_tensors = [ i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors ] additional_tensors = [ i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args ] if missing_tensors: raise ValueError( 'Arguments %s are needed by metric_fn (first ' 'element of TPUEstimatorSpec.eval_metrics) but ' 'they are not provided by evaluation tensors ' '(second element of TPUEstimatorSpec.eval_metrics)' '.' % missing_tensors) if additional_tensors: raise ValueError( 'Arguments %s are provided by evaluation tensors ' '(second element of TPUEstimatorSpec.eval_metrics)' ' but they are not needed by metric_fn (first ' 'element of TPUEstimatorSpec.eval_metrics).' % additional_tensors) return estimator_spec
def _verify_metric_fn_args(metric_fn): args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % (metric_fn, invalid_args))
def _validate_properties(run_config): """Validates the properties.""" def _validate(property_name, cond, message): property_value = getattr(run_config, property_name) if property_value is not None and not cond(property_value): raise ValueError(message) def _validate_delay(delay): """Check that delay is an integer value. Since this has to work for both Python2 and Python3 and PEP237 defines long to be basically int, we cannot just use a lambda function. """ try: return isinstance(delay, (int, long)) except NameError: # PEP237 redefines long to int for Python3 return isinstance(delay, int) _validate('model_dir', lambda dir: dir, message='model_dir should be non-empty') _validate('save_summary_steps', lambda steps: steps >= 0, message='save_summary_steps should be >= 0') _validate('save_checkpoints_steps', lambda steps: steps >= 0, message='save_checkpoints_steps should be >= 0') _validate('save_checkpoints_secs', lambda secs: secs >= 0, message='save_checkpoints_secs should be >= 0') _validate('session_config', lambda sc: isinstance(sc, config_pb2.ConfigProto), message='session_config must be instance of ConfigProto') _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0, message='keep_checkpoint_max should be >= 0') _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0, message='keep_checkpoint_every_n_hours should be > 0') _validate('log_step_count_steps', lambda num_steps: num_steps > 0, message='log_step_count_steps should be > 0') _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types), message='tf_random_seed must be integer.') _validate( 'experimental_max_worker_delay_secs', _validate_delay, message='experimental_max_worker_delay_secs must be an integer if' ' set.') _validate('session_creation_timeout_secs', lambda timeout_secs: timeout_secs > 0, message='session_creation_timeout_secs should be > 0') _validate('device_fn', lambda device_fn: six.callable(device_fn) and set( function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS, message='device_fn must be callable with exactly' ' one argument "op".') _validate('protocol', lambda protocol: protocol in (None, 'grpc', 'grpc+verbs'), message='protocol should be grpc or grpc+verbs')
def _get_loss_towers(model_fn, mode, features, labels, params, config, devices, local_ps_devices, loss_reduction, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] model_fn_args = function_utils.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) # pylint: disable=protected-access round_robin_strategy = device_setter_lib._RoundRobinStrategy( num_tasks=len(local_ps_devices)) TowerOptimizer._graph_state().set_reduction_across_towers( loss_reduction, len(devices)) for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( worker_device=device, ps_devices=local_ps_devices, ps_strategy=round_robin_strategy) # We would like to preserve the names of the variables and ops that the user # might be relying on. Names without a prefix are going to resolve to # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' with variable_scope.variable_scope( '', reuse=not is_the_first_tower) as var_scope: with ops_lib.name_scope(name_scope.format(i)) as name_scope: with TowerOptimizer._graph_state().tower( tower_id=i, var_scope=var_scope, name_scope=name_scope): with ops_lib.device(device_setter): labels_shard = None if labels: labels_shard = labels[i] tower_spec = model_fn( mode=mode, features=features[i], labels=labels_shard, **optional_params) if (tower_spec.train_op is not None and len(devices) > 1 and not TowerOptimizer.has_been_used()): raise ValueError('Please wrap optimizers with TowerOptimizer' ' in order to use replicate_model_fn with' ' multiple `devices`.') # Scaling the loss here doesn't actually affect gradients. Another # instance of scaling happens inside the TowerOptimizer. tower_spec = _scale_tower_loss( tower_spec, loss_reduction, number_of_towers=len(devices)) tower_specs.append(tower_spec) if not TowerOptimizer._did_towers_have_same_optimizer_calls(): raise ValueError('Each invocation of model_fn was supposed to make the same' ' optimizer calls.') TowerOptimizer._clear_graph_state() # pylint: enable=protected-access return tower_specs
def call(*args): kwargs = dict( zip( function_utils.fn_args(getattr(self._type, name))[1:], args)) specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs) # print("name is: ", name) # print("kwargs are: ", kwargs) # print("specs are: ", specs) if specs is None: raise ValueError( 'No tensor specifications were provided for: %s' % name) flat_dtypes = nest.flatten( nest.map_structure(lambda s: s.dtype, specs)) # print("(py_process.py) tensor specs: ", specs) # print() # print("(py_process.py) Flat dtypes: ", flat_dtypes) flat_shapes = nest.flatten( nest.map_structure(lambda s: s.shape, specs)) def py_call(*args): # print("(PyProcess.py) args are: ", args) # print("(PyProcess.py) first args are: ", args[0]) # print("(PyProcess.py) type of args are: ", type(args)) try: self._out.send(args) result = self._out.recv() # print("(PyProcess.py) result is: ", result) # print("Result: ", result) # print("This is what is given to initial(): ", result) if isinstance(result, Exception): # print("Result: ", result) # print("(PyProcess.py) Exception is: ", result) raise result if result is not None: # print("(PyProcess.py) good result: ", type(result[1])) return result except Exception as e: if isinstance(e, IOError): raise StopIteration() # Clean exit. else: raise # print("name is: ", name) # print("args are: ", tuple(args)) # print("(PyProcess.py) Pycall is: ", type(py_call)) # print("Before result in pyprocess") # print("(PyProcess.py) 2nd argument: ", (name,) + tuple(args)) result = tf.py_func(py_call, (name, ) + tuple(args), flat_dtypes, name=name) # print("after result in pyprocess") # print("(PyProcess.py) operation result is: ", result) if isinstance(result, tf.Operation): # print("(PyProcess.py) operation result is: ", result) return result for t, shape in zip(result, flat_shapes): t.set_shape(shape) retval = nest.pack_sequence_as(specs, result) return retval
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). Note: - If the layer's `call` method takes a `scope` keyword argument, this argument will be automatically set to the current variable scope. - If the layer's `call` method takes a `mask` argument (as some Keras layers do), its default value will be set to the mask generated for `inputs` by the previous layer (if `input` did come from a layer that generated a corresponding mask, i.e. if it came from a Keras layer with masking support. Raises: ValueError: if the layer's `call` method returns None (an invalid value). """ self._set_scope(kwargs.pop('scope', None)) if not context.executing_eagerly(): try: # Set layer's "graph" at build time self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access graph=self._graph) except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) if self.built: try: # Some classes which inherit from Layer do not use its constructor, so # rather than initializing to None we check for an AttributeError. scope_context_manager = self._always_reuse_variable_scope except AttributeError: # From this point we will always set reuse=True, so create a "final" # variable scope with this setting. We avoid re-creating variable scopes # after this point as an optimization. self._always_reuse_variable_scope = vs.variable_scope( self._scope, reuse=True, auxiliary_name_scope=False) scope_context_manager = self._always_reuse_variable_scope else: scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse, auxiliary_name_scope=False) with scope_context_manager as scope: self._current_scope = scope try: call_has_scope_arg = self._call_has_scope_arg except AttributeError: self._call_fn_args = function_utils.fn_args(self.call) self._call_has_scope_arg = 'scope' in self._call_fn_args call_has_scope_arg = self._call_has_scope_arg if call_has_scope_arg: kwargs['scope'] = scope # Actually call layer outputs = super(Layer, self).__call__(inputs, *args, **kwargs) if not context.executing_eagerly(): # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs
def __call__(self, inputs, *args, **kwargs): """Wraps `call`, applying pre- and post-processing steps. Arguments: inputs: input tensor(s). *args: additional positional arguments to be passed to `self.call`. **kwargs: additional keyword arguments to be passed to `self.call`. **Note**: kwarg `scope` is reserved for use by the layer. Returns: Output tensor(s). Note: - If the layer's `call` method takes a `scope` keyword argument, this argument will be automatically set to the current variable scope. - If the layer's `call` method takes a `mask` argument (as some Keras layers do), its default value will be set to the mask generated for `inputs` by the previous layer (if `input` did come from a layer that generated a corresponding mask, i.e. if it came from a Keras layer with masking support. Raises: ValueError: if the layer's `call` method returns None (an invalid value). """ scope = kwargs.pop('scope', None) if self._keras_style: if scope is not None: raise ValueError( 'scope argument not allowed when keras style layers are enabled, ' 'but saw: {}'.format(scope)) return super(Layer, self).__call__(inputs, *args, **kwargs) self._set_scope(scope) if not context.executing_eagerly(): try: # Set layer's "graph" at build time self._graph = ops._get_graph_from_inputs(nest.flatten(inputs), # pylint: disable=protected-access graph=self._graph) except ValueError as e: raise ValueError('Input graph and Layer graph are not the same: %s' % e) if self.built: try: # Some classes which inherit from Layer do not use its constructor, so # rather than initializing to None we check for an AttributeError. scope_context_manager = self._always_reuse_variable_scope except AttributeError: # From this point we will always set reuse=True, so create a "final" # variable scope with this setting. We avoid re-creating variable scopes # after this point as an optimization. self._always_reuse_variable_scope = vs.variable_scope( self._scope, reuse=True, auxiliary_name_scope=False) scope_context_manager = self._always_reuse_variable_scope else: scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse, auxiliary_name_scope=False) with scope_context_manager as scope: self._current_scope = scope try: call_has_scope_arg = self._call_has_scope_arg except AttributeError: self._call_fn_args = function_utils.fn_args(self.call) self._call_has_scope_arg = 'scope' in self._call_fn_args call_has_scope_arg = self._call_has_scope_arg if call_has_scope_arg: kwargs['scope'] = scope # Actually call layer outputs = super(Layer, self).__call__(inputs, *args, **kwargs) if not context.executing_eagerly(): # Update global default collections. _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS) return outputs