def fn_summary(): one = tf_util.constant(value=1.0, dtype='float') digamma_alpha = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=alpha)), dtype='float' ) digamma_beta = tf_util.cast(x=tf.math.digamma(x=tf_util.float32(x=beta)), dtype='float') digamma_alpha_beta = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=alpha_beta)), dtype='float' ) entropy = log_norm - (beta - one) * digamma_beta - (alpha - one) * digamma_alpha + \ (alpha_beta - one - one) * digamma_alpha_beta return tf.math.reduce_mean(input_tensor=entropy)
def entropy(self, *, parameters): alpha, beta, alpha_beta, log_norm = parameters.get( ('alpha', 'beta', 'alpha_beta', 'log_norm')) one = tf_util.constant(value=1.0, dtype='float') digamma_alpha = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=alpha)), dtype='float') digamma_beta = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=beta)), dtype='float') digamma_alpha_beta = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=alpha_beta)), dtype='float') return log_norm - (beta - one) * digamma_beta - (alpha - one) * digamma_alpha + \ (alpha_beta - one - one) * digamma_alpha_beta
def kl_divergence(self, *, parameters1, parameters2): alpha1, beta1, alpha_beta1, log_norm1 = parameters1.get( ('alpha', 'beta', 'alpha_beta', 'log_norm')) alpha2, beta2, alpha_beta2, log_norm2 = parameters2.get( ('alpha', 'beta', 'alpha_beta', 'log_norm')) digamma_alpha1 = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=alpha1)), dtype='float') digamma_beta1 = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=beta1)), dtype='float') digamma_alpha_beta1 = tf_util.cast( x=tf.math.digamma(x=tf_util.float32(x=alpha_beta1)), dtype='float') return log_norm2 - log_norm1 - digamma_beta1 * (beta2 - beta1) - \ digamma_alpha1 * (alpha2 - alpha1) + digamma_alpha_beta1 * \ (alpha_beta2 - alpha_beta1)
def log_probability(self, *, parameters, action): mean, stddev, log_stddev = parameters.get(('mean', 'stddev', 'log_stddev')) # Inverse bounded transformation if self.bounded_transform is not None: if self.action_spec.min_value is not None and self.action_spec.max_value is not None: one = tf_util.constant(value=1.0, dtype='float') two = tf_util.constant(value=2.0, dtype='float') min_value = tf_util.constant(value=self.action_spec.min_value, dtype='float') max_value = tf_util.constant(value=self.action_spec.max_value, dtype='float') action = two * (action - min_value) / (max_value - min_value) - one if self.bounded_transform == 'tanh': clip = tf_util.constant(value=(1.0 - util.epsilon), dtype='float') action = tf.clip_by_value(t=action, clip_value_min=-clip, clip_value_max=clip) action = tf_util.cast(x=tf.math.atanh(x=tf_util.float32(x=action)), dtype='float') epsilon = tf_util.constant(value=util.epsilon, dtype='float') half = tf_util.constant(value=0.5, dtype='float') half_log_two_pi = tf_util.constant(value=(0.5 * np.log(2.0 * np.pi)), dtype='float') sq_mean_distance = tf.square(x=(action - mean)) sq_stddev = tf.maximum(x=tf.square(x=stddev), y=epsilon) log_prob = -half * sq_mean_distance / sq_stddev - log_stddev - half_log_two_pi if self.bounded_transform == 'tanh': log_two = tf_util.constant(value=np.log(2.0), dtype='float') log_prob -= two * (log_two - action - tf.math.softplus(features=(-two * action))) return log_prob
def iterative_apply(self, *, x, internals): x = tf_util.float32(x=x) state = tf_util.float32(x=internals['state']) if self.cell_type == 'gru': state = (state, ) elif self.cell_type == 'lstm': state = (state[:, 0, :], state[:, 1, :]) x, state = self.cell(inputs=x, states=state) if self.cell_type == 'gru': state = state[0] elif self.cell_type == 'lstm': state = tf.stack(values=state, axis=1) x = tf_util.cast(x=x, dtype='float') internals['state'] = tf_util.cast(x=state, dtype='float') return x, internals
def apply(self, *, x): x = tf_util.float32(x=x) x = self.rnn(inputs=x, initial_state=None) if not self.return_final_state: x = tf_util.cast(x=x[0], dtype='float') elif self.cell_type == 'gru': x = tf_util.cast(x=x[1], dtype='float') elif self.cell_type == 'lstm': x = tf_util.cast(x=tf.concat(values=(x[1], x[2]), axis=1), dtype='float') return super().apply(x=x)