Пример #1
0
    def build_VAE_prior(self, params):
        """
            Build a VAE prior with loaded weights
        """
        shape = self.shape
        assert self.N == 784
        biases, weights = self.load_VAE_prior(params)

        if params['id'] == '20_relu_400_sigmoid_784_bias':
            D, N1, N = 20, 400, 28 * 28
            W1, W2 = weights
            b1, b2 = biases
            prior_x = (GaussianPrior(size=D) @ V(id="z_0") @ LinearChannel(
                W1, name="W_1") @ V(id="Wz_1") @ BiasChannel(b1) @ V(id="b_1")
                       @ LeakyReluChannel(0) @ V(id="z_1") @ LinearChannel(
                           W2, name="W_2") @ V(id="Wz_2") @ BiasChannel(b2)
                       @ V(id="b_2") @ HardTanhChannel() @ V(id="z_2")
                       @ ReshapeChannel(prev_shape=self.N,
                                        next_shape=self.shape))
        else:
            raise NotImplementedError

        return prior_x
Пример #2
0
 def test_hard_tanh_second_moment(self):
     channel = HardTanhChannel()
     self._test_function_second_moment(channel, self.records, places=2)
Пример #3
0
 def test_hard_tanh_posterior(self):
     channel = HardTanhChannel()
     self._test_function_posterior(channel, self.records, places=1)