コード例 #1
0
class NetworkBaseline(Baseline):
    """
    Baseline based on a TensorForce network, used when parameters are shared between
    the value function and the baseline.
    """

    def __init__(self, network_spec, scope='network-baseline', summary_labels=()):
        """
        Network baseline.

        Args:
            network_spec: Network specification dict
        """
        with tf.name_scope(name=scope):
            self.network = Network.from_spec(spec=network_spec)
            assert len(self.network.internal_inputs()) == 0

            self.linear = Linear(size=1, bias=0.0, scope='prediction')

        super(NetworkBaseline, self).__init__(scope, summary_labels)

    def tf_predict(self, states):
        embedding = self.network.apply(x=states)
        prediction = self.linear.apply(x=embedding)
        return tf.squeeze(input=prediction, axis=1)

    def tf_regularization_loss(self):
        """
        Creates the TensorFlow operations for the baseline regularization loss.

        Returns:
            Regularization loss tensor
        """
        if super(NetworkBaseline, self).tf_regularization_loss() is None:
            losses = list()
        else:
            losses = [super(NetworkBaseline, self).tf_regularization_loss()]

        if self.network.get_regularization_loss() is not None:
            losses.append(self.network.get_regularization_loss())

        if self.linear.get_regularization_loss() is not None:
            losses.append(self.linear.get_regularization_loss())

        if len(losses) > 0:
            return tf.add_n(inputs=losses)
        else:
            return None

    def get_variables(self, include_non_trainable=False):
        baseline_variables = super(NetworkBaseline, self).get_variables(
            include_non_trainable=include_non_trainable
        )

        network_variables = self.network.get_variables(include_non_trainable=include_non_trainable)

        layer_variables = self.linear.get_variables(include_non_trainable=include_non_trainable)

        return baseline_variables + network_variables + layer_variables
コード例 #2
0
class AggregatedBaseline(Baseline):
    """
    Baseline which aggregates per-state baselines.
    """

    def __init__(self, baselines, scope='aggregated-baseline', summary_labels=()):
        """
        Aggregated baseline.

        Args:
            baselines: Dict of per-state baseline specification dicts
        """

        with tf.name_scope(name=scope):
            self.baselines = dict()
            for name, baseline_spec in baselines.items():
                with tf.name_scope(name=(name + '-baseline')):
                    self.baselines[name] = Baseline.from_spec(
                        spec=baseline_spec,
                        kwargs=dict(summary_labels=summary_labels)
                    )

            self.linear = Linear(size=1, bias=0.0, scope='prediction')

        super(AggregatedBaseline, self).__init__(scope, summary_labels)

    def tf_predict(self, states):
        predictions = list()
        for name, state in states.items():
            prediction = self.baselines[name].predict(states=state)
            predictions.append(prediction)
        predictions = tf.stack(values=predictions, axis=1)
        prediction = self.linear.apply(x=predictions)
        return tf.squeeze(input=prediction, axis=1)

    def tf_regularization_loss(self):
        if super(AggregatedBaseline, self).tf_regularization_loss() is None:
            losses = list()
        else:
            losses = [super(AggregatedBaseline, self).tf_regularization_loss()]

        for baseline in self.baseline.values():
            if baseline.regularization_loss() is not None:
                losses.append(baseline.regularization_loss())

        if self.linear.get_regularization_loss() is not None:
            losses.append(self.linear.get_regularization_loss())

        if len(losses) > 0:
            return tf.add_n(inputs=losses)
        else:
            return None

    def get_variables(self, include_non_trainable=False):
        baseline_variables = super(AggregatedBaseline, self).get_variables(
            include_non_trainable=include_non_trainable
        )

        baselines_variables = [
            variable for name in sorted(self.baselines)
            for variable in self.baselines[name].get_variables(include_non_trainable=include_non_trainable)
        ]

        linear_variables = self.linear.get_variables(include_non_trainable=include_non_trainable)

        return baseline_variables + baselines_variables + linear_variables