コード例 #1
0
ファイル: networks.py プロジェクト: twoletters/dreamerv2-old
 def get_dist(self, state, dtype=None):
     if self._discrete:
         logit = state['logit']
         logit = tf.cast(logit, tf.float32)
         dist = tfd.Independent(tools.OneHotDist(logit), 1)
         if dtype != tf.float32:
             dist = tools.DtypeDist(dist, dtype or state['logit'].dtype)
     else:
         mean, std = state['mean'], state['std']
         if dtype:
             mean = tf.cast(mean, dtype)
             std = tf.cast(std, dtype)
         dist = tfd.MultivariateNormalDiag(mean, std)
     return dist
コード例 #2
0
ファイル: networks.py プロジェクト: twoletters/dreamerv2-old
 def __call__(self, features, dtype=None):
     x = features
     for index in range(self._layers):
         kw = {}
         if index == self._layers - 1 and self._outscale:
             kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
                 self._outscale)
         x = self.get(f'h{index}', tfkl.Dense, self._units, self._act,
                      **kw)(x)
     if self._dist == 'tanh_normal':
         # https://www.desmos.com/calculator/rcmcf5jwe7
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         mean, std = tf.split(x, 2, -1)
         mean = tf.tanh(mean)
         std = tf.nn.softplus(std + self._init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == 'tanh_normal_5':
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         mean, std = tf.split(x, 2, -1)
         mean = 5 * tf.tanh(mean / 5)
         std = tf.nn.softplus(std + 5) + 5
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == 'normal':
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         mean, std = tf.split(x, 2, -1)
         std = tf.nn.softplus(std + self._init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.Independent(dist, 1)
     elif self._dist == 'normal_1':
         mean = self.get(f'hout', tfkl.Dense, self._size)(x)
         if dtype:
             mean = tf.cast(mean, dtype)
         dist = tfd.Normal(mean, 1)
         dist = tfd.Independent(dist, 1)
     elif self._dist == 'trunc_normal':
         # https://www.desmos.com/calculator/mmuvuhnyxo
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         x = tf.cast(x, tf.float32)
         mean, std = tf.split(x, 2, -1)
         mean = tf.tanh(mean)
         std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
         dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
         dist = tools.DtypeDist(dist, dtype)
         dist = tfd.Independent(dist, 1)
     elif self._dist == 'onehot':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         x = tf.cast(x, tf.float32)
         dist = tools.OneHotDist(x, dtype=dtype)
         dist = tools.DtypeDist(dist, dtype)
     elif self._dist == 'onehot_gumble':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         temp = self._temp
         dist = tools.GumbleDist(temp, x, dtype=dtype)
     else:
         raise NotImplementedError(self._dist)
     return dist