Beispiel #1
0
        def graph_fn(mode, features, labels=None):
            kwargs = {}
            if 'labels' in get_arguments(self._graph_fn):
                kwargs['labels'] = labels

            graph_outputs = self._graph_fn(mode=mode,
                                           features=features,
                                           **kwargs)
            a = Dense(units=self.num_actions)(graph_outputs)
            v = None

            if self.dueling is not None:
                # Q = V(s) + A(s, a)
                v = Dense(units=1)(graph_outputs)
                if self.dueling == 'mean':
                    q = v + (a - tf.reduce_mean(a, axis=1, keep_dims=True))
                elif self.dueling == 'max':
                    q = v + (a - tf.reduce_max(a, axis=1, keep_dims=True))
                elif self.dueling == 'naive':
                    q = v + a
                elif self.dueling is True:
                    q = tf.identity(a)
                else:
                    raise ValueError("The value `{}` provided for "
                                     "dueling is unsupported.".format(
                                         self.dueling))
            else:
                q = tf.identity(a)

            return QModelSpec(graph_outputs=graph_outputs, a=a, v=v, q=q)
Beispiel #2
0
    def _call_model_fn(self, features, labels, mode, config=None):
        """Calls model function.

        Args:
          features: features dict.
          labels: labels dict.
          mode: ModeKeys

        Returns:
          An `EstimatorSpec` object.

        Raises:
          ValueError: if model_fn returns invalid objects.
        """
        model_fn_args = get_arguments(self._model_fn)
        kwargs = {}
        if 'labels' in model_fn_args:
            kwargs['labels'] = labels
        else:
            if labels is not None:
                raise ValueError(
                    'model_fn does not take labels, but input_fn returns labels.'
                )
        if 'mode' in model_fn_args:
            kwargs['mode'] = mode
        if 'params' in model_fn_args:
            kwargs['params'] = self.params
        if 'config' in model_fn_args:
            kwargs['config'] = self.config
        model_fn_results = self._model_fn(features=features, **kwargs)

        if not isinstance(model_fn_results, EstimatorSpec):
            raise ValueError('model_fn should return an EstimatorSpec.')

        return model_fn_results
Beispiel #3
0
def _decay_fn(timestep,
              exploration_rate,
              decay_type='polynomial_decay',
              start_decay_at=0,
              stop_decay_at=1e9,
              decay_rate=0.,
              staircase=False,
              decay_steps=100000,
              min_exploration_rate=0):
    """The computed decayed exploration rate.

    Args:
        timestep: the current timestep.
        exploration_rate: `float` or `list` of `float` or `function`.
            The initial value of the exploration rate.
        decay_type: A decay function name defined in `exploration_decay`
            possible Values: exponential_decay, inverse_time_decay, natural_exp_decay,
                             piecewise_constant, polynomial_decay.
        start_decay_at: `int`. When to start the decay.
        stop_decay_at: `int`. When to stop the decay.
        decay_rate: A Python number.  The decay rate.
        staircase: Whether to apply decay in a discrete staircase,
            as opposed to continuous, fashion.
        decay_steps: How often to apply decay.
        min_exploration_rate: `float`. Don't decay below this number.
    """
    if isinstance(exploration_rate, partial):
        _exploration_rate = exploration_rate()
    else:
        _exploration_rate = exploration_rate

    timestep = tf.to_int32(timestep)
    decay_type_fn = getattr(exploration_decay, decay_type)
    kwargs = dict(exploration_rate=_exploration_rate,
                  timestep=tf.minimum(timestep, tf.to_int32(stop_decay_at)) -
                  tf.to_int32(start_decay_at),
                  decay_steps=decay_steps,
                  name="decayed_exploration_rate")
    decay_fn_args = get_arguments(decay_type_fn)
    if 'decay_rate' in decay_fn_args:
        kwargs['decay_rate'] = decay_rate
    if 'staircase' in decay_fn_args:
        kwargs['staircase'] = staircase

    decayed_exploration_rate = decay_type_fn(**kwargs)

    final_exploration_rate = tf.train.piecewise_constant(
        x=timestep,
        boundaries=[start_decay_at],
        values=[exploration_rate, decayed_exploration_rate])

    if min_exploration_rate:
        final_exploration_rate = tf.maximum(final_exploration_rate,
                                            min_exploration_rate)

    return final_exploration_rate
Beispiel #4
0
 def _check_subgraph_fn(function, function_name):
     """Checks that the functions provided for constructing the graph has a valid signature."""
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         if 'mode' not in model_fn_args or 'features' not in model_fn_args:
             raise ValueError(
                 "Model's `{}` `{}` should have at least 2 args: "
                 "`mode`, `features`, and possibly `features`.".format(function_name, function))
     else:
         raise ValueError("`{}` must be provided to Model.".format(function_name))
Beispiel #5
0
    def __init__(self, mode, build_fn, name=None):
        if not callable(build_fn):
            raise TypeError("`build_fn` must be callable.")

        build_fn_args = get_arguments(build_fn)
        if 'mode' not in build_fn_args:
            raise ValueError("`build_fn` must include `mode` argument.")

        self._build_fn = build_fn
        super(FunctionModule, self).__init__(
            mode=mode,
            name=name or get_function_name(build_fn),
            module_type=self.ModuleType.IMAGE_PROCESSOR)
Beispiel #6
0
    def _call_graph_fn(self, features, labels=None):
        """Calls graph function.

        Args:
            features: `Tensor` or `dict` of tensors
            labels: `Tensor` or `dict` of tensors
        """
        set_learning_phase(Modes.is_train(self.mode))

        kwargs = {}
        if 'labels' in get_arguments(self._graph_fn):
            kwargs['labels'] = labels
        return self._graph_fn(mode=self.mode, features=features, **kwargs)
Beispiel #7
0
    def encode(self, features, labels, encoder_fn, *args, **kwargs):
        """Encodes the incoming tensor.

        Args:
            features: `Tensor`.
            labels: `dict` or `Tensor`
            encoder_fn: `function`.
            *args:
            **kwargs:
        """
        if 'labels' in get_arguments(encoder_fn):
            kwargs['labels'] = labels
        output = encoder_fn(mode=self.mode, features=features, **kwargs)

        if self.state_size is None:
            self.state_size = get_shape(output)[1:]
        return output
Beispiel #8
0
 def _check_bridge_fn(function):
     if function is not None:
         # Check number of arguments of the given function matches requirements.
         model_fn_args = get_arguments(function)
         cond = ('mode' not in model_fn_args or
                 'loss' not in model_fn_args or
                 'features' not in model_fn_args or
                 'labels' not in model_fn_args or
                 'encoder_fn' not in model_fn_args or
                 'encoder_fn' not in model_fn_args)
         if cond:
             raise ValueError(
                 "Model's `bridge` `{}` should have these args: "
                 "`mode`, `features`, `labels`, `encoder_fn`, "
                 "and `decoder_fn`.".format(function))
     else:
         raise ValueError("`bridge_fn` must be provided to Model.")
Beispiel #9
0
        def graph_fn(mode, features, labels=None):
            kwargs = {}
            if 'labels' in get_arguments(self._graph_fn):
                kwargs['labels'] = labels

            graph_outputs = self._graph_fn(mode=mode,
                                           features=features,
                                           **kwargs)
            a = Dense(units=self.num_actions)(graph_outputs)
            if self.is_continuous:
                values = tf.concat(values=[a, tf.exp(a) + 1], axis=0)
                distribution = self._build_distribution(values=values)
            else:
                values = tf.identity(a)
                distribution = self._build_distribution(values=a)
            return PGModelSpec(graph_outputs=graph_outputs,
                               a=a,
                               distribution=distribution,
                               dist_values=values)
Beispiel #10
0
    def decode(self, features, labels, decoder_fn, *args, **kwargs):
        """Decodes the incoming tensor if it's validates against the state size of the decoder.
        Otherwise, generates a random value.

        Args:
            features: `Tensor`
            labels: `dict` or `Tensor`
            decoder_fn: `function`.
            *args:
            **kwargs:
        """
        # TODO: make decode capable of generating values directly,
        # TODO: basically accepting None incoming values. Should also specify a distribution.

        # shape = self._get_decoder_shape(incoming)
        # return decoder_fn(mode=self.mode, inputs=tf.random_normal(shape=shape))
        if 'labels' in get_arguments(decoder_fn):
            kwargs['labels'] = labels

        return decoder_fn(mode=self.mode, features=features, **kwargs)
Beispiel #11
0
    def _verify_model_fn_args(model_fn, params):
        """Verifies model fn arguments."""

        valid_model_fn_args = {
            'features', 'labels', 'mode', 'params', 'config'
        }

        if model_fn is not None:
            # Check number of arguments of the given function matches requirements.
            model_fn_args = get_arguments(model_fn)
            if 'features' not in model_fn_args:
                raise ValueError(
                    'model_fn `{}` must include features argument.'.format(
                        model_fn))
            if 'labels' not in model_fn_args:
                raise ValueError(
                    'model_fn `{}` must include labels argument.'.format(
                        model_fn))

            if params is not None and 'params' not in model_fn_args:
                raise ValueError(
                    "Estimator's model_fn `{}` does not include params argument, "
                    "but params `{}` are passed.".format(model_fn, params))
            if params is None and 'params' in model_fn_args:
                logging.warning(
                    "Estimator's model_fn (%s) includes params "
                    "argument, but params are not passed to Estimator.",
                    model_fn)
        else:
            raise ValueError("`model_fn` must be provided to Estimator.")

        if 'self' in model_fn_args:
            model_fn_args.remove('self')

        non_valid_args = set(model_fn_args) - valid_model_fn_args
        if non_valid_args:
            raise ValueError(
                "model_fn `{}` has following not expected args: {}".format(
                    model_fn, non_valid_args))