コード例 #1
0
 def decoder(self, z):  # z 是VAE后验分布的均值, shape=(None,None,512)
     nl = tf.nn.leaky_relu
     z_has_timesteps = (z.get_shape().ndims == 3)
     if z_has_timesteps:
         sh = tf.shape(z)
         z = flatten_two_dims(z)  # (None,512)
     with tf.variable_scope(self.scope + "decoder"):
         # 反卷积网络. de-convolution. spherical_obs=True, 输出 z.shape=(None,84,84,4)
         z = small_deconvnet(z,
                             nl=nl,
                             ch=4 if self.spherical_obs else 8,
                             positional_bias=True)
         if z_has_timesteps:
             z = unflatten_first_dim(z, sh)
         if self.spherical_obs:  # 球形损失, scale 在所有维度都是同一个常数, 简化运算
             scale = tf.get_variable(name="scale",
                                     shape=(),
                                     dtype=tf.float32,
                                     initializer=tf.ones_initializer())
             scale = tf.maximum(scale, -4.)
             scale = tf.nn.softplus(scale)
             scale = scale * tf.ones_like(z)
         else:
             z, scale = tf.split(z, 2, -1)  # 输出 split, 分别作为 mu 和 scale.
             scale = tf.nn.softplus(scale)
         # scale = tf.Print(scale, [scale])
         return tf.distributions.Normal(loc=z, scale=scale)
コード例 #2
0
 def decoder(self, z):
     nl = tf.nn.leaky_relu
     z_has_timesteps = (z.get_shape().ndims == 3)
     if z_has_timesteps:
         sh = tf.shape(z)
         z = flatten_two_dims(z)
     with tf.variable_scope(self.scope + "decoder"):
         z = small_deconvnet(z,
                             nl=nl,
                             ch=4 if self.spherical_obs else 8,
                             positional_bias=True)
         if z_has_timesteps:
             z = unflatten_first_dim(z, sh)
         if self.spherical_obs:
             scale = tf.get_variable(name="scale",
                                     shape=(),
                                     dtype=tf.float32,
                                     initializer=tf.ones_initializer())
             scale = tf.maximum(scale, -4.)
             scale = tf.nn.softplus(scale)
             scale = scale * tf.ones_like(z)
         else:
             z, scale = tf.split(z, 2, -1)
             scale = tf.nn.softplus(scale)
         # scale = tf.Print(scale, [scale])
         return tf.distributions.Normal(loc=z, scale=scale)
コード例 #3
0
    def __init__(self,
                 policy,
                 features_shared_with_policy,
                 feat_dim=None,
                 layernormalize=False,
                 spherical_obs=False):
        assert not layernormalize, "VAE features should already have reasonable size, no need to layer normalize them"
        super(VAE, self).__init__(
            scope="vae",
            policy=policy,
            features_shared_with_policy=features_shared_with_policy,
            feat_dim=feat_dim,
            layernormalize=False)

        self.features_model = small_convnet(self.ob_space,
                                            nl=torch.nn.LeakyReLU,
                                            feat_dim=2 * self.feat_dim,
                                            last_nl=None,
                                            layernormalize=False)
        self.decoder_model = small_deconvnet(self.ob_space,
                                             feat_dim=self.feat_dim,
                                             nl=torch.nn.LeakyReLU,
                                             ch=4 if spherical_obs else 8,
                                             positional_bias=True)

        self.param_list = [
            dict(params=self.features_model.parameters()),
            dict(params=self.decoder_model.parameters())
        ]

        self.features_std = None
        self.next_features_std = None

        self.spherical_obs = spherical_obs
        if self.spherical_obs:
            self.scale = torch.nn.Parameter(torch.tensor(1.0),
                                            requires_grad=True)
            self.param_list = self.param_list + [dict(params=[self.scale])]