def get_dist(self, state, dtype=None): if self._discrete: logit = state['logit'] logit = tf.cast(logit, tf.float32) dist = tfd.Independent(tools.OneHotDist(logit), 1) if dtype != tf.float32: dist = tools.DtypeDist(dist, dtype or state['logit'].dtype) else: mean, std = state['mean'], state['std'] if dtype: mean = tf.cast(mean, dtype) std = tf.cast(std, dtype) dist = tfd.MultivariateNormalDiag(mean, std) return dist
def __call__(self, features, dtype=None): x = features for index in range(self._layers): kw = {} if index == self._layers - 1 and self._outscale: kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling( self._outscale) x = self.get(f'h{index}', tfkl.Dense, self._units, self._act, **kw)(x) if self._dist == 'tanh_normal': # https://www.desmos.com/calculator/rcmcf5jwe7 x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) mean = tf.tanh(mean) std = tf.nn.softplus(std + self._init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'tanh_normal_5': x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) mean = 5 * tf.tanh(mean / 5) std = tf.nn.softplus(std + 5) + 5 dist = tfd.Normal(mean, std) dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) dist = tfd.Independent(dist, 1) dist = tools.SampleDist(dist) elif self._dist == 'normal': x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) if dtype: x = tf.cast(x, dtype) mean, std = tf.split(x, 2, -1) std = tf.nn.softplus(std + self._init_std) + self._min_std dist = tfd.Normal(mean, std) dist = tfd.Independent(dist, 1) elif self._dist == 'normal_1': mean = self.get(f'hout', tfkl.Dense, self._size)(x) if dtype: mean = tf.cast(mean, dtype) dist = tfd.Normal(mean, 1) dist = tfd.Independent(dist, 1) elif self._dist == 'trunc_normal': # https://www.desmos.com/calculator/mmuvuhnyxo x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) x = tf.cast(x, tf.float32) mean, std = tf.split(x, 2, -1) mean = tf.tanh(mean) std = 2 * tf.nn.sigmoid(std / 2) + self._min_std dist = tools.SafeTruncatedNormal(mean, std, -1, 1) dist = tools.DtypeDist(dist, dtype) dist = tfd.Independent(dist, 1) elif self._dist == 'onehot': x = self.get(f'hout', tfkl.Dense, self._size)(x) x = tf.cast(x, tf.float32) dist = tools.OneHotDist(x, dtype=dtype) dist = tools.DtypeDist(dist, dtype) elif self._dist == 'onehot_gumble': x = self.get(f'hout', tfkl.Dense, self._size)(x) if dtype: x = tf.cast(x, dtype) temp = self._temp dist = tools.GumbleDist(temp, x, dtype=dtype) else: raise NotImplementedError(self._dist) return dist