class NetworkBaseline(Baseline): """ Baseline based on a TensorForce network, used when parameters are shared between the value function and the baseline. """ def __init__(self, network_spec, scope='network-baseline', summary_labels=()): """ Network baseline. Args: network_spec: Network specification dict """ with tf.name_scope(name=scope): self.network = Network.from_spec(spec=network_spec) assert len(self.network.internal_inputs()) == 0 self.linear = Linear(size=1, bias=0.0, scope='prediction') super(NetworkBaseline, self).__init__(scope, summary_labels) def tf_predict(self, states): embedding = self.network.apply(x=states) prediction = self.linear.apply(x=embedding) return tf.squeeze(input=prediction, axis=1) def tf_regularization_loss(self): """ Creates the TensorFlow operations for the baseline regularization loss. Returns: Regularization loss tensor """ if super(NetworkBaseline, self).tf_regularization_loss() is None: losses = list() else: losses = [super(NetworkBaseline, self).tf_regularization_loss()] if self.network.get_regularization_loss() is not None: losses.append(self.network.get_regularization_loss()) if self.linear.get_regularization_loss() is not None: losses.append(self.linear.get_regularization_loss()) if len(losses) > 0: return tf.add_n(inputs=losses) else: return None def get_variables(self, include_non_trainable=False): baseline_variables = super(NetworkBaseline, self).get_variables( include_non_trainable=include_non_trainable ) network_variables = self.network.get_variables(include_non_trainable=include_non_trainable) layer_variables = self.linear.get_variables(include_non_trainable=include_non_trainable) return baseline_variables + network_variables + layer_variables
class AggregatedBaseline(Baseline): """ Baseline which aggregates per-state baselines. """ def __init__(self, baselines, scope='aggregated-baseline', summary_labels=()): """ Aggregated baseline. Args: baselines: Dict of per-state baseline specification dicts """ with tf.name_scope(name=scope): self.baselines = dict() for name, baseline_spec in baselines.items(): with tf.name_scope(name=(name + '-baseline')): self.baselines[name] = Baseline.from_spec( spec=baseline_spec, kwargs=dict(summary_labels=summary_labels) ) self.linear = Linear(size=1, bias=0.0, scope='prediction') super(AggregatedBaseline, self).__init__(scope, summary_labels) def tf_predict(self, states): predictions = list() for name, state in states.items(): prediction = self.baselines[name].predict(states=state) predictions.append(prediction) predictions = tf.stack(values=predictions, axis=1) prediction = self.linear.apply(x=predictions) return tf.squeeze(input=prediction, axis=1) def tf_regularization_loss(self): if super(AggregatedBaseline, self).tf_regularization_loss() is None: losses = list() else: losses = [super(AggregatedBaseline, self).tf_regularization_loss()] for baseline in self.baseline.values(): if baseline.regularization_loss() is not None: losses.append(baseline.regularization_loss()) if self.linear.get_regularization_loss() is not None: losses.append(self.linear.get_regularization_loss()) if len(losses) > 0: return tf.add_n(inputs=losses) else: return None def get_variables(self, include_non_trainable=False): baseline_variables = super(AggregatedBaseline, self).get_variables( include_non_trainable=include_non_trainable ) baselines_variables = [ variable for name in sorted(self.baselines) for variable in self.baselines[name].get_variables(include_non_trainable=include_non_trainable) ] linear_variables = self.linear.get_variables(include_non_trainable=include_non_trainable) return baseline_variables + baselines_variables + linear_variables