예제 #1
0
파일: models.py 프로젝트: axelbr/dreamer
 def __call__(self, features, training=False):
     raw_init_std = np.log(np.exp(self._init_std) - 1)
     x = features
     for index in range(self._layers):
         x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
     if self._dist == 'tanh_normal':  # Original from Dreamer
         # https://www.desmos.com/calculator/rcmcf5jwe7
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         mean, std = tf.split(x, 2, -1)
         mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
         std = tf.nn.softplus(std + raw_init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == 'normalized_tanhtransformed_normal':
         # Normalized variation of the original actor: (mu,std) normalized, then create tanh normal from them
         # The normalization params (moving avg, std) are updated only during training
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         x = tf.reshape(x, [-1, 2 * self._size])
         x = self.get(f'hnorm', tfkl.BatchNormalization)(x, training=training)  # `training` true only in imagination
         x = tf.reshape(x, [*features.shape[:-1], -1])
         mean, std = tf.split(x, 2, -1)
         std = tf.nn.softplus(std) + self._min_std  # to have positive values
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     else:
         raise NotImplementedError(self._dist)
     return dist
예제 #2
0
 def __call__(self, features):
     raw_init_std = np.log(np.exp(self._init_std) - 1)
     x = features
     for index in range(self._layers):
         x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
     if self._dist == 'tanh_normal':
         # https://www.desmos.com/calculator/rcmcf5jwe7
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         mean, std = tf.split(x, 2, -1)
         mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
         std = tf.nn.softplus(std + raw_init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == 'onehot':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         dist = tools.OneHotDist(x)
     elif self._dist == 'gumbel':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         dist = tfd.RelaxedOneHotCategorical(temperature=1e-1, logits=x)
         dist = tools.SampleDist(dist)
     else:
         raise NotImplementedError
     return dist
예제 #3
0
 def __call__(self, features):
     raw_init_std = np.log(np.exp(self._init_std) - 1)
     x = features
     for index in range(self._layers_num):
         x = self.get(f"h{index}", tf.keras.layers.Dense, self._units,
                      self._act)(x)
     if self._dist == "tanh_normal":
         # https://www.desmos.com/calculator/rcmcf5jwe7
         x = self.get(f"hout", tf.keras.layers.Dense, 2 * self._size)(x)
         mean, std = tf.split(x, 2, -1)
         mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
         std = tf.nn.softplus(std + raw_init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == "onehot":
         x = self.get(f"hout", tf.keras.layers.Dense, self._size)(x)
         dist = tools.OneHotDist(x)
     else:
         raise NotImplementedError(dist)
     return dist
예제 #4
0
 def __call__(self, features, dtype=None):
     x = features
     for index in range(self._layers):
         kw = {}
         if index == self._layers - 1 and self._outscale:
             kw['kernel_initializer'] = tf.keras.initializers.VarianceScaling(
                 self._outscale)
         x = self.get(f'h{index}', tfkl.Dense, self._units, self._act,
                      **kw)(x)
     if self._dist == 'tanh_normal':
         # https://www.desmos.com/calculator/rcmcf5jwe7
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         mean, std = tf.split(x, 2, -1)
         mean = tf.tanh(mean)
         std = tf.nn.softplus(std + self._init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == 'tanh_normal_5':
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         mean, std = tf.split(x, 2, -1)
         mean = 5 * tf.tanh(mean / 5)
         std = tf.nn.softplus(std + 5) + 5
         dist = tfd.Normal(mean, std)
         dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
         dist = tfd.Independent(dist, 1)
         dist = tools.SampleDist(dist)
     elif self._dist == 'normal':
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         mean, std = tf.split(x, 2, -1)
         std = tf.nn.softplus(std + self._init_std) + self._min_std
         dist = tfd.Normal(mean, std)
         dist = tfd.Independent(dist, 1)
     elif self._dist == 'normal_1':
         mean = self.get(f'hout', tfkl.Dense, self._size)(x)
         if dtype:
             mean = tf.cast(mean, dtype)
         dist = tfd.Normal(mean, 1)
         dist = tfd.Independent(dist, 1)
     elif self._dist == 'trunc_normal':
         # https://www.desmos.com/calculator/mmuvuhnyxo
         x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
         x = tf.cast(x, tf.float32)
         mean, std = tf.split(x, 2, -1)
         mean = tf.tanh(mean)
         std = 2 * tf.nn.sigmoid(std / 2) + self._min_std
         dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
         dist = tools.DtypeDist(dist, dtype)
         dist = tfd.Independent(dist, 1)
     elif self._dist == 'onehot':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         x = tf.cast(x, tf.float32)
         dist = tools.OneHotDist(x, dtype=dtype)
         dist = tools.DtypeDist(dist, dtype)
     elif self._dist == 'onehot_gumble':
         x = self.get(f'hout', tfkl.Dense, self._size)(x)
         if dtype:
             x = tf.cast(x, dtype)
         temp = self._temp
         dist = tools.GumbleDist(temp, x, dtype=dtype)
     else:
         raise NotImplementedError(self._dist)
     return dist