Beispiel #1
0
    def _call_model_fn(self, features, labels, mode):
        """Calls model function with support of 2, 3 or 4 arguments.

        Args:
            features: features dict.
            labels: labels dict.
            mode: ModeKeys

        Returns:
            A `ModelFnOps` object.
            If model_fn returns a tuple, wraps them up in a `ModelFnOps` object.

        Raises:
            ValueError: if model_fn returns invalid objects.
        """
        model_fn_args = get_arguments(self._model_fn)
        kwargs = {}
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = self.config
        model_fn_results = self._model_fn(features=features,
                                          labels=labels,
                                          **kwargs)

        if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):
            raise ValueError('model_fn should return an EstimatorSpec.')

        return model_fn_results
Beispiel #2
0
        def graph_fn(mode, features, labels=None):
            kwargs = {}
            if 'labels' in get_arguments(self._graph_fn):
                kwargs['labels'] = labels

            graph_outputs = self._graph_fn(mode=mode,
                                           features=features,
                                           **kwargs)
            a = Dense(units=self.num_actions)(graph_outputs)
            v = None

            if self.dueling is not None:
                # Q = V(s) + A(s, a)
                v = Dense(units=1)(graph_outputs)
                if self.dueling == 'mean':
                    q = v + (a - tf.reduce_mean(a, axis=1, keep_dims=True))
                elif self.dueling == 'max':
                    q = v + (a - tf.reduce_max(a, axis=1, keep_dims=True))
                elif self.dueling == 'naive':
                    q = v + a
                elif self.dueling is True:
                    q = tf.identity(a)
                else:
                    raise ValueError("The value `{}` provided for "
                                     "dueling is unsupported.".format(
                                         self.dueling))
            else:
                q = tf.identity(a)

            return QModelSpec(graph_outputs=graph_outputs, a=a, v=v, q=q)
def _check_method_supports_args(method, kwargs):
    """Checks that the given method supports the given args."""
    supported_args = tuple(get_arguments(method))
    for kwarg in kwargs:
        if kwarg not in supported_args:
            raise ValueError(
                'Argument `{}` is not supported in method {}.'.format(kwarg, method))
Beispiel #4
0
    def decode(self, features, labels, decoder_fn, *args, **kwargs):
        """Decodes the incoming tensor if it's validates against the state size of the decoder.
        Otherwise, generates a random value.

        Args:
            features: `Tensor`
            labels: `dict` or `Tensor`
            decoder_fn: `function`.
            *args:
            **kwargs:
        """
        incoming_shape = get_shape(features)
        if incoming_shape[1:] != self.state_size:
            raise ValueError(
                '`incoming` tensor is incompatible with decoder function, '
                'expects a tensor with shape `{}`, '
                'received instead `{}`'.format(self.state_size,
                                               incoming_shape[1:]))

        # TODO: make decode capable of generating values directly,
        # TODO: basically accepting None incoming values. Should also specify a distribution.

        # shape = self._get_decoder_shape(incoming)
        # return decoder_fn(mode=self.mode, inputs=tf.random_normal(shape=shape))
        if 'labels' in get_arguments(decoder_fn):
            kwargs['labels'] = labels

        x = decoder_fn(mode=self.mode, features=features, **kwargs)
        if not isinstance(x, DecoderSpec):
            raise ValueError('`decoder_fn` should return an DecoderSpec.')
        return x.output
Beispiel #5
0
    def _call_model_fn(self, features, labels, mode):
        """Calls model function.

        Args:
          features: features dict.
          labels: labels dict.
          mode: ModeKeys

        Returns:
          An `EstimatorSpec` object.

        Raises:
          ValueError: if model_fn returns invalid objects.
        """
        model_fn_args = get_arguments(self._model_fn)
        kwargs = {}
        if 'labels' in model_fn_args:
            kwargs['labels'] = labels
        else:
            if labels is not None:
                raise ValueError(
                    'model_fn does not take labels, but input_fn returns labels.'
                )
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = self.config
        model_fn_results = self._model_fn(features=features, **kwargs)

        if not isinstance(model_fn_results, EstimatorSpec):
            raise ValueError('model_fn should return an EstimatorSpec.')

        return model_fn_results
Beispiel #6
0
    def decay_fn(learning_rate, global_step):
        """The computed learning rate decay function."""
        global_step = tf.to_int32(global_step)
        decay_type_fn = getattr(tf.train, decay_type)
        kwargs = dict(
            learning_rate=learning_rate,
            global_step=tf.minimum(global_step, stop_decay_at) - start_decay_at,
            decay_steps=decay_steps,
            staircase=staircase,
            name="decayed_learning_rate"
        )
        decay_fn_args = get_arguments(decay_type_fn)
        if 'decay_rate' in decay_fn_args:
            kwargs['decay_rate'] = decay_rate
        if 'staircase' in decay_fn_args:
            kwargs['staircase'] = staircase

        decayed_learning_rate = decay_type_fn(**kwargs)

        final_lr = tf.train.piecewise_constant(
            x=global_step,
            boundaries=[start_decay_at],
            values=[learning_rate, decayed_learning_rate])

        if min_learning_rate:
            final_lr = tf.maximum(final_lr, min_learning_rate)

        return final_lr
Beispiel #7
0
    def __init__(self,
                 mode,
                 name,
                 graph_fn,
                 loss_config,
                 optimizer_config,
                 model_type,
                 eval_metrics_config=None,
                 summaries='all',
                 clip_gradients=0.5,
                 params=None):
        super(BaseModel, self).__init__(mode, name, self.ModuleType.MODEL)
        self.loss_config = loss_config
        self.optimizer_config = optimizer_config
        self.eval_metrics_config = eval_metrics_config or []
        self.params = params
        self.model_type = model_type
        self.summaries = summarizer.SummaryOptions.validate(summaries)
        assert model_type in self.Types.VALUES, "`model_type` provided is unsupported."
        self._clip_gradients = clip_gradients
        self._grads_and_vars = None
        self._total_loss = None
        self._loss = None

        if graph_fn is not None:
            # Check number of arguments of the given function matches requirements.
            model_fn_args = get_arguments(graph_fn)
            if 'mode' not in model_fn_args or 'inputs' not in model_fn_args:
                raise ValueError(
                    "Model's graph_fn `{}` expects should have 2 args: "
                    "`mode` and `inputs`.".format(graph_fn))
        else:
            raise ValueError("`graph_fn` must be provided to Model.")

        self._graph_fn = graph_fn
Beispiel #8
0
        def graph_fn(mode, features, labels=None):
            kwargs = {}
            if 'labels' in get_arguments(self._graph_fn):
                kwargs['labels'] = labels

            graph_outputs = self._graph_fn(mode=mode, features=features, **kwargs)
            a = FullyConnected(mode, num_units=self.num_actions)(graph_outputs)
            v = None

            if self.dueling is not None:
                # Q = V(s) + A(s, a)
                v = FullyConnected(mode, num_units=1)(graph_outputs)
                if self.dueling == 'mean':
                    q = v + (a - tf.reduce_mean(a, axis=1, keep_dims=True))
                elif self.dueling == 'max':
                    q = v + (a - tf.reduce_max(a, axis=1, keep_dims=True))
                elif self.dueling == 'naive':
                    q = v + a
                elif self.dueling is True:
                    q = tf.identity(a)
                else:
                    raise ValueError("The value `{}` provided for "
                                     "dueling is unsupported.".format(self.dueling))
            else:
                q = tf.identity(a)

            return QModelSpec(graph_outputs=graph_outputs, a=a, v=v, q=q)
Beispiel #9
0
    def decay_fn(timestep):
        """The computed decayed exploration rate.

        Args:
            timestep: the current timestep.
        """
        timestep = tf.to_int32(timestep)
        decay_type_fn = getattr(exploration_decay, decay_type)
        kwargs = dict(
            exploration_rate=exploration_rate,
            timestep=tf.minimum(timestep, tf.to_int32(stop_decay_at)) -
            tf.to_int32(start_decay_at),
            decay_steps=decay_steps,
            name="decayed_exploration_rate")
        decay_fn_args = get_arguments(decay_type_fn)
        if 'decay_rate' in decay_fn_args:
            kwargs['decay_rate'] = decay_rate
        if 'staircase' in decay_fn_args:
            kwargs['staircase'] = staircase

        decayed_exploration_rate = decay_type_fn(**kwargs)

        final_exploration_rate = tf.train.piecewise_constant(
            x=timestep,
            boundaries=[start_decay_at],
            values=[exploration_rate, decayed_exploration_rate])

        if min_exploration_rate:
            final_exploration_rate = tf.maximum(final_exploration_rate,
                                                min_exploration_rate)

        return final_exploration_rate
Beispiel #10
0
    def _call_model_fn(self, features, labels, mode):
        """Calls model function with support of 2, 3 or 4 arguments.

        Args:
            features: features dict.
            labels: labels dict.
            mode: Modes

        Returns:
            A `ModelFnOps` object.
            If model_fn returns a tuple, wraps them up in a `ModelFnOps` object.

        Raises:
            ValueError: if model_fn returns invalid objects.
        """
        model_fn_args = get_arguments(self._model_fn)
        kwargs = {}
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = self.config
        model_fn_results = self._model_fn(features=features, labels=labels, **kwargs)

        if not isinstance(model_fn_results, EstimatorSpec):
            raise ValueError('model_fn should return an EstimatorSpec.')

        return model_fn_results
Beispiel #11
0
    def _verify_model_fn_args(model_fn, params):
        """Verifies model fn arguments."""

        MODEL_FN_ARGS = {'features', 'labels', 'mode', 'params', 'config'}

        if model_fn is not None:
            # Check number of arguments of the given function matches requirements.
            model_fn_args = get_arguments(model_fn)
            if 'features' not in model_fn_args:
                raise ValueError('model_fn `{}` must include features argument.'.format(model_fn))
            if 'labels' not in model_fn_args:
                raise ValueError('model_fn `{}` must include labels argument.'.format(model_fn))

            if params is not None and 'params' not in model_fn_args:
                raise ValueError("Estimator's model_fn `{}` does not include params argument, "
                                 "but params `{}` are passed.".format(model_fn, params))
            if params is None and 'params' in model_fn_args:
                logging.warning("Estimator's model_fn (%s) includes params "
                                "argument, but params are not passed to Estimator.", model_fn)
        else:
            raise ValueError("`model_fn` must be provided to Estimator.")

        if 'self' in model_fn_args:
            model_fn_args.remove('self')

        non_valid_args = set(model_fn_args) - MODEL_FN_ARGS
        if non_valid_args:
            raise ValueError("model_fn `{}` has following not expected args: {}".format(
                model_fn, non_valid_args))
Beispiel #12
0
    def decode(self, features, labels, decoder_fn, *args, **kwargs):
        """Decodes the incoming tensor if it's validates against the state size of the decoder.
        Otherwise, generates a random value.

        Args:
            features: `Tensor`
            labels: `dict` or `Tensor`
            decoder_fn: `function`.
            *args:
            **kwargs:
        """
        incoming_shape = get_shape(features)
        if incoming_shape[1:] != self.state_size:
            raise ValueError('`incoming` tensor is incompatible with decoder function, '
                             'expects a tensor with shape `{}`, '
                             'received instead `{}`'.format(self.state_size, incoming_shape[1:]))

        # TODO: make decode capable of generating values directly,
        # TODO: basically accepting None incoming values. Should also specify a distribution.

        # shape = self._get_decoder_shape(incoming)
        # return decoder_fn(mode=self.mode, inputs=tf.random_normal(shape=shape))
        if 'labels' in get_arguments(decoder_fn):
            kwargs['labels'] = labels

        x = decoder_fn(mode=self.mode, features=features, **kwargs)
        if not isinstance(x, DecoderSpec):
            raise ValueError('`decoder_fn` should return an DecoderSpec.')
        return x.output
Beispiel #13
0
    def decay_fn(learning_rate, global_step):
        """The computed learning rate decay function."""
        global_step = tf.to_int32(global_step)
        decay_type_fn = getattr(tf.train, decay_type)
        kwargs = dict(
            learning_rate=learning_rate,
            global_step=tf.minimum(global_step, stop_decay_at) - start_decay_at,
            decay_steps=decay_steps,
            staircase=staircase,
            name="decayed_learning_rate"
        )
        decay_fn_args = get_arguments(decay_type_fn)
        if 'decay_rate' in decay_fn_args:
            kwargs['decay_rate'] = decay_rate
        if 'staircase' in decay_fn_args:
            kwargs['staircase'] = staircase

        decayed_learning_rate = decay_type_fn(**kwargs)

        final_lr = tf.train.piecewise_constant(
            x=global_step,
            boundaries=[start_decay_at],
            values=[learning_rate, decayed_learning_rate])

        if min_learning_rate:
            final_lr = tf.maximum(final_lr, min_learning_rate)

        return final_lr
Beispiel #14
0
def _check_method_supports_args(method, kwargs):
    """Checks that the given method supports the given args."""
    supported_args = tuple(get_arguments(method))
    for kwarg in kwargs:
        if kwarg not in supported_args:
            raise ValueError(
                'Argument `{}` is not supported in method {}.'.format(
                    kwarg, method))
Beispiel #15
0
    def _call_graph_fn(self, features, labels=None):
        """Calls graph function.

        Args:
            features: `Tensor` or `dict` of tensors
            labels: `Tensor` or `dict` of tensors
        """
        kwargs = {}
        if 'labels' in get_arguments(self._graph_fn):
            kwargs['labels'] = labels
        return self._graph_fn(mode=self.mode, features=features, **kwargs)
Beispiel #16
0
    def _call_graph_fn(self, features, labels=None):
        """Calls graph function.

        Args:
            features: `Tensor` or `dict` of tensors
            labels: `Tensor` or `dict` of tensors
        """
        kwargs = {}
        if 'labels' in get_arguments(self._graph_fn):
            kwargs['labels'] = labels
        return self._graph_fn(mode=self.mode, features=features, **kwargs)
Beispiel #17
0
 def _check_subgraph_fn(function, function_name):
     """Checks that the functions provided for constructing the graph has a valid signature."""
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         if 'mode' not in model_fn_args or 'features' not in model_fn_args:
             raise ValueError(
                 "Model's `{}` `{}` should have at least 2 args: "
                 "`mode`, `features`, and possibly `features`.".format(function_name, function))
     else:
         raise ValueError("`{}` must be provided to Model.".format(function_name))
Beispiel #18
0
 def _check_subgraph_fn(function, function_name):
     """Checks that the functions provided for constructing the graph has a valid signature."""
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         if 'mode' not in model_fn_args or 'features' not in model_fn_args:
             raise ValueError(
                 "Model's `{}` `{}` should have at least 2 args: "
                 "`mode`, `features`, and possibly `features`.".format(function_name, function))
     else:
         raise ValueError("`{}` must be provided to Model.".format(function_name))
Beispiel #19
0
 def _check_subgraph_fn(function, function_name):
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         if 'mode' not in model_fn_args or 'inputs' not in model_fn_args:
             raise ValueError("Model's `{}` `{}` should have 2 args: "
                              "`mode` and `inputs`.".format(
                                  function_name, function))
     else:
         raise ValueError(
             "`{}` must be provided to Model.".format(function_name))
Beispiel #20
0
    def __init__(self, mode, build_fn, name=None):
        if not callable(build_fn):
            raise TypeError("`build_fn` must be callable.")

        build_fn_args = get_arguments(build_fn)
        if 'mode' not in build_fn_args:
            raise ValueError("`build_fn` must include `mode` argument.")

        self._build_fn = build_fn
        super(FunctionModule,
              self).__init__(mode=mode,
                             name=name or get_function_name(build_fn),
                             module_type=self.ModuleType.IMAGE_PROCESSOR)
Beispiel #21
0
def _decay_fn(timestep, exploration_rate,  decay_type='polynomial_decay', start_decay_at=0,
              stop_decay_at=1e9, decay_rate=0., staircase=False, decay_steps=100000,
              min_exploration_rate=0):
    """The computed decayed exploration rate.

    Args:
        timestep: the current timestep.
        exploration_rate: `float` or `list` of `float` or `function`.
            The initial value of the exploration rate.
        decay_type: A decay function name defined in `exploration_decay`
            possible Values: exponential_decay, inverse_time_decay, natural_exp_decay,
                             piecewise_constant, polynomial_decay.
        start_decay_at: `int`. When to start the decay.
        stop_decay_at: `int`. When to stop the decay.
        decay_rate: A Python number.  The decay rate.
        staircase: Whether to apply decay in a discrete staircase,
            as opposed to continuous, fashion.
        decay_steps: How often to apply decay.
        min_exploration_rate: `float`. Don't decay below this number.
    """
    if isinstance(exploration_rate, partial):
        _exploration_rate = exploration_rate()
    else:
        _exploration_rate = exploration_rate

    timestep = tf.to_int32(timestep)
    decay_type_fn = getattr(exploration_decay, decay_type)
    kwargs = dict(
        exploration_rate=_exploration_rate,
        timestep=tf.minimum(timestep, tf.to_int32(stop_decay_at)) - tf.to_int32(start_decay_at),
        decay_steps=decay_steps,
        name="decayed_exploration_rate"
    )
    decay_fn_args = get_arguments(decay_type_fn)
    if 'decay_rate' in decay_fn_args:
        kwargs['decay_rate'] = decay_rate
    if 'staircase' in decay_fn_args:
        kwargs['staircase'] = staircase

    decayed_exploration_rate = decay_type_fn(**kwargs)

    final_exploration_rate = tf.train.piecewise_constant(
        x=timestep,
        boundaries=[start_decay_at],
        values=[exploration_rate, decayed_exploration_rate])

    if min_exploration_rate:
        final_exploration_rate = tf.maximum(final_exploration_rate, min_exploration_rate)

    return final_exploration_rate
Beispiel #22
0
 def _check_bridge_fn(function):
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         if ('mode' not in model_fn_args or
                 'loss_config' not in model_fn_args or
                 'inputs' not in model_fn_args or
                 'encoder_fn' not in model_fn_args or
                 'encoder_fn' not in model_fn_args):
             raise ValueError(
                 "Model's `bridge` `{}` should have 4 args: "
                 "`mode`, `inputs`, `encoder_fn`, and `decoder_fn`.".format(function))
     else:
         raise ValueError("`bridge_fn` must be provided to Model.")
Beispiel #23
0
def _decay_fn(timestep, exploration_rate, decay_type='polynomial_decay', start_decay_at=0,
              stop_decay_at=1e9, decay_rate=0., staircase=False, decay_steps=100000,
              min_exploration_rate=0):
    """The computed decayed exploration rate.

    Args:
        timestep: the current timestep.
        exploration_rate: `float` or `list` of `float` or `function`.
            The initial value of the exploration rate.
        decay_type: A decay function name defined in `exploration_decay`
            possible Values: exponential_decay, inverse_time_decay, natural_exp_decay,
                             piecewise_constant, polynomial_decay.
        start_decay_at: `int`. When to start the decay.
        stop_decay_at: `int`. When to stop the decay.
        decay_rate: A Python number.  The decay rate.
        staircase: Whether to apply decay in a discrete staircase,
            as opposed to continuous, fashion.
        decay_steps: How often to apply decay.
        min_exploration_rate: `float`. Don't decay below this number.
    """
    if isinstance(exploration_rate, partial):
        _exploration_rate = exploration_rate()
    else:
        _exploration_rate = exploration_rate

    timestep = tf.to_int32(timestep)
    decay_type_fn = getattr(exploration_decay, decay_type)
    kwargs = dict(
        exploration_rate=_exploration_rate,
        timestep=tf.minimum(timestep, tf.to_int32(stop_decay_at)) - tf.to_int32(start_decay_at),
        decay_steps=decay_steps,
        name="decayed_exploration_rate"
    )
    decay_fn_args = get_arguments(decay_type_fn)
    if 'decay_rate' in decay_fn_args:
        kwargs['decay_rate'] = decay_rate
    if 'staircase' in decay_fn_args:
        kwargs['staircase'] = staircase

    decayed_exploration_rate = decay_type_fn(**kwargs)

    final_exploration_rate = tf.train.piecewise_constant(
        x=timestep,
        boundaries=[start_decay_at],
        values=[exploration_rate, decayed_exploration_rate])

    if min_exploration_rate:
        final_exploration_rate = tf.maximum(final_exploration_rate, min_exploration_rate)

    return final_exploration_rate
Beispiel #24
0
        def graph_fn(mode, features, labels=None):
            kwargs = {}
            if 'labels' in get_arguments(self._graph_fn):
                kwargs['labels'] = labels

            graph_outputs = self._graph_fn(mode=mode, features=features, **kwargs)
            a = FullyConnected(mode, num_units=self.num_actions)(graph_outputs)
            if self.is_continuous:
                values = tf.concat(values=[a, tf.exp(a) + 1], axis=0)
                distribution = self._build_distribution(values=values)
            else:
                values = tf.identity(a)
                distribution = self._build_distribution(values=a)
            return PGModelSpec(
                graph_outputs=graph_outputs, a=a, distribution=distribution, dist_values=values)
Beispiel #25
0
        def graph_fn(mode, features, labels=None):
            kwargs = {}
            if 'labels' in get_arguments(self._graph_fn):
                kwargs['labels'] = labels

            graph_outputs = self._graph_fn(mode=mode, features=features, **kwargs)
            a = FullyConnected(mode, num_units=self.num_actions)(graph_outputs)
            if self.is_continuous:
                values = tf.concat(values=[a, tf.exp(a) + 1], axis=0)
                distribution = self._build_distribution(values=values)
            else:
                values = tf.identity(a)
                distribution = self._build_distribution(values=a)
            return PGModelSpec(
                graph_outputs=graph_outputs, a=a, distribution=distribution, dist_values=values)
Beispiel #26
0
 def _check_bridge_fn(function):
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         if ('mode' not in model_fn_args or
                 'loss_config' not in model_fn_args or
                 'features' not in model_fn_args or
                 'labels' not in model_fn_args or
                 'encoder_fn' not in model_fn_args or
                 'encoder_fn' not in model_fn_args):
             raise ValueError(
                 "Model's `bridge` `{}` should have these args: "
                 "`mode`, `features`, `labels`, `encoder_fn`, "
                 "and `decoder_fn`.".format(function))
     else:
         raise ValueError("`bridge_fn` must be provided to Model.")
Beispiel #27
0
 def _check_bridge_fn(function):
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         cond = ('mode' not in model_fn_args or 'loss' not in model_fn_args
                 or 'features' not in model_fn_args
                 or 'labels' not in model_fn_args
                 or 'encoder_fn' not in model_fn_args
                 or 'encoder_fn' not in model_fn_args)
         if cond:
             raise ValueError(
                 "Model's `bridge` `{}` should have these args: "
                 "`mode`, `features`, `labels`, `encoder_fn`, "
                 "and `decoder_fn`.".format(function))
     else:
         raise ValueError("`bridge_fn` must be provided to Model.")
Beispiel #28
0
    def encode(self, features, labels, encoder_fn, *args, **kwargs):
        """Encodes the incoming tensor.

        Args:
            features: `Tensor`.
            labels: `dict` or `Tensor`
            encoder_fn: `function`.
            *args:
            **kwargs:
        """
        if 'labels' in get_arguments(encoder_fn):
            kwargs['labels'] = labels
        output = encoder_fn(mode=self.mode, features=features, **kwargs)

        if self.state_size is None:
            self.state_size = get_shape(output)[1:]
        return output
Beispiel #29
0
    def decode(self, features, labels, decoder_fn, *args, **kwargs):
        """Decodes the incoming tensor if it's validates against the state size of the decoder.
        Otherwise, generates a random value.

        Args:
            features: `Tensor`
            labels: `dict` or `Tensor`
            decoder_fn: `function`.
            *args:
            **kwargs:
        """
        # TODO: make decode capable of generating values directly,
        # TODO: basically accepting None incoming values. Should also specify a distribution.

        # shape = self._get_decoder_shape(incoming)
        # return decoder_fn(mode=self.mode, inputs=tf.random_normal(shape=shape))
        if 'labels' in get_arguments(decoder_fn):
            kwargs['labels'] = labels

        return decoder_fn(mode=self.mode, features=features, **kwargs)
Beispiel #30
0
    def encode(self, features, labels, encoder_fn, *args, **kwargs):
        """Encodes the incoming tensor.

        Args:
            features: `Tensor`.
            labels: `dict` or `Tensor`
            encoder_fn: `function`.
            *args:
            **kwargs:
        """
        if 'labels' in get_arguments(encoder_fn):
            kwargs['labels'] = labels
        x = encoder_fn(mode=self.mode, features=features, **kwargs)

        if not isinstance(x, EncoderSpec):
            raise ValueError('`encoder_fn` should return an EncoderSpec.')

        if self.state_size is None:
            self.state_size = x.output_size
        return x.output
Beispiel #31
0
    def encode(self, features, labels, encoder_fn, *args, **kwargs):
        """Encodes the incoming tensor.

        Args:
            features: `Tensor`.
            labels: `dict` or `Tensor`
            encoder_fn: `function`.
            *args:
            **kwargs:
        """
        if 'labels' in get_arguments(encoder_fn):
            kwargs['labels'] = labels
        x = encoder_fn(mode=self.mode, features=features, **kwargs)

        if not isinstance(x, EncoderSpec):
            raise ValueError('`encoder_fn` should return an EncoderSpec.')

        if self.state_size is None:
            self.state_size = x.output_size
        return x.output
Beispiel #32
0
    def _verify_model_fn_args(model_fn, params):
        """Verifies model fn arguments."""

        valid_model_fn_args = {
            'features', 'labels', 'mode', 'params', 'config'
        }

        if model_fn is not None:
            # Check number of arguments of the given function matches requirements.
            model_fn_args = get_arguments(model_fn)
            if 'features' not in model_fn_args:
                raise ValueError(
                    'model_fn `{}` must include features argument.'.format(
                        model_fn))
            if 'labels' not in model_fn_args:
                raise ValueError(
                    'model_fn `{}` must include labels argument.'.format(
                        model_fn))

            if params is not None and 'params' not in model_fn_args:
                raise ValueError(
                    "Estimator's model_fn `{}` does not include params argument, "
                    "but params `{}` are passed.".format(model_fn, params))
            if params is None and 'params' in model_fn_args:
                logging.warning(
                    "Estimator's model_fn (%s) includes params "
                    "argument, but params are not passed to Estimator.",
                    model_fn)
        else:
            raise ValueError("`model_fn` must be provided to Estimator.")

        if 'self' in model_fn_args:
            model_fn_args.remove('self')

        non_valid_args = set(model_fn_args) - valid_model_fn_args
        if non_valid_args:
            raise ValueError(
                "model_fn `{}` has following not expected args: {}".format(
                    model_fn, non_valid_args))