Exemple #1
0
 def K_n_choose_k(n, k, seed=None):
     from theano.tensor.shared_randomstreams import RandomStreams
     if seed is None:
         seed = np.random.randint(1, 10e6)
     rng = RandomStreams(seed=seed)
     r = rng.choice(size=(k, ), a=n, replace=False, dtype='int32')
     return r
Exemple #2
0
    def __init__(self, classifier, args, noise_dist):
        self.y = T.ivector("y")

        ## Cost function
        #  Sum over minibatch instances (log ( u(w|c) / (u(w|c) + k * p_n(w)) ) + sum over noise samples ( log ( u(x|c) / ( u(x|c) + k * p_n(x) ) )))

        # Generating noise samples
        srng = RandomStreams(seed=1234)
        noise_samples = srng.choice(
            size=(self.y.shape[0], args.num_noise_samples), a=args.num_classes, p=noise_dist, dtype="int32"
        )

        log_noise_dist = theano.shared(np.log(noise_dist.get_value()), borrow=True)
        # log_num_noise_samples = theano.shared(math.log(args.num_noise_samples)).astype(theano.config.floatX)
        log_num_noise_samples = theano.shared(np.log(args.num_noise_samples, dtype=theano.config.floatX))
        # Data Part of Cost Function: log ( u(w|c) / (u(w|c) + k * p_n(w))
        data_scores = classifier.output[T.arange(self.y.shape[0]), self.y]
        data_denom = self.logadd(data_scores, log_num_noise_samples + log_noise_dist[self.y])
        data_prob = data_scores - data_denom
        # Sumation of Noise Part of Cost Function: sum over noise samples ( log ( u(x|c) / ( u(x|c) + k * p_n(x) ) ))
        noise_mass = (
            log_num_noise_samples + log_noise_dist[noise_samples]
        )  # log(k) + log(p_n(x)) for all noise samples (Shape: #instaces x k)
        noise_scores = classifier.output[T.arange(noise_samples.shape[0]).reshape((-1, 1)), noise_samples]
        noise_denom = self.logadd(noise_scores, noise_mass)
        noise_prob_sum = T.sum(noise_mass - noise_denom, axis=1)

        self.cost = -T.mean(data_prob + noise_prob_sum)
        self.test = T.sum(data_scores)
Exemple #3
0
    def _negative_sampling(self, num_negative_samples, target_indices):
        assert num_negative_samples > 0

        logging.debug(
            'Stochastically sampling %d negative instances '
            'out of %d classes (%.2f%%).', num_negative_samples,
            self.num_entities,
            100.0 * float(num_negative_samples) / self.num_entities)

        from theano.tensor.shared_randomstreams import RandomStreams

        srng = RandomStreams(seed=np.random.randint(low=0, high=(1 << 30)))

        rng_sample_size = (
            self.batch_size,
            num_negative_samples,
        )

        logging.debug('Using %s for random sample generation of %s tensors.',
                      RandomStreams, rng_sample_size)

        logging.debug('For every batch %d random integers are sampled.',
                      np.prod(rng_sample_size))

        random_negative_indices = srng.choice(rng_sample_size,
                                              a=self.num_entities,
                                              p=self.clazz_distribution)

        if self.__DEBUG__:
            random_negative_indices = theano.printing.Print(
                'random_negative_indices')(random_negative_indices)

        return random_negative_indices
class BlackoutLayer(DenseLayer):
	def __init__(self, incoming, num_units, num_outputs=0.01, **kwargs):
		super(BlackoutLayer, self).__init__(incoming, num_units, **kwargs)
		self._srng = RandomStreams(get_rng().randint(1, 2147462579))
		if num_outputs < 1:
			num_outputs = num_outputs * num_units
		self.num_outputs = int(num_outputs)

	def get_output_for(self, input, deterministic=False, targets=None, samples=None, **kwargs):
		if input.ndim > 2:
			# if the input has more than two dimensions, flatten it into a
			# batch of feature vectors.
			input = input.flatten(2)

		if deterministic:
			activation = T.dot(input, self.W)
			if self.b is not None:
				activation = activation + self.b.dimshuffle('x', 0)
		else:

			if samples is None:
				output_cells = self._srng.choice(a=self.num_units, size=(self.num_outputs,))
			else:
				output_cells = samples

			if targets is not None:
				#output_cells = [x for x in output_cells if x not in targets]
				output_cells = T.concatenate((targets, output_cells))

			activation = T.dot(input, self.W[:,output_cells])
			if self.b is not None:
				activation = activation + self.b[output_cells].dimshuffle('x', 0)

		return self.nonlinearity(activation)
Exemple #5
0
 def sample(self, b):
     r = RandomStreams(seed=1)
     return r.choice(size=(self.nr_neg_samples * b, ),
                     replace=True,
                     a=self.v,
                     p=self.p,
                     dtype='int32')
Exemple #6
0
    def __init__(self, classifier, args, noise_dist):
        self.y = T.ivector('y')

        ## Cost function
        #  Sum over minibatch instances (log ( u(w|c) / (u(w|c) + k * p_n(w)) ) + sum over noise samples ( log ( u(x|c) / ( u(x|c) + k * p_n(x) ) )))

        # Generating noise samples
        srng = RandomStreams(seed=1234)
        noise_samples = srng.choice(size=(self.y.shape[0],
                                          args.num_noise_samples),
                                    a=args.num_classes,
                                    p=noise_dist,
                                    dtype='int32')

        log_noise_dist = theano.shared(np.log(noise_dist.get_value()),
                                       borrow=True)
        #log_num_noise_samples = theano.shared(math.log(args.num_noise_samples)).astype(theano.config.floatX)
        log_num_noise_samples = theano.shared(
            np.log(args.num_noise_samples, dtype=theano.config.floatX))
        # Data Part of Cost Function: log ( u(w|c) / (u(w|c) + k * p_n(w))
        data_scores = classifier.output[T.arange(self.y.shape[0]), self.y]
        data_denom = self.logadd(
            data_scores, log_num_noise_samples + log_noise_dist[self.y])
        data_prob = data_scores - data_denom
        # Sumation of Noise Part of Cost Function: sum over noise samples ( log ( u(x|c) / ( u(x|c) + k * p_n(x) ) ))
        noise_mass = log_num_noise_samples + log_noise_dist[
            noise_samples]  # log(k) + log(p_n(x)) for all noise samples (Shape: #instaces x k)
        noise_scores = classifier.output[
            T.arange(noise_samples.shape[0]).reshape((-1, 1)), noise_samples]
        noise_denom = self.logadd(noise_scores, noise_mass)
        noise_prob_sum = T.sum(noise_mass - noise_denom, axis=1)

        self.cost = (-T.mean(data_prob + noise_prob_sum))
        self.test = (T.sum(data_scores))
Exemple #7
0
    def _negative_sampling(self, num_negative_samples, target_indices):
        assert num_negative_samples > 0

        logging.debug('Stochastically sampling %d negative instances '
                      'out of %d classes (%.2f%%).',
                      num_negative_samples, self.num_entities,
                      100.0 *
                      float(num_negative_samples) / self.num_entities)

        from theano.tensor.shared_randomstreams import RandomStreams

        srng = RandomStreams(
            seed=np.random.randint(low=0, high=(1 << 30)))

        rng_sample_size = (self.batch_size, num_negative_samples,)

        logging.debug(
            'Using %s for random sample generation of %s tensors.',
            RandomStreams, rng_sample_size)

        logging.debug('For every batch %d random integers are sampled.',
                      np.prod(rng_sample_size))

        random_negative_indices = srng.choice(
            rng_sample_size,
            a=self.num_entities,
            p=self.clazz_distribution)

        if self.__DEBUG__:
            random_negative_indices = theano.printing.Print(
                'random_negative_indices')(random_negative_indices)

        return random_negative_indices
Exemple #8
0
def tied_losses(preds, n_sample_preds, n_classes, n_pairs):
    preds_per_trial_row = preds.reshape((-1, n_sample_preds, n_classes))
    _srng = RandomStreams(get_rng().randint(1, 2147462579))
    rand_inds = _srng.choice([n_pairs * 2], n_sample_preds, replace=False)
    part_1 = preds_per_trial_row[:, rand_inds[:n_pairs]]
    part_2 = preds_per_trial_row[:, rand_inds[n_pairs:]]
    # Have to now ensure first values are larger zero
    # for numerical stability :/
    eps = 1e-4
    part_1 = T.maximum(part_1, eps)
    loss = categorical_crossentropy(part_1, part_2)
    return loss
Exemple #9
0
def tied_losses(preds, n_sample_preds, n_classes, n_pairs):
    preds_per_trial_row = preds.reshape((-1, n_sample_preds, n_classes))
    _srng = RandomStreams(get_rng().randint(1, 2147462579))
    rand_inds = _srng.choice([n_pairs  * 2], n_sample_preds, replace=False)
    part_1 = preds_per_trial_row[:,rand_inds[:n_pairs]]
    part_2 = preds_per_trial_row[:,rand_inds[n_pairs:]]
    # Have to now ensure first values are larger zero
    # for numerical stability :/
    eps = 1e-4
    part_1 = T.maximum(part_1, eps)
    loss = categorical_crossentropy(part_1, part_2)
    return loss
Exemple #10
0
class KernelDensityEstimateDistribution(Distribution):
    """Randomly samples from a kernel density estimate yielded by a set
    of training points.

    Simple sampling procedure [1]:

    1. With training points $x_1, ... x_n$, sample a point $x_i$
       uniformly
    2. From original KDE, we have a kernel defined at point $x_i$;
       sample randomly from this kernel

    [1]: http://www.stat.cmu.edu/~cshalizi/350/lectures/28/lecture-28.pdf
    """
    def __init__(self, X, bandwidth=1, space=None, rng=None):
        """
        Parameters
        ----------
        X : ndarray of shape (num_examples, num_features)
            Training examples from which to generate a kernel density
            estimate

        bandwidth : float
            Bandwidth (or h, or sigma) of the generated kernels
        """

        assert X.ndim == 2
        if space is None:
            space = VectorSpace(dim=X.shape[1], dtype=X.dtype)

        # super(KernelDensityEstimateDistribution, self).__init__(space)

        self.X = sharedX(X, name='KDE_X')

        self.bandwidth = sharedX(bandwidth, name='bandwidth')
        self.rng = RandomStreams() if rng is None else rng

    def sample(self, n):
        # Sample $n$ training examples
        training_samples = self.X[self.rng.choice(size=(n, ),
                                                  a=self.X.shape[0],
                                                  replace=True)]

        # Sample individually from each selected associated kernel
        #
        # (not well documented within NumPy / Theano, but rng.normal
        # call samples from a multivariate normal with diagonal
        # covariance matrix)
        ret = self.rng.normal(size=(n, self.X.shape[1]),
                              avg=training_samples,
                              std=self.bandwidth,
                              dtype=theano.config.floatX)

        return ret
class KernelDensityEstimateDistribution(Distribution):
    """Randomly samples from a kernel density estimate yielded by a set
    of training points.

    Simple sampling procedure [1]:

    1. With training points $x_1, ... x_n$, sample a point $x_i$
       uniformly
    2. From original KDE, we have a kernel defined at point $x_i$;
       sample randomly from this kernel

    [1]: http://www.stat.cmu.edu/~cshalizi/350/lectures/28/lecture-28.pdf
    """

    def __init__(self, X, bandwidth=1, space=None, rng=None):
        """
        Parameters
        ----------
        X : ndarray of shape (num_examples, num_features)
            Training examples from which to generate a kernel density
            estimate

        bandwidth : float
            Bandwidth (or h, or sigma) of the generated kernels
        """

        assert X.ndim == 2
        if space is None:
            space = VectorSpace(dim=X.shape[1], dtype=X.dtype)

        # super(KernelDensityEstimateDistribution, self).__init__(space)

        self.X = sharedX(X, name='KDE_X')

        self.bandwidth = sharedX(bandwidth, name='bandwidth')
        self.rng = RandomStreams() if rng is None else rng

    def sample(self, n):
        # Sample $n$ training examples
        training_samples = self.X[self.rng.choice(size=(n,), a=self.X.shape[0], replace=True)]

        # Sample individually from each selected associated kernel
        #
        # (not well documented within NumPy / Theano, but rng.normal
        # call samples from a multivariate normal with diagonal
        # covariance matrix)
        ret = self.rng.normal(size=(n, self.X.shape[1]),
                              avg=training_samples, std=self.bandwidth,
                              dtype=theano.config.floatX)

        return ret
class GibbsRegressor(ISymbolicPredictor):

    def __init__(self, n_dim_in, n_dim_out, sample_y = False, n_alpha = 1, possible_ws = [0, 1],
            alpha_update_policy = 'sequential', seed = None):
        self._w = theano.shared(np.zeros((n_dim_in, n_dim_out), dtype = theano.config.floatX), name = 'w')
        self._rng = RandomStreams(seed)
        if n_alpha == 'all':
            n_alpha = n_dim_in
        self._n_alpha = n_alpha
        self._alpha = theano.shared(np.arange(n_alpha))  # scalar
        self._sample_y = sample_y
        self._possible_ws = theano.shared(np.array(possible_ws), name = 'possible_ws')
        assert alpha_update_policy in ('sequential', 'random')
        self._alpha_update_policy = alpha_update_policy

    def _add_alpha_update(self):
        new_alpha = (self._alpha+self._n_alpha) % self._w.shape[0] \
            if self._alpha_update_policy == 'sequential' else \
            self._rng.choice(a=self._w.shape[0], size = (self._n_alpha, ), replace = False).reshape([-1])  # Reshape is for some reason necessary when n_alpha=1
        add_update(self._alpha, new_alpha)

    @staticmethod
    def compute_p_wa(w, x, y, alpha, possible_ws = np.array([0, 1])):
        """
        Compute the probability the weights at index alpha taking on each of the values in possible_ws
        """
        assert x.tag.test_value.ndim == y.tag.test_value.ndim == 2
        assert x.tag.test_value.shape[0] == y.tag.test_value.shape[0]
        assert w.get_value().shape[1] == y.tag.test_value.shape[1]
        v_current = x.dot(w)  # (n_samples, n_dim_out)
        v_0 = v_current[None, :, :] - w[alpha, None, :]*x.T[alpha, :, None]  # (n_alpha, n_samples, n_dim_out)
        possible_vs = v_0[:, :, :, None] + possible_ws[None, None, None, :]*x.T[alpha, :, None, None]  # (n_alpha, n_samples, n_dim_out, n_possible_ws)
        all_zs = tt.nnet.sigmoid(possible_vs)  # (n_alpha, n_samples, n_dim_out, n_possible_ws)
        log_likelihoods = tt.sum(tt.log(bernoulli(y[None, :, :, None], all_zs[:, :, :, :])), axis = 1)  # (n_alpha, n_dim_out, n_possible_ws)
        # Question: Need to shift for stability here or will Theano take care of that?
        # Stupid theano didn't implement softmax very nicely so we have to do some reshaping.
        return tt.nnet.softmax(log_likelihoods.reshape([alpha.shape[0]*w.shape[1], possible_ws.shape[0]]))\
            .reshape([alpha.shape[0], w.shape[1], possible_ws.shape[0]])  # (n_alpha, n_dim_out, n_possible_ws)

    @symbolic_updater
    def train(self, x, y):
        p_wa = self.compute_p_wa(self._w, x, y, self._alpha, self._possible_ws)  # (n_alpha, n_dim_out, n_possible_ws)
        w_sample = sample_categorical(self._rng, p_wa, values = self._possible_ws)
        w_new = tt.set_subtensor(self._w[self._alpha], w_sample)  # (n_dim_in, n_dim_out)
        add_update(self._w, w_new)
        self._add_alpha_update()

    @symbolic_simple
    def predict(self, x):
        p_y = tt.nnet.sigmoid(x.dot(self._w))
        return self._rng.binomial(p = p_y) if self._sample_y else p_y
    def test_choice(self):
        """Test that RandomStreams.choice generates the same results as numpy"""
        # Check over two calls to see if the random state is correctly updated.
        random = RandomStreams(utt.fetch_seed())
        fn = function([], random.choice((11, 8), 10, 1, 0))
        fn_val0 = fn()
        fn_val1 = fn()

        rng_seed = np.random.RandomState(utt.fetch_seed()).randint(2**30)
        rng = np.random.RandomState(int(rng_seed))  # int() is for 32bit
        numpy_val0 = rng.choice(10, (11, 8), True, None)
        numpy_val1 = rng.choice(10, (11, 8), True, None)

        assert np.all(fn_val0 == numpy_val0)
        assert np.all(fn_val1 == numpy_val1)
Exemple #14
0
    def test_choice(self):
        # Test that RandomStreams.choice generates the same results as numpy
        # Check over two calls to see if the random state is correctly updated.
        random = RandomStreams(utt.fetch_seed())
        fn = function([], random.choice((11, 8), 10, 1, 0))
        fn_val0 = fn()
        fn_val1 = fn()

        rng_seed = np.random.RandomState(utt.fetch_seed()).randint(2**30)
        rng = np.random.RandomState(int(rng_seed))  # int() is for 32bit
        numpy_val0 = rng.choice(10, (11, 8), True, None)
        numpy_val1 = rng.choice(10, (11, 8), True, None)

        assert np.all(fn_val0 == numpy_val0)
        assert np.all(fn_val1 == numpy_val1)
    def test_choice(self):
        """Test that RandomStreams.choice generates the same results as numpy"""
        # numpy.random.choice is only available for numpy versions >= 1.7
        major, minor, _ = numpy.version.short_version.split('.')
        if (int(major), int(minor)) < (1, 7):
            raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
                               '(%s)' % numpy.__version__)

        # Check over two calls to see if the random state is correctly updated.
        random = RandomStreams(utt.fetch_seed())
        fn = function([], random.choice((11, 8), 10, 1, 0))
        fn_val0 = fn()
        fn_val1 = fn()

        rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
        rng = numpy.random.RandomState(int(rng_seed))  # int() is for 32bit
        numpy_val0 = rng.choice(10, (11, 8), True, None)
        numpy_val1 = rng.choice(10, (11, 8), True, None)

        assert numpy.all(fn_val0 == numpy_val0)
        assert numpy.all(fn_val1 == numpy_val1)
    def test_choice(self):
        """Test that RandomStreams.choice generates the same results as numpy"""
        # numpy.random.choice is only available for numpy versions >= 1.7
        major, minor, _ = numpy.version.short_version.split('.')
        if (int(major), int(minor)) < (1, 7):
            raise utt.SkipTest('choice requires at NumPy version >= 1.7 '
                               '(%s)' % numpy.__version__)
        
        # Check over two calls to see if the random state is correctly updated.
        random = RandomStreams(utt.fetch_seed())
        fn = function([], random.choice((11, 8), 10, 1, 0))
        fn_val0 = fn()
        fn_val1 = fn()

        rng_seed = numpy.random.RandomState(utt.fetch_seed()).randint(2**30)
        rng = numpy.random.RandomState(int(rng_seed))  # int() is for 32bit
        numpy_val0 = rng.choice(10, (11, 8), True, None)
        numpy_val1 = rng.choice(10, (11, 8), True, None)

        assert numpy.all(fn_val0 == numpy_val0)
        assert numpy.all(fn_val1 == numpy_val1)
import theano as th

data = np.random.rand(10,3)




it = th.shared(0)
y = th.shared(data)




srng = RandomStreams(seed=234)

expectRvs   = srng.normal(size=(3,1))
expectRvs.name='expectRvs'
epochStream = srng.permutation(n=10)
currentBatch = epochStream.reshape((5,2))[:,it]
y_mini = y[ currentBatch, :]
L = th.tensor.sum(th.tensor.dot( y_mini, expectRvs ))
L_func = function([], L, no_default_updates=True)

padding = srng.choice(size=(3,), a=10, replace=False, p=None, ndim=None, dtype='int64')



f1 = function([], expectRvs, no_default_updates=True)
f2 = function([], expectRvs)

Exemple #18
0
print "Single Normal ", norm()

#############Random integer list

rn_i = srng.random_integers(size = (4, ), low=1, high=900)
inte = function([], rn_i)
print "Integer list ", inte()

#############Generating a permutation unifromly at random

rn_p = srng.permutation(size=(), n = 10)
perm = function([], rn_p)
print "Random permutation of 0 to 9", perm()

#############choosing from a list randomly

rn_list = srng.choice(size=(), a=[2,3, 4.5, 6], replace=True, p=[.5, 0, .5, 0], dtype='float64')
lis = function([], rn_list)
print "Choosing 3 times from the specified list ", lis()
print lis()
print lis()

rn_another_list = srng.choice(size=(), a=3, replace=True, p=None)
an_list = function([], rn_another_list)

print "Choosing 3 times from [0,1, 2] since a is scalar", an_list()
print an_list()
print an_list()


Exemple #19
0
class CSDGM(Model):
    """
    The :class:'CSDGM' class represents the implementation of the model described in the
    Auxiliary Generative Models article on Arxiv.org.
    """
    def __init__(self,
                 n_c,
                 n_l,
                 n_a,
                 n_z,
                 n_y,
                 qa_hid,
                 qz_hid,
                 qy_hid,
                 px_hid,
                 pa_hid,
                 filters,
                 nonlinearity=rectify,
                 px_nonlinearity=None,
                 x_dist='bernoulli',
                 batchnorm=False,
                 seed=1234):
        """
        Initialize an skip deep generative model consisting of
        discriminative classifier q(y|a,x),
        generative model P p(a|z,y) and p(x|a,z,y),
        inference model Q q(a|x) and q(z|a,x,y).
        Weights are initialized using the Bengio and Glorot (2010) initialization scheme.
        :param n_c: Number of input channels.
        :param n_l: Number of lengths.
        :param n_a: Number of auxiliary.
        :param n_z: Number of latent.
        :param n_y: Number of classes.
        :param qa_hid: List of number of deterministic hidden q(a|x).
        :param qz_hid: List of number of deterministic hidden q(z|a,x,y).
        :param qy_hid: List of number of deterministic hidden q(y|a,x).
        :param px_hid: List of number of deterministic hidden p(a|z,y) & p(x|z,y).
        :param nonlinearity: The transfer function used in the deterministic layers.
        :param x_dist: The x distribution, 'bernoulli', 'multinomial', or 'gaussian'.
        :param batchnorm: Boolean value for batch normalization.
        :param seed: The random seed.
        """
        super(CSDGM, self).__init__(n_c, qz_hid + px_hid, n_a + n_z,
                                    nonlinearity)
        self.x_dist = x_dist
        self.n_y = n_y
        self.n_c = n_c
        self.n_l = n_l
        self.n_a = n_a
        self.n_z = n_z
        self.batchnorm = batchnorm
        self._srng = RandomStreams(seed)

        # Decide Glorot initializaiton of weights.
        init_w = 1e-3
        hid_w = ""
        if nonlinearity == rectify or nonlinearity == softplus:
            hid_w = "relu"

        pool_layers = []

        # Define symbolic variables for theano functions.
        self.sym_beta = T.scalar('beta')  # scaling constant beta
        self.sym_x_l = T.tensor3('x')  # labeled inputs
        self.sym_t_l = T.matrix('t')  # labeled targets
        self.sym_x_u = T.tensor3('x')  # unlabeled inputs
        self.sym_bs_l = T.iscalar('bs_l')  # number of labeled data
        self.sym_samples = T.iscalar('samples')  # MC samples
        self.sym_z = T.matrix('z')  # latent variable z
        self.sym_a = T.matrix('a')  # auxiliary variable a
        self.sym_warmup = T.fscalar('warmup')  # warmup to scale KL term

        # Assist methods for collecting the layers
        def dense_layer(layer_in,
                        n,
                        dist_w=init.GlorotNormal,
                        dist_b=init.Normal):
            dense = DenseLayer(layer_in, n, dist_w(hid_w), dist_b(init_w),
                               None)
            if batchnorm:
                dense = BatchNormLayer(dense)
            return NonlinearityLayer(dense, self.transf)

        def stochastic_layer(layer_in, n, samples, nonlin=None):
            mu = DenseLayer(layer_in, n, init.Normal(init_w),
                            init.Normal(init_w), nonlin)
            logvar = DenseLayer(layer_in, n, init.Normal(init_w),
                                init.Normal(init_w), nonlin)
            return SampleLayer(mu, logvar, eq_samples=samples,
                               iw_samples=1), mu, logvar

        def conv_layer(layer_in,
                       filter,
                       stride=(1, 1),
                       pool=1,
                       name='conv',
                       dist_w=init.GlorotNormal,
                       dist_b=init.Normal):
            l_conv = Conv2DLayer(layer_in,
                                 num_filters=filter,
                                 filter_size=(3, 1),
                                 stride=stride,
                                 pad='full',
                                 W=dist_w(hid_w),
                                 b=dist_b(init_w),
                                 name=name)
            if pool > 1:
                l_conv = MaxPool2DLayer(l_conv, pool_size=(pool, 1))
                pool_layers.append(l_conv)
            return l_conv

        # Input layers
        l_y_in = InputLayer((None, n_y))
        l_x_in = InputLayer((None, n_l, n_c), name='Input')

        # Reshape input
        l_x_in_reshp = ReshapeLayer(l_x_in, (-1, 1, n_l, n_c))
        print("l_x_in_reshp", l_x_in_reshp.output_shape)

        # CNN encoder implementation
        l_conv_enc = l_x_in_reshp
        for filter, stride, pool in filters:
            l_conv_enc = conv_layer(l_conv_enc, filter, stride, pool)
            print("l_conv_enc", l_conv_enc.output_shape)

        # Pool along last 2 axes
        l_global_pool_enc = GlobalPoolLayer(l_conv_enc, pool_function=T.mean)
        l_enc = dense_layer(l_global_pool_enc, n_z)
        print("l_enc", l_enc.output_shape)

        # Auxiliary q(a|x)
        l_qa_x = l_enc
        for hid in qa_hid:
            l_qa_x = dense_layer(l_qa_x, hid)
        l_qa_x, l_qa_x_mu, l_qa_x_logvar = stochastic_layer(
            l_qa_x, n_a, self.sym_samples)

        # Classifier q(y|a,x)
        l_qa_to_qy = DenseLayer(l_qa_x, qy_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qa_to_qy = ReshapeLayer(l_qa_to_qy,
                                  (-1, self.sym_samples, 1, qy_hid[0]))
        l_x_to_qy = DenseLayer(l_enc, qy_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_x_to_qy = DimshuffleLayer(l_x_to_qy, (0, 'x', 'x', 1))
        l_qy_xa = ReshapeLayer(ElemwiseSumLayer([l_qa_to_qy, l_x_to_qy]),
                               (-1, qy_hid[0]))
        if batchnorm:
            l_qy_xa = BatchNormLayer(l_qy_xa)
        l_qy_xa = NonlinearityLayer(l_qy_xa, self.transf)
        if len(qy_hid) > 1:
            for hid in qy_hid[1:]:
                l_qy_xa = dense_layer(l_qy_xa, hid)
        l_qy_xa = DenseLayer(l_qy_xa, n_y, init.GlorotNormal(),
                             init.Normal(init_w), softmax)

        # Recognition q(z|x,a,y)
        l_qa_to_qz = DenseLayer(l_qa_x, qz_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qa_to_qz = ReshapeLayer(l_qa_to_qz,
                                  (-1, self.sym_samples, 1, qz_hid[0]))
        l_x_to_qz = DenseLayer(l_enc, qz_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_x_to_qz = DimshuffleLayer(l_x_to_qz, (0, 'x', 'x', 1))
        l_y_to_qz = DenseLayer(l_y_in, qz_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_y_to_qz = DimshuffleLayer(l_y_to_qz, (0, 'x', 'x', 1))
        l_qz_axy = ReshapeLayer(
            ElemwiseSumLayer([l_qa_to_qz, l_x_to_qz, l_y_to_qz]),
            (-1, qz_hid[0]))
        if batchnorm:
            l_qz_axy = BatchNormLayer(l_qz_axy)
        l_qz_axy = NonlinearityLayer(l_qz_axy, self.transf)
        if len(qz_hid) > 1:
            for hid in qz_hid[1:]:
                l_qz_axy = dense_layer(l_qz_axy, hid)
        l_qz_axy, l_qz_axy_mu, l_qz_axy_logvar = stochastic_layer(
            l_qz_axy, n_z, 1)

        # Generative p(a|z,y)
        l_y_to_pa = DenseLayer(l_y_in, pa_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_y_to_pa = DimshuffleLayer(l_y_to_pa, (0, 'x', 'x', 1))
        l_qz_to_pa = DenseLayer(l_qz_axy, pa_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qz_to_pa = ReshapeLayer(l_qz_to_pa,
                                  (-1, self.sym_samples, 1, pa_hid[0]))
        l_pa_zy = ReshapeLayer(ElemwiseSumLayer([l_qz_to_pa, l_y_to_pa]),
                               [-1, pa_hid[0]])
        if batchnorm:
            l_pa_zy = BatchNormLayer(l_pa_zy)
        l_pa_zy = NonlinearityLayer(l_pa_zy, self.transf)
        if len(pa_hid) > 1:
            for hid in pa_hid[1:]:
                l_pa_zy = dense_layer(l_pa_zy, hid)
        l_pa_zy, l_pa_zy_mu, l_pa_zy_logvar = stochastic_layer(l_pa_zy, n_a, 1)

        # Generative p(x|a,z,y)
        l_qa_to_px = DenseLayer(l_qa_x, px_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qa_to_px = ReshapeLayer(l_qa_to_px,
                                  (-1, self.sym_samples, 1, px_hid[0]))
        l_y_to_px = DenseLayer(l_y_in, px_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_y_to_px = DimshuffleLayer(l_y_to_px, (0, 'x', 'x', 1))
        l_qz_to_px = DenseLayer(l_qz_axy, px_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qz_to_px = ReshapeLayer(l_qz_to_px,
                                  (-1, self.sym_samples, 1, px_hid[0]))
        l_px_azy = ReshapeLayer(
            ElemwiseSumLayer([l_qa_to_px, l_qz_to_px, l_y_to_px]),
            [-1, px_hid[0]])
        if batchnorm:
            l_px_azy = BatchNormLayer(l_px_azy)
        l_px_azy = NonlinearityLayer(l_px_azy, self.transf)

        # Note that px_hid[0] has to be equal to the number filters in the first convolution. Otherwise add a
        # dense layers here.

        # Inverse pooling
        l_global_depool = InverseLayer(l_px_azy, l_global_pool_enc)
        print("l_global_depool", l_global_depool.output_shape)

        # Reverse pool layer order
        pool_layers = pool_layers[::-1]

        # Decode
        l_deconv = l_global_depool
        for idx, filter in enumerate(filters[::-1]):
            filter, stride, pool = filter
            if pool > 1:
                l_deconv = InverseLayer(l_deconv, pool_layers[idx])
            l_deconv = Conv2DLayer(l_deconv,
                                   num_filters=filter,
                                   filter_size=(3, 1),
                                   stride=(stride, 1),
                                   W=init.GlorotNormal('relu'))
            print("l_deconv", l_deconv.output_shape)

        # The last l_conv layer should give us the input shape
        l_px_azy = Conv2DLayer(l_deconv,
                               num_filters=1,
                               filter_size=(3, 1),
                               pad='same',
                               nonlinearity=None)
        print("l_dec", l_px_azy.output_shape)

        # Flatten first two dimensions
        l_px_azy = ReshapeLayer(l_px_azy, (-1, n_c))

        if x_dist == 'bernoulli':
            l_px_azy = DenseLayer(l_px_azy, n_c, init.GlorotNormal(),
                                  init.Normal(init_w), sigmoid)
        elif x_dist == 'multinomial':
            l_px_azy = DenseLayer(l_px_azy, n_c, init.GlorotNormal(),
                                  init.Normal(init_w), softmax)
        elif x_dist == 'gaussian':
            l_px_azy, l_px_zy_mu, l_px_zy_logvar = stochastic_layer(
                l_px_azy, n_c, self.sym_samples, px_nonlinearity)
        elif x_dist == 'linear':
            l_px_azy = DenseLayer(l_px_azy, n_c, nonlinearity=None)

        # Reshape all the model layers to have the same size
        self.l_x_in = l_x_in
        self.l_y_in = l_y_in
        self.l_a_in = l_qa_x

        self.l_qa = ReshapeLayer(l_qa_x, (-1, self.sym_samples, 1, n_a))
        self.l_qa_mu = DimshuffleLayer(l_qa_x_mu, (0, 'x', 'x', 1))
        self.l_qa_logvar = DimshuffleLayer(l_qa_x_logvar, (0, 'x', 'x', 1))

        self.l_qz = ReshapeLayer(l_qz_axy, (-1, self.sym_samples, 1, n_z))
        self.l_qz_mu = ReshapeLayer(l_qz_axy_mu,
                                    (-1, self.sym_samples, 1, n_z))
        self.l_qz_logvar = ReshapeLayer(l_qz_axy_logvar,
                                        (-1, self.sym_samples, 1, n_z))

        self.l_qy = ReshapeLayer(l_qy_xa, (-1, self.sym_samples, 1, n_y))

        self.l_pa = ReshapeLayer(l_pa_zy, (-1, self.sym_samples, 1, n_a))
        self.l_pa_mu = ReshapeLayer(l_pa_zy_mu, (-1, self.sym_samples, 1, n_a))
        self.l_pa_logvar = ReshapeLayer(l_pa_zy_logvar,
                                        (-1, self.sym_samples, 1, n_a))

        # Here we assume that we pass (batch size * segment length, number of features) to the sample layer from
        # which we then get (batch size * segment length, samples, IW samples, features)
        self.l_px = ReshapeLayer(l_px_azy, (-1, n_l, self.sym_samples, 1, n_c))
        self.l_px_mu = ReshapeLayer(l_px_zy_mu, (-1, n_l, self.sym_samples, 1, n_c)) \
            if x_dist == "gaussian" else None
        self.l_px_logvar = ReshapeLayer(l_px_zy_logvar, (-1, n_l, self.sym_samples, 1, n_c)) \
            if x_dist == "gaussian" else None

        # Predefined functions
        inputs = {l_x_in: self.sym_x_l}
        outputs = get_output(self.l_qy, inputs,
                             deterministic=True).mean(axis=(1, 2))
        self.f_qy = theano.function([self.sym_x_l, self.sym_samples], outputs)

        outputs = get_output(l_qa_x, inputs, deterministic=True)
        self.f_qa = theano.function([self.sym_x_l, self.sym_samples], outputs)

        inputs = {l_x_in: self.sym_x_l, l_y_in: self.sym_t_l}
        outputs = get_output(l_qz_axy, inputs, deterministic=True)
        self.f_qz = theano.function(
            [self.sym_x_l, self.sym_t_l, self.sym_samples], outputs)

        inputs = {l_qz_axy: self.sym_z, l_y_in: self.sym_t_l}
        outputs = get_output(self.l_pa, inputs,
                             deterministic=True).mean(axis=(1, 2))
        self.f_pa = theano.function(
            [self.sym_z, self.sym_t_l, self.sym_samples], outputs)

        inputs = {
            l_x_in: self.sym_x_l,
            l_qa_x: self.sym_a,
            l_qz_axy: self.sym_z,
            l_y_in: self.sym_t_l
        }
        outputs = get_output(self.l_px, inputs,
                             deterministic=True).mean(axis=(2, 3))
        self.f_px = theano.function([
            self.sym_x_l, self.sym_a, self.sym_z, self.sym_t_l,
            self.sym_samples
        ], outputs)

        outputs = get_output(self.l_px_mu, inputs,
                             deterministic=True).mean(axis=(2, 3))
        self.f_mu = theano.function([
            self.sym_x_l, self.sym_a, self.sym_z, self.sym_t_l,
            self.sym_samples
        ], outputs)

        outputs = get_output(self.l_px_logvar, inputs,
                             deterministic=True).mean(axis=(2, 3))
        self.f_var = theano.function([
            self.sym_x_l, self.sym_a, self.sym_z, self.sym_t_l,
            self.sym_samples
        ], outputs)

        # Define model parameters
        self.model_params = get_all_params([self.l_qy, self.l_pa, self.l_px])
        self.trainable_model_params = get_all_params(
            [self.l_qy, self.l_pa, self.l_px], trainable=True)

    def build_model(self,
                    train_set_unlabeled,
                    train_set_labeled,
                    test_set,
                    validation_set=None):
        """
        Build the auxiliary deep generative model from the initialized hyperparameters.
        Define the lower bound term and compile it into a training function.
        :param train_set_unlabeled: Unlabeled train set containing variables x, t.
        :param train_set_labeled: Unlabeled train set containing variables x, t.
        :param test_set: Test set containing variables x, t.
        :param validation_set: Validation set containing variables x, t.
        :return: train, test, validation function and dicts of arguments.
        """
        super(CSDGM, self).build_model(train_set_unlabeled, test_set,
                                       validation_set)

        sh_train_x_l = theano.shared(np.asarray(train_set_labeled[0],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        sh_train_t_l = theano.shared(np.asarray(train_set_labeled[1],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        n = self.sh_train_x.shape[0].astype(
            theano.config.floatX)  # no. of data points
        n_l = sh_train_x_l.shape[0].astype(
            theano.config.floatX)  # no. of labeled data points

        # Define the layers for the density estimation used in the lower bound.
        l_log_qa = GaussianLogDensityLayer(self.l_qa, self.l_qa_mu,
                                           self.l_qa_logvar)
        l_log_qz = GaussianLogDensityLayer(self.l_qz, self.l_qz_mu,
                                           self.l_qz_logvar)
        l_log_qy = MultinomialLogDensityLayer(self.l_qy, self.l_y_in, eps=1e-8)

        l_log_pz = StandardNormalLogDensityLayer(self.l_qz)
        l_log_pa = GaussianLogDensityLayer(self.l_qa, self.l_pa_mu,
                                           self.l_pa_logvar)

        l_x_in = ReshapeLayer(self.l_x_in, (-1, self.n_l * self.n_c))
        l_px = DimshuffleLayer(self.l_px, (0, 3, 1, 2, 4))
        l_px = ReshapeLayer(l_px, (-1, self.sym_samples, 1, self.n_c))
        if self.x_dist == 'bernoulli':
            l_log_px = BernoulliLogDensityLayer(self.l_px, self.l_x_in)
        elif self.x_dist == 'multinomial':
            l_log_px = MultinomialLogDensityLayer(l_px, l_x_in)
            l_log_px = ReshapeLayer(l_log_px, (-1, self.n_l, 1, 1, 1))
            l_log_px = MeanLayer(l_log_px, axis=1)
        elif self.x_dist == 'gaussian':
            l_px_mu = ReshapeLayer(
                DimshuffleLayer(self.l_px_mu, (0, 2, 3, 1, 4)),
                (-1, self.sym_samples, 1, self.n_l * self.n_c))
            l_px_logvar = ReshapeLayer(
                DimshuffleLayer(self.l_px_logvar, (0, 2, 3, 1, 4)),
                (-1, self.sym_samples, 1, self.n_l * self.n_c))
            l_log_px = GaussianLogDensityLayer(l_x_in, l_px_mu, l_px_logvar)

        def lower_bound(log_pa, log_qa, log_pz, log_qz, log_py, log_px):
            lb = log_px + log_py + (log_pz + log_pa - log_qa -
                                    log_qz) * (1.1 - self.sym_warmup)
            return lb

        # Lower bound for labeled data
        out_layers = [
            l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px, l_log_qy
        ]
        inputs = {self.l_x_in: self.sym_x_l, self.l_y_in: self.sym_t_l}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_l, log_pz_l, log_qa_x_l, log_qz_axy_l, log_px_zy_l, log_qy_ax_l = out

        # Prior p(y) expecting that all classes are evenly distributed
        py_l = softmax(T.zeros((self.sym_x_l.shape[0], self.n_y)))
        log_py_l = -categorical_crossentropy(py_l, self.sym_t_l).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_l = lower_bound(log_pa_l, log_qa_x_l, log_pz_l, log_qz_axy_l,
                           log_py_l, log_px_zy_l)
        lb_l = lb_l.mean(axis=(1, 2))  # Mean over the sampling dimensions
        log_qy_ax_l *= (
            self.sym_beta * (n / n_l)
        )  # Scale the supervised cross entropy with the alpha constant
        lb_l += log_qy_ax_l.mean(axis=(
            1, 2
        ))  # Collect the lower bound term and mean over sampling dimensions

        # Lower bound for unlabeled data
        bs_u = self.sym_x_u.shape[0]

        # For the integrating out approach, we repeat the input matrix x, and construct a target (bs * n_y) x n_y
        # Example of input and target matrix for a 3 class problem and batch_size=2. 2D tensors of the form
        #               x_repeat                     t_repeat
        #  [[x[0,0], x[0,1], ..., x[0,n_x]]         [[1, 0, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [1, 0, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 1, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [0, 1, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 0, 1]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]]         [0, 0, 1]]
        t_eye = T.eye(self.n_y, k=0)
        t_u = t_eye.reshape((self.n_y, 1, self.n_y)).repeat(bs_u,
                                                            axis=1).reshape(
                                                                (-1, self.n_y))
        x_u = self.sym_x_u.reshape(
            (1, bs_u, self.n_l, self.n_c)).repeat(self.n_y, axis=0).reshape(
                (-1, self.n_l, self.n_c))

        # Since the expectation of var a is outside the integration we calculate E_q(a|x) first
        a_x_u = get_output(self.l_qa,
                           self.sym_x_u,
                           batch_norm_update_averages=True,
                           batch_norm_use_averages=False)
        a_x_u_rep = a_x_u.reshape(
            (1, bs_u * self.sym_samples, self.n_a)).repeat(self.n_y,
                                                           axis=0).reshape(
                                                               (-1, self.n_a))
        out_layers = [l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px]
        inputs = {self.l_x_in: x_u, self.l_y_in: t_u, self.l_a_in: a_x_u_rep}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_u, log_pz_u, log_qa_x_u, log_qz_axy_u, log_px_zy_u = out

        # Prior p(y) expecting that all classes are evenly distributed
        py_u = softmax(T.zeros((bs_u * self.n_y, self.n_y)))
        log_py_u = -categorical_crossentropy(py_u, t_u).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_u = lower_bound(log_pa_u, log_qa_x_u, log_pz_u, log_qz_axy_u,
                           log_py_u, log_px_zy_u)
        lb_u = lb_u.reshape(
            (self.n_y, 1, 1, bs_u)).transpose(3, 1, 2, 0).mean(axis=(1, 2))
        inputs = {
            self.l_x_in: self.sym_x_u,
            self.l_a_in: a_x_u.reshape((-1, self.n_a))
        }
        y_u = get_output(self.l_qy,
                         inputs,
                         batch_norm_update_averages=True,
                         batch_norm_use_averages=False).mean(axis=(1, 2))
        y_u += 1e-8  # Ensure that we get no NANs when calculating the entropy
        y_u /= T.sum(y_u, axis=1, keepdims=True)
        lb_u = (y_u * (lb_u - T.log(y_u))).sum(axis=1)

        # Regularizing with weight priors p(theta|N(0,1)), collecting and clipping gradients
        weight_priors = 0.0
        for p in self.trainable_model_params:
            if 'W' not in str(p):
                continue
            weight_priors += log_normal(p, 0, 1).sum()

        # Collect the lower bound and scale it with the weight priors.
        elbo = ((lb_l.mean() + lb_u.mean()) * n + weight_priors) / -n
        lb_labeled = -lb_l.mean()
        lb_unlabeled = -lb_u.mean()
        log_px = log_px_zy_l.mean() + log_px_zy_u.mean()
        log_pz = log_pz_l.mean() + log_pz_u.mean()
        log_qz = log_qz_axy_l.mean() + log_qz_axy_u.mean()
        log_pa = log_pa_l.mean() + log_pa_u.mean()
        log_qa = log_qa_x_l.mean() + log_qa_x_u.mean()

        grads_collect = T.grad(elbo, self.trainable_model_params)
        params_collect = self.trainable_model_params
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        clip_grad, max_norm = 1, 5
        mgrads = total_norm_constraint(grads_collect, max_norm=max_norm)
        mgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
        updates = adam(mgrads, params_collect, self.sym_lr, sym_beta1,
                       sym_beta2)

        # Training function
        indices = self._srng.choice(size=[self.sym_bs_l],
                                    a=sh_train_x_l.shape[0],
                                    replace=False)
        x_batch_l = sh_train_x_l[indices]
        t_batch_l = sh_train_t_l[indices]
        x_batch_u = self.sh_train_x[self.batch_slice]
        if self.x_dist == 'bernoulli':  # Sample bernoulli input.
            x_batch_u = self._srng.binomial(size=x_batch_u.shape,
                                            n=1,
                                            p=x_batch_u,
                                            dtype=theano.config.floatX)
            x_batch_l = self._srng.binomial(size=x_batch_l.shape,
                                            n=1,
                                            p=x_batch_l,
                                            dtype=theano.config.floatX)

        givens = {
            self.sym_x_l: x_batch_l,
            self.sym_x_u: x_batch_u,
            self.sym_t_l: t_batch_l
        }
        inputs = [
            self.sym_index, self.sym_batchsize, self.sym_bs_l, self.sym_beta,
            self.sym_lr, sym_beta1, sym_beta2, self.sym_samples,
            self.sym_warmup
        ]
        outputs = [
            elbo, lb_labeled, lb_unlabeled, log_px, log_pz, log_qz, log_pa,
            log_qa
        ]
        f_train = theano.function(inputs=inputs,
                                  outputs=outputs,
                                  givens=givens,
                                  updates=updates)

        # Default training args. Note that these can be changed during or prior to training.
        self.train_args['inputs']['batchsize_unlabeled'] = 100
        self.train_args['inputs']['batchsize_labeled'] = 100
        self.train_args['inputs']['beta'] = 0.1
        self.train_args['inputs']['learningrate'] = 3e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['inputs']['warmup'] = 0.1
        self.train_args['outputs']['lb'] = '%0.3f'
        self.train_args['outputs']['lb-l'] = '%0.3f'
        self.train_args['outputs']['lb-u'] = '%0.3f'
        self.train_args['outputs']['px'] = '%0.3f'
        self.train_args['outputs']['pz'] = '%0.3f'
        self.train_args['outputs']['qz'] = '%0.3f'
        self.train_args['outputs']['pa'] = '%0.3f'
        self.train_args['outputs']['qa'] = '%0.3f'

        # Validation and test function
        y = get_output(self.l_qy, self.sym_x_l,
                       deterministic=True).mean(axis=(1, 2))
        class_err = (1. - categorical_accuracy(y, self.sym_t_l).mean()) * 100
        givens = {self.sym_x_l: self.sh_test_x, self.sym_t_l: self.sh_test_t}
        f_test = theano.function(inputs=[self.sym_samples],
                                 outputs=[class_err],
                                 givens=givens)

        # Test args.  Note that these can be changed during or prior to training.
        self.test_args['inputs']['samples'] = 1
        self.test_args['outputs']['test'] = '%0.2f%%'

        f_validate = None
        if validation_set is not None:
            givens = {
                self.sym_x_l: self.sh_valid_x,
                self.sym_t_l: self.sh_valid_t
            }
            f_validate = theano.function(inputs=[self.sym_samples],
                                         outputs=[class_err],
                                         givens=givens)
            # Default validation args. Note that these can be changed during or prior to training.
            self.validate_args['inputs']['samples'] = 1
            self.validate_args['outputs']['validation'] = '%0.2f%%'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args

    def get_output(self, x, samples=1):
        return self.f_qy(x, samples)

    def model_info(self):
        qa_shapes = self.get_model_shape(get_all_params(self.l_qa))
        qy_shapes = self.get_model_shape(get_all_params(
            self.l_qy))[len(qa_shapes) - 1:]
        qz_shapes = self.get_model_shape(get_all_params(
            self.l_qz))[len(qa_shapes) - 1:]
        px_shapes = self.get_model_shape(get_all_params(
            self.l_px))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        pa_shapes = self.get_model_shape(get_all_params(
            self.l_pa))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        s = ""
        s += 'batch norm: %s.\n' % (str(self.batchnorm))
        s += 'x distribution: %s.\n' % (str(self.x_dist))
        s += 'model q(a|x): %s.\n' % str(qa_shapes)[1:-1]
        s += 'model q(z|a,x,y): %s.\n' % str(qz_shapes)[1:-1]
        s += 'model q(y|a,x): %s.\n' % str(qy_shapes)[1:-1]
        s += 'model p(x|a,z,y): %s.\n' % str(px_shapes)[1:-1]
        s += 'model p(a|z,y): %s.' % str(pa_shapes)[1:-1]
        return s
class ADGM(Model):
    """
    The :class:'ADGM' class represents the implementation of the model described in the
    Auxiliary Generative Models article on Arxiv.org.
    """
    def __init__(self,
                 n_x,
                 n_a,
                 n_z,
                 n_y,
                 qa_hid,
                 qz_hid,
                 qy_hid,
                 pax_hid,
                 trans_func=rectify,
                 x_dist='bernoulli',
                 batchnorm=False,
                 seed=1234):
        """
        Initialize an auxiliary deep generative model consisting of
        discriminative classifier q(y|a,x),
        generative model P p(a|z,y) and p(x|z,y),
        inference model Q q(a|x) and q(z|a,x,y).
        Weights are initialized using the Bengio and Glorot (2010) initialization scheme.
        :param n_x: Number of inputs.
        :param n_a: Number of auxiliary.
        :param n_z: Number of latent.
        :param n_y: Number of classes.
        :param qa_hid: List of number of deterministic hidden q(a|x).
        :param qz_hid: List of number of deterministic hidden q(z|a,x,y).
        :param qy_hid: List of number of deterministic hidden q(y|a,x).
        :param pax_hid: List of number of deterministic hidden p(a|z,y) & p(x|z,y).
        :param trans_func: The transfer function used in the deterministic layers.
        :param x_dist: The x distribution, 'bernoulli', 'multinomial', or 'gaussian'.
        :param batchnorm: Boolean value for batch normalization.
        :param seed: The random seed.
        """
        super(ADGM, self).__init__(n_x, qz_hid + pax_hid, n_a + n_z,
                                   trans_func)
        self.x_dist = x_dist
        self.n_y = n_y
        self.n_x = n_x
        self.n_a = n_a
        self.n_z = n_z
        self.batchnorm = batchnorm
        self._srng = RandomStreams(seed)

        # Decide Glorot initializaiton of weights.
        init_w = 1e-3
        hid_w = ""
        if trans_func == rectify or trans_func == softplus:
            hid_w = "relu"

        # Define symbolic variables for theano functions.
        self.sym_beta = T.scalar('beta')  # scaling constant beta
        self.sym_x_l = T.matrix('x')  # labeled inputs
        self.sym_t_l = T.matrix('t')  # labeled targets
        self.sym_x_u = T.matrix('x')  # unlabeled inputs
        self.sym_bs_l = T.iscalar('bs_l')  # number of labeled data
        self.sym_samples = T.iscalar('samples')  # MC samples
        self.sym_z = T.matrix('z')  # latent variable z
        self.sym_a = T.matrix('a')  # auxiliary variable a

        # Assist methods for collecting the layers
        def dense_layer(layer_in,
                        n,
                        dist_w=init.GlorotNormal,
                        dist_b=init.Normal):
            dense = DenseLayer(layer_in, n, dist_w(hid_w), dist_b(init_w),
                               None)
            if batchnorm:
                dense = BatchNormLayer(dense)
            return NonlinearityLayer(dense, self.transf)

        def stochastic_layer(layer_in, n, samples):
            mu = DenseLayer(layer_in, n, init.Normal(init_w),
                            init.Normal(init_w), None)
            logvar = DenseLayer(layer_in, n, init.Normal(init_w),
                                init.Normal(init_w), None)
            return SampleLayer(mu, logvar, eq_samples=samples,
                               iw_samples=1), mu, logvar

        # Input layers
        l_x_in = InputLayer((None, n_x))
        l_y_in = InputLayer((None, n_y))

        # Auxiliary q(a|x)
        l_qa_x = l_x_in
        for hid in qa_hid:
            l_qa_x = dense_layer(l_qa_x, hid)
        l_qa_x, l_qa_x_mu, l_qa_x_logvar = stochastic_layer(
            l_qa_x, n_a, self.sym_samples)

        # Classifier q(y|a,x)
        l_qa_to_qy = DenseLayer(l_qa_x, qy_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qa_to_qy = ReshapeLayer(l_qa_to_qy,
                                  (-1, self.sym_samples, 1, qy_hid[0]))
        l_x_to_qy = DenseLayer(l_x_in, qy_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        self.l_x_to_qy = l_x_to_qy
        l_x_to_qy = DimshuffleLayer(l_x_to_qy, (0, 'x', 'x', 1))
        l_qy_xa = ReshapeLayer(ElemwiseSumLayer([l_qa_to_qy, l_x_to_qy]),
                               (-1, qy_hid[0]))
        if batchnorm:
            l_qy_xa = BatchNormLayer(l_qy_xa)
        l_qy_xa = NonlinearityLayer(l_qy_xa, self.transf)
        if len(qy_hid) > 1:
            for hid in qy_hid[1:]:
                l_qy_xa = dense_layer(l_qy_xa, hid)
        l_qy_xa = DenseLayer(l_qy_xa, n_y, init.GlorotNormal(),
                             init.Normal(init_w), softmax)

        # Recognition q(z|x,a,y)
        l_qa_to_qz = DenseLayer(l_qa_x, qz_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qa_to_qz = ReshapeLayer(l_qa_to_qz,
                                  (-1, self.sym_samples, 1, qz_hid[0]))
        l_x_to_qz = DenseLayer(l_x_in, qz_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_x_to_qz = DimshuffleLayer(l_x_to_qz, (0, 'x', 'x', 1))
        l_y_to_qz = DenseLayer(l_y_in, qz_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_y_to_qz = DimshuffleLayer(l_y_to_qz, (0, 'x', 'x', 1))
        l_qz_axy = ReshapeLayer(
            ElemwiseSumLayer([l_qa_to_qz, l_x_to_qz, l_y_to_qz]),
            (-1, qz_hid[0]))
        if batchnorm:
            l_qz_axy = BatchNormLayer(l_qz_axy)
        l_qz_axy = NonlinearityLayer(l_qz_axy, self.transf)
        if len(qz_hid) > 1:
            for hid in qz_hid[1:]:
                l_qz_axy = dense_layer(l_qz_axy, hid)
        l_qz_axy, l_qz_axy_mu, l_qz_axy_logvar = stochastic_layer(
            l_qz_axy, n_z, 1)

        # Generative p(x|z,y), p(a|z,y)
        l_y_to_px = DenseLayer(l_y_in, pax_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_y_to_px = DimshuffleLayer(l_y_to_px, (0, 'x', 'x', 1))
        l_qz_to_px = DenseLayer(l_qz_axy, pax_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qz_to_px = ReshapeLayer(l_qz_to_px,
                                  (-1, self.sym_samples, 1, pax_hid[0]))
        l_px_zy = ReshapeLayer(ElemwiseSumLayer([l_qz_to_px, l_y_to_px]),
                               [-1, pax_hid[0]])
        if batchnorm:
            l_px_zy = BatchNormLayer(l_px_zy)
        l_px_zy = NonlinearityLayer(l_px_zy, self.transf)
        if len(pax_hid) > 1:
            for hid in pax_hid[1:]:
                l_px_zy = dense_layer(l_px_zy, hid)

        l_pa_zy, l_pa_zy_mu, l_pa_zy_logvar = stochastic_layer(l_px_zy, n_a, 1)
        if x_dist == 'bernoulli':
            l_px_zy = DenseLayer(l_px_zy, n_x, init.GlorotNormal(),
                                 init.Normal(init_w), sigmoid)
        elif x_dist == 'multinomial':
            l_px_zy = DenseLayer(l_px_zy, n_x, init.GlorotNormal(),
                                 init.Normal(init_w), softmax)
        elif x_dist == 'gaussian':
            l_px_zy, l_px_zy_mu, l_px_zy_logvar = stochastic_layer(
                l_px_zy, n_x, 1)

        # Reshape all the model layers to have the same size
        self.l_x_in = l_x_in
        self.l_y_in = l_y_in
        self.l_a_in = l_qa_x

        self.l_qa = ReshapeLayer(l_qa_x, (-1, self.sym_samples, 1, n_a))
        self.l_qa_mu = DimshuffleLayer(l_qa_x_mu, (0, 'x', 'x', 1))
        self.l_qa_logvar = DimshuffleLayer(l_qa_x_logvar, (0, 'x', 'x', 1))

        self.l_qz = ReshapeLayer(l_qz_axy, (-1, self.sym_samples, 1, n_z))
        self.l_qz_mu = ReshapeLayer(l_qz_axy_mu,
                                    (-1, self.sym_samples, 1, n_z))
        self.l_qz_logvar = ReshapeLayer(l_qz_axy_logvar,
                                        (-1, self.sym_samples, 1, n_z))

        self.l_qy = ReshapeLayer(l_qy_xa, (-1, self.sym_samples, 1, n_y))

        self.l_pa = ReshapeLayer(l_pa_zy, (-1, self.sym_samples, 1, n_a))
        self.l_pa_mu = ReshapeLayer(l_pa_zy_mu, (-1, self.sym_samples, 1, n_a))
        self.l_pa_logvar = ReshapeLayer(l_pa_zy_logvar,
                                        (-1, self.sym_samples, 1, n_a))

        self.l_px = ReshapeLayer(l_px_zy, (-1, self.sym_samples, 1, n_x))
        self.l_px_mu = ReshapeLayer(l_px_zy_mu,
                                    (-1, self.sym_samples, 1,
                                     n_x)) if x_dist == "gaussian" else None
        self.l_px_logvar = ReshapeLayer(
            l_px_zy_logvar,
            (-1, self.sym_samples, 1, n_x)) if x_dist == "gaussian" else None

        # Predefined functions
        inputs = [self.sym_x_l, self.sym_samples]
        outputs = get_output(self.l_qy, self.sym_x_l,
                             deterministic=True).mean(axis=(1, 2))
        self.f_qy = theano.function(inputs, outputs)

        inputs = [self.sym_x_l, self.sym_samples]
        outputs = get_output(self.l_qa, self.sym_x_l,
                             deterministic=True).mean(axis=(1, 2))
        self.f_qa = theano.function(inputs, outputs)

        # Define model parameters
        self.model_params = get_all_params([self.l_qy, self.l_pa, self.l_px])
        self.trainable_model_params = get_all_params(
            [self.l_qy, self.l_pa, self.l_px], trainable=True)

    def build_model(self,
                    train_set_unlabeled,
                    train_set_labeled,
                    test_set,
                    validation_set=None):
        """
        Build the auxiliary deep generative model from the initialized hyperparameters.
        Define the lower bound term and compile it into a training function.
        :param train_set_unlabeled: Unlabeled train set containing variables x, t.
        :param train_set_labeled: Unlabeled train set containing variables x, t.
        :param test_set: Test set containing variables x, t.
        :param validation_set: Validation set containing variables x, t.
        :return: train, test, validation function and dicts of arguments.
        """
        super(ADGM, self).build_model(train_set_unlabeled, test_set,
                                      validation_set)

        sh_train_x_l = theano.shared(np.asarray(train_set_labeled[0],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        sh_train_t_l = theano.shared(np.asarray(train_set_labeled[1],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        n = self.sh_train_x.shape[0].astype(
            theano.config.floatX)  # no. of data points
        n_l = sh_train_x_l.shape[0].astype(
            theano.config.floatX)  # no. of labeled data points

        params = get_all_params([self.l_pa, self.l_px], trainable=True)
        grads = [T.zeros(p.shape) for p in params]
        params_qy = get_all_params(
            [self.l_qy],
            trainable=True)[len(get_all_params(self.l_qa, trainable=True)):]
        grads_qy = [T.zeros(p.shape) for p in params_qy]

        # Define the layers for the density estimation used in the lower bound.
        l_log_qa = GaussianLogDensityLayer(self.l_qa, self.l_qa_mu,
                                           self.l_qa_logvar)
        l_log_qz = GaussianLogDensityLayer(self.l_qz, self.l_qz_mu,
                                           self.l_qz_logvar)
        l_log_qy = MultinomialLogDensityLayer(self.l_qy, self.l_y_in, eps=1e-8)

        l_log_pz = StandardNormalLogDensityLayer(self.l_qz)
        l_log_pa = GaussianLogDensityLayer(self.l_qa, self.l_pa_mu,
                                           self.l_pa_logvar)
        if self.x_dist == 'bernoulli':
            l_log_px = BernoulliLogDensityLayer(self.l_px, self.l_x_in)
        elif self.x_dist == 'multinomial':
            l_log_px = MultinomialLogDensityLayer(self.l_px, self.l_x_in)
        elif self.x_dist == 'gaussian':
            l_log_px = GaussianLogDensityLayer(self.l_x_in, self.l_px_mu,
                                               self.l_px_logvar)

        def lower_bound(log_pa, log_qa, log_pz, log_qz, log_py, log_px):
            lb = log_px + log_py + log_pz + log_pa.mean(
                axis=(1, 2), keepdims=True) - log_qa - log_qz
            return lb

        # Lower bound for labeled data
        out_layers = [
            l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px, l_log_qy
        ]
        inputs = {self.l_x_in: self.sym_x_l, self.l_y_in: self.sym_t_l}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_l, log_pz_l, log_qa_x_l, log_qz_axy_l, log_px_zy_l, log_qy_ax_l = out
        # Prior p(y) expecting that all classes are evenly distributed
        py_l = softmax(T.zeros((self.sym_x_l.shape[0], self.n_y)))
        log_py_l = -categorical_crossentropy(py_l, self.sym_t_l).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_l = lower_bound(log_pa_l, log_qa_x_l, log_pz_l, log_qz_axy_l,
                           log_py_l, log_px_zy_l)
        lb_l = lb_l.mean(axis=(1, 2))  # Mean over the sampling dimensions
        log_qy_ax_l *= (
            self.sym_beta * (n / n_l)
        )  # Scale the supervised cross entropy with the alpha constant
        grads_qy_l = T.grad(
            -log_qy_ax_l.mean(),
            params_qy)  # Calculate gradients on q(y|a,x) for labeled data
        grads_l = T.grad(
            lb_l.mean(),
            params)  # Calculate gradients on the remainder for labeled data
        lb_l += log_qy_ax_l.mean(axis=(
            1, 2
        ))  # Collect the lower bound term and mean over sampling dimensions

        # Lower bound for unlabeled data
        bs_u = self.sym_x_u.shape[0]

        # For the integrating out approach, we repeat the input matrix x, and construct a target (bs * n_y) x n_y
        # Example of input and target matrix for a 3 class problem and batch_size=2. 2D tensors of the form
        #               x_repeat                     t_repeat
        #  [[x[0,0], x[0,1], ..., x[0,n_x]]         [[1, 0, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [1, 0, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 1, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [0, 1, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 0, 1]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]]         [0, 0, 1]]
        t_eye = T.eye(self.n_y, k=0)
        t_u = t_eye.reshape((self.n_y, 1, self.n_y)).repeat(bs_u,
                                                            axis=1).reshape(
                                                                (-1, self.n_y))
        x_u = self.sym_x_u.reshape(
            (1, bs_u, self.n_x)).repeat(self.n_y, axis=0).reshape(
                (-1, self.n_x))

        # Since the expectation of var a is outside the integration we calculate E_q(a|x) first
        a_x_u = get_output(self.l_qa,
                           self.sym_x_u,
                           batch_norm_update_averages=True,
                           batch_norm_use_averages=False)
        a_x_u_rep = a_x_u.reshape(
            (1, bs_u * self.sym_samples, self.n_a)).repeat(self.n_y,
                                                           axis=0).reshape(
                                                               (-1, self.n_a))
        out_layers = [l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px]
        inputs = {self.l_x_in: x_u, self.l_y_in: t_u, self.l_a_in: a_x_u_rep}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_u, log_pz_u, log_qa_x_u, log_qz_axy_u, log_px_zy_u = out
        # Prior p(y) expecting that all classes are evenly distributed
        py_u = softmax(T.zeros((bs_u * self.n_y, self.n_y)))
        log_py_u = -categorical_crossentropy(py_u, t_u).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_u = lower_bound(log_pa_u, log_qa_x_u, log_pz_u, log_qz_axy_u,
                           log_py_u, log_px_zy_u)
        lb_u = lb_u.reshape((self.n_y, self.sym_samples, 1,
                             bs_u)).transpose(3, 1, 2, 0).mean(axis=(1, 2))
        inputs = {
            self.l_x_in: self.sym_x_u,
            self.l_a_in: a_x_u.reshape((-1, self.n_a))
        }
        y_u = get_output(self.l_qy,
                         inputs,
                         batch_norm_update_averages=True,
                         batch_norm_use_averages=False).mean(axis=(1, 2))
        y_u += 1e-8  # Ensure that we get no NANs when calculating the entropy
        y_u /= T.sum(y_u, axis=1, keepdims=True)
        grads_u = T.grad(
            (y_u * lb_u).sum(axis=1).mean(),
            params)  # Calculate gradients on the model for unlabeled data
        lb_u = (y_u * (lb_u - T.log(y_u))).sum(axis=1)
        grads_qy_u = T.grad(
            lb_u.mean(),
            params_qy)  # Calculate gradients on q(y|a,x) for unlabeled data

        if self.batchnorm:
            # TODO: implement the BN layer correctly.
            inputs = {
                self.l_x_in: self.sym_x_u,
                self.l_y_in: y_u,
                self.l_a_in: a_x_u
            }
            get_output(out_layers,
                       inputs,
                       weighting=None,
                       batch_norm_update_averages=True,
                       batch_norm_use_averages=False)

        # Regularizing with weight priors p(theta|N(0,1)), collecting and clipping gradients
        weight_priors = 0.0
        for p in params:
            if 'W' not in str(p):
                continue
            weight_priors += log_normal(p, 0, 1).sum()
        grads_priors = T.grad(weight_priors,
                              params,
                              disconnected_inputs='ignore')

        weight_priors_qy = 0.0
        for p in params_qy:
            if 'W' not in str(p):
                continue
            weight_priors_qy += log_normal(p, 0, 1).sum()
        grads_qy_priors = T.grad(weight_priors_qy,
                                 params_qy,
                                 disconnected_inputs='ignore')

        for i in range(len(params)):
            grads[i] = ((grads_l[i] + grads_u[i]) * n + grads_priors[i]) / -n
        for i in range(len(params_qy)):
            grads_qy[i] = (
                (grads_qy_l[i] + grads_qy_u[i]) * n + grads_qy_priors[i]) / -n

        grads_collect = grads + grads_qy
        params_collect = params + params_qy

        # Collect the lower bound and scale it with the weight priors.
        elbo = ((lb_l.mean() + lb_u.mean()) * n + weight_priors +
                weight_priors_qy) / -n
        lb_labeled = -lb_l.mean()
        lb_unlabeled = -lb_u.mean()
        out_px_zy = log_px_zy_u.mean() + log_px_zy_l.mean()
        out_a = (log_pa_l.mean() + log_pa_u.mean()) - (log_qa_x_l.mean() +
                                                       log_qa_x_u.mean())
        out_z = (log_pz_l.mean() + log_pz_u.mean()) - (log_qz_axy_l.mean() +
                                                       log_qz_axy_u.mean())

        # grads = T.grad(elbo, self.trainable_model_params)
        clip_grad, max_norm = 1, 5
        mgrads = total_norm_constraint(grads_collect, max_norm=max_norm)
        mgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        updates = adam(mgrads, params_collect, self.sym_lr, sym_beta1,
                       sym_beta2)

        # Training function
        indices = self._srng.choice(size=[self.sym_bs_l],
                                    a=sh_train_x_l.shape[0],
                                    replace=False)
        x_batch_l = sh_train_x_l[indices]
        t_batch_l = sh_train_t_l[indices]
        x_batch_u = self.sh_train_x[self.batch_slice]
        if self.x_dist == 'bernoulli':  # Sample bernoulli input.
            x_batch_u = self._srng.binomial(size=x_batch_u.shape,
                                            n=1,
                                            p=x_batch_u,
                                            dtype=theano.config.floatX)
            x_batch_l = self._srng.binomial(size=x_batch_l.shape,
                                            n=1,
                                            p=x_batch_l,
                                            dtype=theano.config.floatX)

        givens = {
            self.sym_x_l: x_batch_l,
            self.sym_x_u: x_batch_u,
            self.sym_t_l: t_batch_l
        }
        inputs = [
            self.sym_index, self.sym_batchsize, self.sym_bs_l, self.sym_beta,
            self.sym_lr, sym_beta1, sym_beta2, self.sym_samples
        ]
        outputs = [elbo, lb_labeled, lb_unlabeled, out_px_zy, out_a, out_z]
        f_train = theano.function(inputs=inputs,
                                  outputs=outputs,
                                  givens=givens,
                                  updates=updates)

        # Default training args. Note that these can be changed during or prior to training.
        self.train_args['inputs']['batchsize_unlabeled'] = 100
        self.train_args['inputs']['batchsize_labeled'] = 100
        self.train_args['inputs']['beta'] = 0.1
        self.train_args['inputs']['learningrate'] = 3e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['outputs']['lb'] = '%0.4f'
        self.train_args['outputs']['lb-labeled'] = '%0.4f'
        self.train_args['outputs']['lb-unlabeled'] = '%0.4f'
        self.train_args['outputs']['log(px)'] = '%0.4f'
        self.train_args['outputs']['KL(p(a)||q(a))'] = '%0.4f'
        self.train_args['outputs']['KL(p(z)||q(z))'] = '%0.4f'

        # Validation and test function
        y = get_output(self.l_qy, self.sym_x_l,
                       deterministic=True).mean(axis=(1, 2))
        class_err = (1. - categorical_accuracy(y, self.sym_t_l).mean()) * 100
        givens = {self.sym_x_l: self.sh_test_x, self.sym_t_l: self.sh_test_t}
        f_test = theano.function(inputs=[self.sym_samples],
                                 outputs=[class_err],
                                 givens=givens)

        # Test args.  Note that these can be changed during or prior to training.
        self.test_args['inputs']['samples'] = 1
        self.test_args['outputs']['test'] = '%0.2f%%'

        f_validate = None
        if validation_set is not None:
            givens = {
                self.sym_x_l: self.sh_valid_x,
                self.sym_t_l: self.sh_valid_t
            }
            f_validate = theano.function(inputs=[self.sym_samples],
                                         outputs=[class_err],
                                         givens=givens)
        # Default validation args. Note that these can be changed during or prior to training.
        self.validate_args['inputs']['samples'] = 1
        self.validate_args['outputs']['validation'] = '%0.2f%%'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args

    def get_output(self, x, samples=1):
        return self.f_qy(x, samples)

    def model_info(self):
        qa_shapes = self.get_model_shape(get_all_params(self.l_qa))
        qy_shapes = self.get_model_shape(get_all_params(
            self.l_qy))[len(qa_shapes) - 1:]
        qz_shapes = self.get_model_shape(get_all_params(
            self.l_qz))[len(qa_shapes) - 1:]
        px_shapes = self.get_model_shape(get_all_params(
            self.l_px))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        pa_shapes = self.get_model_shape(get_all_params(
            self.l_pa))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        s = ""
        s += 'batch norm: %s.\n' % (str(self.batchnorm))
        s += 'x distribution: %s.\n' % (str(self.x_dist))
        s += 'model q(a|x): %s.\n' % str(qa_shapes)[1:-1]
        s += 'model q(z|a,x,y): %s.\n' % str(qz_shapes)[1:-1]
        s += 'model q(y|a,x): %s.\n' % str(qy_shapes)[1:-1]
        s += 'model p(x|z,y): %s.\n' % str(px_shapes)[1:-1]
        s += 'model p(a|z,y): %s.' % str(pa_shapes)[1:-1]
        return s
class SDGMSSL(Model):
    """
    The :class:'SDGMSSL' class represents the implementation of the model described in the
    Auxiliary Generative Models article on Arxiv.org.
    """

    def __init__(self, n_x, n_a, n_z, n_y, qa_hid, qz_hid, qy_hid, px_hid, pa_hid, nonlinearity=rectify,
                 px_nonlinearity=None, x_dist='bernoulli', batchnorm=False, seed=1234):
        """
        Initialize an skip deep generative model consisting of
        discriminative classifier q(y|a,x),
        generative model P p(a|z,y) and p(x|a,z,y),
        inference model Q q(a|x) and q(z|a,x,y).
        Weights are initialized using the Bengio and Glorot (2010) initialization scheme.
        :param n_x: Number of inputs.
        :param n_a: Number of auxiliary.
        :param n_z: Number of latent.
        :param n_y: Number of classes.
        :param qa_hid: List of number of deterministic hidden q(a|x).
        :param qz_hid: List of number of deterministic hidden q(z|a,x,y).
        :param qy_hid: List of number of deterministic hidden q(y|a,x).
        :param px_hid: List of number of deterministic hidden p(a|z,y) & p(x|z,y).
        :param nonlinearity: The transfer function used in the deterministic layers.
        :param x_dist: The x distribution, 'bernoulli', 'multinomial', or 'gaussian'.
        :param batchnorm: Boolean value for batch normalization.
        :param seed: The random seed.
        """
        super(SDGMSSL, self).__init__(n_x, qz_hid + px_hid, n_a + n_z, nonlinearity)
        self.x_dist = x_dist
        self.n_y = n_y
        self.n_x = n_x
        self.n_a = n_a
        self.n_z = n_z
        self.batchnorm = batchnorm
        self._srng = RandomStreams(seed)

        # Decide Glorot initializaiton of weights.
        init_w = 1e-3
        hid_w = ""
        if nonlinearity == rectify or nonlinearity == softplus:
            hid_w = "relu"

        # Define symbolic variables for theano functions.
        self.sym_beta = T.scalar('beta')  # scaling constant beta
        self.sym_x_l = T.matrix('x')  # labeled inputs
        self.sym_t_l = T.matrix('t')  # labeled targets
        self.sym_x_u = T.matrix('x')  # unlabeled inputs
        self.sym_bs_l = T.iscalar('bs_l')  # number of labeled data
        self.sym_samples = T.iscalar('samples')  # MC samples
        self.sym_z = T.matrix('z')  # latent variable z
        self.sym_a = T.matrix('a')  # auxiliary variable a

        # Assist methods for collecting the layers
        def dense_layer(layer_in, n, dist_w=init.GlorotNormal, dist_b=init.Normal):
            dense = DenseLayer(layer_in, n, dist_w(hid_w), dist_b(init_w), None)
            if batchnorm:
                dense = BatchNormLayer(dense)
            return NonlinearityLayer(dense, self.transf)

        def stochastic_layer(layer_in, n, samples, nonlin=None):
            mu = DenseLayer(layer_in, n, init.Normal(init_w), init.Normal(init_w), nonlin)
            logvar = DenseLayer(layer_in, n, init.Normal(init_w), init.Normal(init_w), nonlin)
            return SampleLayer(mu, logvar, eq_samples=samples, iw_samples=1), mu, logvar

        # Input layers
        l_x_in = InputLayer((None, n_x))
        l_y_in = InputLayer((None, n_y))

        # Auxiliary q(a|x)
        l_qa_x = l_x_in
        for hid in qa_hid:
            l_qa_x = dense_layer(l_qa_x, hid)
        l_qa_x, l_qa_x_mu, l_qa_x_logvar = stochastic_layer(l_qa_x, n_a, self.sym_samples)

        # Classifier q(y|a,x)
        l_qa_to_qy = DenseLayer(l_qa_x, qy_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_qa_to_qy = ReshapeLayer(l_qa_to_qy, (-1, self.sym_samples, 1, qy_hid[0]))
        l_x_to_qy = DenseLayer(l_x_in, qy_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        self.l_x_to_qy = l_x_to_qy
        l_x_to_qy = DimshuffleLayer(l_x_to_qy, (0, 'x', 'x', 1))
        l_qy_xa = ReshapeLayer(ElemwiseSumLayer([l_qa_to_qy, l_x_to_qy]), (-1, qy_hid[0]))
        if batchnorm:
            l_qy_xa = BatchNormLayer(l_qy_xa)
        l_qy_xa = NonlinearityLayer(l_qy_xa, self.transf)
        if len(qy_hid) > 1:
            for hid in qy_hid[1:]:
                l_qy_xa = dense_layer(l_qy_xa, hid)
        l_qy_xa = DenseLayer(l_qy_xa, n_y, init.GlorotNormal(), init.Normal(init_w), softmax)

        # Recognition q(z|x,a,y)
        l_qa_to_qz = DenseLayer(l_qa_x, qz_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_qa_to_qz = ReshapeLayer(l_qa_to_qz, (-1, self.sym_samples, 1, qz_hid[0]))
        l_x_to_qz = DenseLayer(l_x_in, qz_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_x_to_qz = DimshuffleLayer(l_x_to_qz, (0, 'x', 'x', 1))
        l_y_to_qz = DenseLayer(l_y_in, qz_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_y_to_qz = DimshuffleLayer(l_y_to_qz, (0, 'x', 'x', 1))
        l_qz_axy = ReshapeLayer(ElemwiseSumLayer([l_qa_to_qz, l_x_to_qz, l_y_to_qz]), (-1, qz_hid[0]))
        if batchnorm:
            l_qz_axy = BatchNormLayer(l_qz_axy)
        l_qz_axy = NonlinearityLayer(l_qz_axy, self.transf)
        if len(qz_hid) > 1:
            for hid in qz_hid[1:]:
                l_qz_axy = dense_layer(l_qz_axy, hid)
        l_qz_axy, l_qz_axy_mu, l_qz_axy_logvar = stochastic_layer(l_qz_axy, n_z, 1)

        # Generative p(a|z,y)
        l_y_to_pa = DenseLayer(l_y_in, pa_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_y_to_pa = DimshuffleLayer(l_y_to_pa, (0, 'x', 'x', 1))
        l_qz_to_pa = DenseLayer(l_qz_axy, pa_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_qz_to_pa = ReshapeLayer(l_qz_to_pa, (-1, self.sym_samples, 1, pa_hid[0]))
        l_pa_zy = ReshapeLayer(ElemwiseSumLayer([l_qz_to_pa, l_y_to_pa]), [-1, pa_hid[0]])
        if batchnorm:
            l_pa_zy = BatchNormLayer(l_pa_zy)
        l_pa_zy = NonlinearityLayer(l_pa_zy, self.transf)
        if len(pa_hid) > 1:
            for hid in pa_hid[1:]:
                l_pa_zy = dense_layer(l_pa_zy, hid)
        l_pa_zy, l_pa_zy_mu, l_pa_zy_logvar = stochastic_layer(l_pa_zy, n_a, 1)

        # Generative p(x|a,z,y)
        l_qa_to_px = DenseLayer(l_qa_x, px_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_qa_to_px = ReshapeLayer(l_qa_to_px, (-1, self.sym_samples, 1, px_hid[0]))
        l_y_to_px = DenseLayer(l_y_in, px_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_y_to_px = DimshuffleLayer(l_y_to_px, (0, 'x', 'x', 1))
        l_qz_to_px = DenseLayer(l_qz_axy, px_hid[0], init.GlorotNormal(hid_w), init.Normal(init_w), None)
        l_qz_to_px = ReshapeLayer(l_qz_to_px, (-1, self.sym_samples, 1, px_hid[0]))
        l_px_azy = ReshapeLayer(ElemwiseSumLayer([l_qa_to_px, l_qz_to_px, l_y_to_px]), [-1, px_hid[0]])
        if batchnorm:
            l_px_azy = BatchNormLayer(l_px_azy)
        l_px_azy = NonlinearityLayer(l_px_azy, self.transf)
        if len(px_hid) > 1:
            for hid in px_hid[1:]:
                l_px_azy = dense_layer(l_px_azy, hid)

        if x_dist == 'bernoulli':
            l_px_azy = DenseLayer(l_px_azy, n_x, init.GlorotNormal(), init.Normal(init_w), sigmoid)
        elif x_dist == 'multinomial':
            l_px_azy = DenseLayer(l_px_azy, n_x, init.GlorotNormal(), init.Normal(init_w), softmax)
        elif x_dist == 'gaussian':
            l_px_azy, l_px_zy_mu, l_px_zy_logvar = stochastic_layer(l_px_azy, n_x, 1, px_nonlinearity)

        # Reshape all the model layers to have the same size
        self.l_x_in = l_x_in
        self.l_y_in = l_y_in
        self.l_a_in = l_qa_x

        self.l_qa = ReshapeLayer(l_qa_x, (-1, self.sym_samples, 1, n_a))
        self.l_qa_mu = DimshuffleLayer(l_qa_x_mu, (0, 'x', 'x', 1))
        self.l_qa_logvar = DimshuffleLayer(l_qa_x_logvar, (0, 'x', 'x', 1))

        self.l_qz = ReshapeLayer(l_qz_axy, (-1, self.sym_samples, 1, n_z))
        self.l_qz_mu = ReshapeLayer(l_qz_axy_mu, (-1, self.sym_samples, 1, n_z))
        self.l_qz_logvar = ReshapeLayer(l_qz_axy_logvar, (-1, self.sym_samples, 1, n_z))

        self.l_qy = ReshapeLayer(l_qy_xa, (-1, self.sym_samples, 1, n_y))

        self.l_pa = ReshapeLayer(l_pa_zy, (-1, self.sym_samples, 1, n_a))
        self.l_pa_mu = ReshapeLayer(l_pa_zy_mu, (-1, self.sym_samples, 1, n_a))
        self.l_pa_logvar = ReshapeLayer(l_pa_zy_logvar, (-1, self.sym_samples, 1, n_a))

        self.l_px = ReshapeLayer(l_px_azy, (-1, self.sym_samples, 1, n_x))
        self.l_px_mu = ReshapeLayer(l_px_zy_mu, (-1, self.sym_samples, 1, n_x)) if x_dist == "gaussian" else None
        self.l_px_logvar = ReshapeLayer(l_px_zy_logvar,
                                        (-1, self.sym_samples, 1, n_x)) if x_dist == "gaussian" else None

        # Predefined functions
        inputs = [self.sym_x_l, self.sym_samples]
        outputs = get_output(self.l_qy, self.sym_x_l, deterministic=True).mean(axis=(1, 2))
        self.f_qy = theano.function(inputs, outputs)

        inputs = [self.sym_x_l, self.sym_samples]
        outputs = get_output(self.l_qa, self.sym_x_l, deterministic=True).mean(axis=(1, 2))
        self.f_qa = theano.function(inputs, outputs)

        inputs = {l_qz_axy: self.sym_z, l_y_in: self.sym_t_l}
        outputs = get_output(self.l_pa, inputs, deterministic=True)
        self.f_pa = theano.function([self.sym_z, self.sym_t_l, self.sym_samples], outputs)

        inputs = {l_qa_x: self.sym_a, l_qz_axy: self.sym_z, l_y_in: self.sym_t_l}
        outputs = get_output(self.l_px, inputs, deterministic=True)
        self.f_px = theano.function([self.sym_a, self.sym_z, self.sym_t_l, self.sym_samples], outputs)

        # Define model parameters
        self.model_params = get_all_params([self.l_qy, self.l_pa, self.l_px])
        self.trainable_model_params = get_all_params([self.l_qy, self.l_pa, self.l_px], trainable=True)

    def build_model(self, train_set_unlabeled, train_set_labeled, test_set, validation_set=None):
        """
        Build the auxiliary deep generative model from the initialized hyperparameters.
        Define the lower bound term and compile it into a training function.
        :param train_set_unlabeled: Unlabeled train set containing variables x, t.
        :param train_set_labeled: Unlabeled train set containing variables x, t.
        :param test_set: Test set containing variables x, t.
        :param validation_set: Validation set containing variables x, t.
        :return: train, test, validation function and dicts of arguments.
        """
        super(SDGMSSL, self).build_model(train_set_unlabeled, test_set, validation_set)

        sh_train_x_l = theano.shared(np.asarray(train_set_labeled[0], dtype=theano.config.floatX), borrow=True)
        sh_train_t_l = theano.shared(np.asarray(train_set_labeled[1], dtype=theano.config.floatX), borrow=True)
        n = self.sh_train_x.shape[0].astype(theano.config.floatX)  # no. of data points
        n_l = sh_train_x_l.shape[0].astype(theano.config.floatX)  # no. of labeled data points

        # Define the layers for the density estimation used in the lower bound.
        l_log_qa = GaussianLogDensityLayer(self.l_qa, self.l_qa_mu, self.l_qa_logvar)
        l_log_qz = GaussianLogDensityLayer(self.l_qz, self.l_qz_mu, self.l_qz_logvar)
        l_log_qy = MultinomialLogDensityLayer(self.l_qy, self.l_y_in, eps=1e-8)

        l_log_pz = StandardNormalLogDensityLayer(self.l_qz)
        l_log_pa = GaussianLogDensityLayer(self.l_qa, self.l_pa_mu, self.l_pa_logvar)
        if self.x_dist == 'bernoulli':
            l_log_px = BernoulliLogDensityLayer(self.l_px, self.l_x_in)
        elif self.x_dist == 'multinomial':
            l_log_px = MultinomialLogDensityLayer(self.l_px, self.l_x_in)
        elif self.x_dist == 'gaussian':
            l_log_px = GaussianLogDensityLayer(self.l_x_in, self.l_px_mu, self.l_px_logvar)

        def lower_bound(log_pa, log_qa, log_pz, log_qz, log_py, log_px):
            lb = log_px + log_py + log_pz + log_pa - log_qa - log_qz
            return lb

        # Lower bound for labeled data
        out_layers = [l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px, l_log_qy]
        inputs = {self.l_x_in: self.sym_x_l, self.l_y_in: self.sym_t_l}
        out = get_output(out_layers, inputs, batch_norm_update_averages=False, batch_norm_use_averages=False)
        log_pa_l, log_pz_l, log_qa_x_l, log_qz_axy_l, log_px_zy_l, log_qy_ax_l = out
        # Prior p(y) expecting that all classes are evenly distributed
        py_l = softmax(T.zeros((self.sym_x_l.shape[0], self.n_y)))
        log_py_l = -categorical_crossentropy(py_l, self.sym_t_l).reshape((-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_l = lower_bound(log_pa_l, log_qa_x_l, log_pz_l, log_qz_axy_l, log_py_l, log_px_zy_l)
        lb_l = lb_l.mean(axis=(1, 2))  # Mean over the sampling dimensions
        log_qy_ax_l *= (self.sym_beta * (n / n_l))  # Scale the supervised cross entropy with the alpha constant
        lb_l -= log_qy_ax_l.mean(axis=(1, 2))  # Collect the lower bound term and mean over sampling dimensions

        # Lower bound for unlabeled data
        bs_u = self.sym_x_u.shape[0]

        # For the integrating out approach, we repeat the input matrix x, and construct a target (bs * n_y) x n_y
        # Example of input and target matrix for a 3 class problem and batch_size=2. 2D tensors of the form
        #               x_repeat                     t_repeat
        #  [[x[0,0], x[0,1], ..., x[0,n_x]]         [[1, 0, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [1, 0, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 1, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [0, 1, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 0, 1]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]]         [0, 0, 1]]
        t_eye = T.eye(self.n_y, k=0)
        t_u = t_eye.reshape((self.n_y, 1, self.n_y)).repeat(bs_u, axis=1).reshape((-1, self.n_y))
        x_u = self.sym_x_u.reshape((1, bs_u, self.n_x)).repeat(self.n_y, axis=0).reshape((-1, self.n_x))

        # Since the expectation of var a is outside the integration we calculate E_q(a|x) first
        a_x_u = get_output(self.l_qa, self.sym_x_u, batch_norm_update_averages=True, batch_norm_use_averages=False)
        a_x_u_rep = a_x_u.reshape((1, bs_u * self.sym_samples, self.n_a)).repeat(self.n_y, axis=0).reshape(
            (-1, self.n_a))
        out_layers = [l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px]
        inputs = {self.l_x_in: x_u, self.l_y_in: t_u, self.l_a_in: a_x_u_rep}
        out = get_output(out_layers, inputs, batch_norm_update_averages=False, batch_norm_use_averages=False)
        log_pa_u, log_pz_u, log_qa_x_u, log_qz_axy_u, log_px_zy_u = out
        # Prior p(y) expecting that all classes are evenly distributed
        py_u = softmax(T.zeros((bs_u * self.n_y, self.n_y)))
        log_py_u = -categorical_crossentropy(py_u, t_u).reshape((-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_u = lower_bound(log_pa_u, log_qa_x_u, log_pz_u, log_qz_axy_u, log_py_u, log_px_zy_u)
        lb_u = lb_u.reshape((self.n_y, 1, 1, bs_u)).transpose(3, 1, 2, 0).mean(axis=(1, 2))
        inputs = {self.l_x_in: self.sym_x_u, self.l_a_in: a_x_u.reshape((-1, self.n_a))}
        y_u = get_output(self.l_qy, inputs, batch_norm_update_averages=True, batch_norm_use_averages=False).mean(
            axis=(1, 2))
        y_u += 1e-8  # Ensure that we get no NANs when calculating the entropy
        y_u /= T.sum(y_u, axis=1, keepdims=True)
        lb_u = (y_u * (lb_u - T.log(y_u))).sum(axis=1)

        if self.batchnorm:
            # TODO: implement the BN layer correctly.
            inputs = {self.l_x_in: self.sym_x_u, self.l_y_in: y_u, self.l_a_in: a_x_u}
            get_output(out_layers, inputs, weighting=None, batch_norm_update_averages=True,
                       batch_norm_use_averages=False)

        # Regularizing with weight priors p(theta|N(0,1)), collecting and clipping gradients
        weight_priors = 0.0
        for p in self.trainable_model_params:
            if 'W' not in str(p):
                continue
            weight_priors += log_normal(p, 0, 1).sum()

        # Collect the lower bound and scale it with the weight priors.
        elbo = ((lb_l.mean() + lb_u.mean()) * n + weight_priors) / -n
        lb_labeled = -lb_l.mean()
        lb_unlabeled = -lb_u.mean()

        grads_collect = T.grad(elbo, self.trainable_model_params)
        params_collect = self.trainable_model_params
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        clip_grad, max_norm = 1, 5
        mgrads = total_norm_constraint(grads_collect, max_norm=max_norm)
        mgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
        updates = adam(mgrads, params_collect, self.sym_lr, sym_beta1, sym_beta2)

        # Training function
        indices = self._srng.choice(size=[self.sym_bs_l], a=sh_train_x_l.shape[0], replace=False)
        x_batch_l = sh_train_x_l[indices]
        t_batch_l = sh_train_t_l[indices]
        x_batch_u = self.sh_train_x[self.batch_slice]
        if self.x_dist == 'bernoulli':  # Sample bernoulli input.
            x_batch_u = self._srng.binomial(size=x_batch_u.shape, n=1, p=x_batch_u, dtype=theano.config.floatX)
            x_batch_l = self._srng.binomial(size=x_batch_l.shape, n=1, p=x_batch_l, dtype=theano.config.floatX)

        givens = {self.sym_x_l: x_batch_l,
                  self.sym_x_u: x_batch_u,
                  self.sym_t_l: t_batch_l}
        inputs = [self.sym_index, self.sym_batchsize, self.sym_bs_l, self.sym_beta,
                  self.sym_lr, sym_beta1, sym_beta2, self.sym_samples]
        outputs = [elbo, lb_labeled, lb_unlabeled]
        f_train = theano.function(inputs=inputs, outputs=outputs, givens=givens, updates=updates)

        # Default training args. Note that these can be changed during or prior to training.
        self.train_args['inputs']['batchsize_unlabeled'] = 100
        self.train_args['inputs']['batchsize_labeled'] = 100
        self.train_args['inputs']['beta'] = 0.1
        self.train_args['inputs']['learningrate'] = 3e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['outputs']['lb'] = '%0.4f'
        self.train_args['outputs']['lb-labeled'] = '%0.4f'
        self.train_args['outputs']['lb-unlabeled'] = '%0.4f'

        # Validation and test function
        y = get_output(self.l_qy, self.sym_x_l, deterministic=True).mean(axis=(1, 2))
        class_err = (1. - categorical_accuracy(y, self.sym_t_l).mean()) * 100
        givens = {self.sym_x_l: self.sh_test_x,
                  self.sym_t_l: self.sh_test_t}
        f_test = theano.function(inputs=[self.sym_samples], outputs=[class_err], givens=givens)

        # Test args.  Note that these can be changed during or prior to training.
        self.test_args['inputs']['samples'] = 1
        self.test_args['outputs']['test'] = '%0.2f%%'

        f_validate = None
        if validation_set is not None:
            givens = {self.sym_x_l: self.sh_valid_x,
                      self.sym_t_l: self.sh_valid_t}
            f_validate = theano.function(inputs=[self.sym_samples], outputs=[class_err], givens=givens)
        # Default validation args. Note that these can be changed during or prior to training.
        self.validate_args['inputs']['samples'] = 1
        self.validate_args['outputs']['validation'] = '%0.2f%%'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args

    def get_output(self, x, samples=1):
        return self.f_qy(x, samples)

    def model_info(self):
        qa_shapes = self.get_model_shape(get_all_params(self.l_qa))
        qy_shapes = self.get_model_shape(get_all_params(self.l_qy))[len(qa_shapes) - 1:]
        qz_shapes = self.get_model_shape(get_all_params(self.l_qz))[len(qa_shapes) - 1:]
        px_shapes = self.get_model_shape(get_all_params(self.l_px))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        pa_shapes = self.get_model_shape(get_all_params(self.l_pa))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        s = ""
        s += 'batch norm: %s.\n' % (str(self.batchnorm))
        s += 'x distribution: %s.\n' % (str(self.x_dist))
        s += 'model q(a|x): %s.\n' % str(qa_shapes)[1:-1]
        s += 'model q(z|a,x,y): %s.\n' % str(qz_shapes)[1:-1]
        s += 'model q(y|a,x): %s.\n' % str(qy_shapes)[1:-1]
        s += 'model p(x|a,z,y): %s.\n' % str(px_shapes)[1:-1]
        s += 'model p(a|z,y): %s.' % str(pa_shapes)[1:-1]
        return s
class GibbsRegressor(ISymbolicPredictor):
    def __init__(self,
                 n_dim_in,
                 n_dim_out,
                 sample_y=False,
                 n_alpha=1,
                 possible_ws=[0, 1],
                 alpha_update_policy='sequential',
                 seed=None):
        self._w = theano.shared(np.zeros((n_dim_in, n_dim_out),
                                         dtype=theano.config.floatX),
                                name='w')
        self._rng = RandomStreams(seed)
        if n_alpha == 'all':
            n_alpha = n_dim_in
        self._n_alpha = n_alpha
        self._alpha = theano.shared(np.arange(n_alpha))  # scalar
        self._sample_y = sample_y
        self._possible_ws = theano.shared(np.array(possible_ws),
                                          name='possible_ws')
        assert alpha_update_policy in ('sequential', 'random')
        self._alpha_update_policy = alpha_update_policy

    def _get_alpha_update(self):
        new_alpha = (self._alpha+self._n_alpha) % self._w.shape[0] \
            if self._alpha_update_policy == 'sequential' else \
            self._rng.choice(a=self._w.shape[0], size = (self._n_alpha, ), replace = False).reshape([-1])  # Reshape is for some reason necessary when n_alpha=1
        return self._alpha, new_alpha

    @staticmethod
    def compute_p_wa(w, x, y, alpha, possible_ws=np.array([0, 1])):
        """
        Compute the probability the weights at index alpha taking on each of the values in possible_ws
        """
        assert x.tag.test_value.ndim == y.tag.test_value.ndim == 2
        assert x.tag.test_value.shape[0] == y.tag.test_value.shape[0]
        assert w.get_value().shape[1] == y.tag.test_value.shape[1]
        v_current = x.dot(w)  # (n_samples, n_dim_out)
        v_0 = v_current[None, :, :] - w[alpha, None, :] * x.T[
            alpha, :, None]  # (n_alpha, n_samples, n_dim_out)
        possible_vs = v_0[:, :, :, None] + possible_ws[
            None, None, None, :] * x.T[
                alpha, :, None,
                None]  # (n_alpha, n_samples, n_dim_out, n_possible_ws)
        all_zs = tt.nnet.sigmoid(
            possible_vs)  # (n_alpha, n_samples, n_dim_out, n_possible_ws)
        log_likelihoods = tt.sum(tt.log(
            bernoulli(y[None, :, :, None], all_zs[:, :, :, :])),
                                 axis=1)  # (n_alpha, n_dim_out, n_possible_ws)
        # Question: Need to shift for stability here or will Theano take care of that?
        # Stupid theano didn't implement softmax very nicely so we have to do some reshaping.
        return tt.nnet.softmax(log_likelihoods.reshape([alpha.shape[0]*w.shape[1], possible_ws.shape[0]]))\
            .reshape([alpha.shape[0], w.shape[1], possible_ws.shape[0]])  # (n_alpha, n_dim_out, n_possible_ws)

    @symbolic_updater
    def train(self, x, y):
        p_wa = self.compute_p_wa(
            self._w, x, y, self._alpha,
            self._possible_ws)  # (n_alpha, n_dim_out, n_possible_ws)
        w_sample = sample_categorical(self._rng,
                                      p_wa,
                                      values=self._possible_ws)
        w_new = tt.set_subtensor(self._w[self._alpha],
                                 w_sample)  # (n_dim_in, n_dim_out)
        return [(self._w, w_new), self._get_alpha_update()]

    @symbolic_stateless
    def predict(self, x):
        p_y = tt.nnet.sigmoid(x.dot(self._w))
        return self._rng.binomial(p=p_y) if self._sample_y else p_y
Exemple #23
0
class ConvSDGMSSL(Model):
    """
    The :class:'SDGMSSL' class represents the implementation of the model described in the
    Auxiliary Generative Models article on Arxiv.org.
    """
    def __init__(self,
                 input_size,
                 n_a,
                 n_z,
                 n_y,
                 qa_hid,
                 qz_hid,
                 qy_hid,
                 px_hid,
                 pa_hid,
                 nonlinearity=rectify,
                 n_mi_features=0,
                 dropout_prob=0.0,
                 px_nonlinearity=None,
                 x_dist='bernoulli',
                 batchnorm=False,
                 seed=1234,
                 conv_output_size=512):

        super(ConvSDGMSSL, self).__init__(input_size**2, qz_hid + px_hid,
                                          n_a + n_z, nonlinearity)
        self.x_dist = x_dist
        self.n_y = n_y
        self.input_size = input_size
        self.n_mi_features = n_mi_features
        self.n_a = n_a
        self.n_z = n_z
        self.batchnorm = batchnorm
        self._srng = RandomStreams(seed)

        # Decide Glorot initializaiton of weights.
        init_w = 1e-3
        hid_w = ""
        if nonlinearity == rectify or nonlinearity == softplus:
            hid_w = "relu"

        # Define symbolic variables for theano functions.
        self.sym_beta = T.scalar('beta')  # scaling constant beta
        self.sym_x_l = T.matrix('x')  # labeled inputs
        self.sym_t_l = T.matrix('t')  # labeled targets
        self.sym_x_u = T.matrix('x')  # unlabeled inputs
        self.sym_bs_l = T.iscalar('bs_l')  # number of labeled data
        self.sym_samples = T.iscalar('samples')  # MC samples
        self.sym_z = T.matrix('z')  # latent variable z
        self.sym_a = T.matrix('a')  # auxiliary variable a

        # Assist methods for collecting the layers
        def dense_layer(layer_in,
                        n,
                        dist_w=init.GlorotNormal,
                        dist_b=init.Normal):
            dense = DenseLayer(layer_in, n, dist_w(hid_w), dist_b(init_w),
                               None)
            if batchnorm:
                dense = BatchNormLayer(dense)
            if dropout_prob != 0.0:
                dense = DropoutLayer(dense, dropout_prob)
            return NonlinearityLayer(dense, self.transf)

        def stochastic_layer(layer_in, n, samples, nonlin=None):
            mu = DenseLayer(layer_in, n, init.Normal(init_w),
                            init.Normal(init_w), nonlin)
            logvar = DenseLayer(layer_in, n, init.Normal(init_w),
                                init.Normal(init_w), nonlin)
            return SampleLayer(mu, logvar, eq_samples=samples,
                               iw_samples=1), mu, logvar

        # Return the number of elements in a tensor, ignoring the first (batch size)
        # axis
        def num_elems(tensor):
            try:
                num_elems = 1
                for val in tensor.output_shape[1:]:
                    num_elems *= val
                return num_elems
            except:
                return -2

        #
        # Functions that define the convolutional and deconvolutional sections of the
        # networks (they are opposites of one another)
        #

        def conv_net(input_layer):
            if self.n_mi_features != 0:
                conv_input = SliceLayer(
                    input_layer,
                    indices=slice(0,
                                  input_layer.shape[1] - self.n_mi_features))
                mi_input = SliceLayer(
                    input_layer,
                    indices=slice(input_layer.shape[1] - self.n_mi_features,
                                  None))
            else:
                conv_input = input_layer
                mi_input = None

            conv_input = ReshapeLayer(
                conv_input, (-1, 1, self.input_size, self.input_size))

            conv_layer_output_shapes = []
            output = Conv2DLayer(conv_input, 64, 5, stride=2, pad='same')
            conv_layer_output_shapes.append(output.output_shape[2])
            output = Conv2DLayer(output, 128, 5, stride=2, pad='same')
            conv_layer_output_shapes.append(output.output_shape[2])
            output = ReshapeLayer(output, (-1, num_elems(output)))
            if mi_input is not None:
                output = ConcatLayer([output, mi_input], axis=1)
            output = BatchNormLayer(DenseLayer(output, conv_output_size))
            return output, conv_layer_output_shapes

        def deconv_net(input_layer, conv_layer_output_shapes):
            output = BatchNormLayer(
                DenseLayer(input_layer, 128 * 7 * 7 + self.n_mi_features))
            if self.n_mi_features != 0:
                deconv_input = SliceLayer(output,
                                          indices=slice(0, 128 * 7 * 7))
                mi_features = SliceLayer(output,
                                         indices=slice(
                                             128 * 7 * 7,
                                             128 * 7 * 7 + self.n_mi_features))

            else:
                deconv_input = output
                mi_features = None

            output = ReshapeLayer(deconv_input, (-1, 128, 7, 7))
            output = TransposedConv2DLayer(
                output,
                64,
                5,
                stride=2,
                crop='same',
                output_size=conv_layer_output_shapes[0])
            output = TransposedConv2DLayer(output,
                                           1,
                                           5,
                                           stride=2,
                                           crop='same',
                                           output_size=self.input_size,
                                           nonlinearity=sigmoid)
            output = ReshapeLayer(output, (-1, self.input_size**2))

            if mi_features is not None:
                output = ConcatLayer([output, mi_features], axis=1)

            return output

        # Input layers
        l_x_in = InputLayer((None, self.input_size**2 + self.n_mi_features))
        l_y_in = InputLayer((None, n_y))

        # Reshape x to be square 2d array so that can keep using previous implementation of the
        # integration over y (see build_model)

        ############################################################################
        #                                Auxiliary q(a|x)                          #
        ############################################################################

        # Two convolutional layers. Can add batch norm or change nonlinearity to lrelu
        l_qa_x, conv_layer_output_shapes = conv_net(l_x_in)
        # Add mutual information features

        if len(qa_hid) > 1:
            for hid in qy_hid[1:]:
                l_qy_xa = dense_layer(l_qa_x, hid)
        l_qa_x, l_qa_x_mu, l_qa_x_logvar = stochastic_layer(
            l_qa_x, n_a, self.sym_samples)

        ############################################################################

        ############################################################################
        #                                Classifier q(y|a,x)                       #
        ############################################################################

        # Dense layers for input a
        l_qa_to_qy = dense_layer(l_qa_x, conv_output_size)
        l_qa_to_qy = ReshapeLayer(l_qa_to_qy,
                                  (-1, self.sym_samples, 1, conv_output_size))

        # Convolutional layers for input x
        l_x_to_qy, _ = conv_net(l_x_in)
        l_x_to_qy = DimshuffleLayer(l_x_to_qy, (0, 'x', 'x', 1))

        # Combine layers from x and a
        l_qy_xa = ReshapeLayer(ElemwiseSumLayer([l_qa_to_qy, l_x_to_qy]),
                               (-1, conv_output_size))
        if batchnorm:
            l_qy_xa = BatchNormLayer(l_qy_xa)

        if len(qy_hid) > 1:
            for hid in qy_hid[1:]:
                l_qy_xa = dense_layer(l_qy_xa, hid)

        l_qy_xa = DenseLayer(l_qy_xa, n_y, init.GlorotNormal(),
                             init.Normal(init_w), softmax)

        #############################################################################

        ############################################################################
        #                            Recognition q(z|x,a,y)                        #
        ############################################################################

        # Dense layers for a
        l_qa_to_qz = DenseLayer(l_qa_x, conv_output_size,
                                init.GlorotNormal(hid_w), init.Normal(init_w),
                                None)
        l_qa_to_qz = ReshapeLayer(l_qa_to_qz,
                                  (-1, self.sym_samples, 1, conv_output_size))

        # Convolutional layers for x
        l_x_to_qz, _ = conv_net(l_x_in)
        l_x_to_qz = DimshuffleLayer(l_x_to_qz, (0, 'x', 'x', 1))

        # Dense layers for y
        l_y_to_qz = DenseLayer(l_y_in, conv_output_size,
                               init.GlorotNormal(hid_w), init.Normal(init_w),
                               None)
        l_y_to_qz = DimshuffleLayer(l_y_to_qz, (0, 'x', 'x', 1))

        # Combine layers from a, x, and y
        l_qz_axy = ReshapeLayer(
            ElemwiseSumLayer([l_qa_to_qz, l_x_to_qz, l_y_to_qz]),
            (-1, conv_output_size))
        if batchnorm:
            l_qz_axy = BatchNormLayer(l_qz_axy)
        if len(qz_hid) > 1:
            for hid in pa_hid[1:]:
                l_qz_axy = dense_layer(l_qz_axy, hid)

        l_qz_axy, l_qz_axy_mu, l_qz_axy_logvar = stochastic_layer(
            l_qz_axy, n_z, 1)

        ############################################################################

        ############################################################################
        #                            Generative p(a|z,y)                           #
        ############################################################################

        l_y_to_pa = DenseLayer(l_y_in, pa_hid[0], init.GlorotNormal(hid_w),
                               init.Normal(init_w), None)
        l_y_to_pa = DimshuffleLayer(l_y_to_pa, (0, 'x', 'x', 1))
        l_qz_to_pa = DenseLayer(l_qz_axy, pa_hid[0], init.GlorotNormal(hid_w),
                                init.Normal(init_w), None)
        l_qz_to_pa = ReshapeLayer(l_qz_to_pa,
                                  (-1, self.sym_samples, 1, pa_hid[0]))
        l_pa_zy = ReshapeLayer(ElemwiseSumLayer([l_qz_to_pa, l_y_to_pa]),
                               [-1, pa_hid[0]])
        if batchnorm:
            l_pa_zy = BatchNormLayer(l_pa_zy)
        l_pa_zy = NonlinearityLayer(l_pa_zy, self.transf)
        if len(pa_hid) > 1:
            for hid in pa_hid[1:]:
                l_pa_zy = dense_layer(l_pa_zy, hid)
        l_pa_zy, l_pa_zy_mu, l_pa_zy_logvar = stochastic_layer(l_pa_zy, n_a, 1)

        ############################################################################

        ############################################################################
        #                            Generative p(x|a,z,y)                         #
        ############################################################################

        # Pass a,y,z through dense layers
        l_qa_to_px = DenseLayer(l_qa_x, conv_output_size,
                                init.GlorotNormal(hid_w), init.Normal(init_w),
                                None)
        l_qa_to_px = ReshapeLayer(l_qa_to_px,
                                  (-1, self.sym_samples, 1, conv_output_size))
        l_y_to_px = DenseLayer(l_y_in, conv_output_size,
                               init.GlorotNormal(hid_w), init.Normal(init_w),
                               None)
        l_y_to_px = DimshuffleLayer(l_y_to_px, (0, 'x', 'x', 1))
        l_qz_to_px = DenseLayer(l_qz_axy, conv_output_size,
                                init.GlorotNormal(hid_w), init.Normal(init_w),
                                None)
        l_qz_to_px = ReshapeLayer(l_qz_to_px,
                                  (-1, self.sym_samples, 1, conv_output_size))

        # Combine the results
        l_px_azy = ReshapeLayer(
            ElemwiseSumLayer([l_qa_to_px, l_qz_to_px, l_y_to_px]),
            [-1, conv_output_size])
        #if batchnorm:
        #    l_px_azy = BatchNormLayer(l_px_azy)
        l_px_azy = NonlinearityLayer(l_px_azy, self.transf)

        # Generate x using transposed convolutional layers
        l_px_azy = deconv_net(l_px_azy, conv_layer_output_shapes)
        l_px_azy = ReshapeLayer(l_px_azy,
                                (-1, self.input_size**2 + self.n_mi_features))

        if x_dist == 'bernoulli':
            l_px_azy = DenseLayer(l_px_azy,
                                  self.input_size**2 + self.n_mi_features,
                                  init.GlorotNormal(), init.Normal(init_w),
                                  sigmoid)
        elif x_dist == 'multinomial':
            l_px_azy = DenseLayer(l_px_azy,
                                  self.input_size**2 + self.n_mi_features,
                                  init.GlorotNormal(), init.Normal(init_w),
                                  softmax)
        elif x_dist == 'gaussian':
            l_px_azy, l_px_zy_mu, l_px_zy_logvar = stochastic_layer(
                l_px_azy, self.input_size**2 + self.n_mi_features, 1,
                px_nonlinearity)

        ############################################################################

        # Reshape all the model layers to have the same size
        self.l_x_in = l_x_in
        self.l_y_in = l_y_in
        self.l_a_in = l_qa_x

        # Output of the auxiliary network q(a|x)
        self.l_qa = ReshapeLayer(l_qa_x, (-1, self.sym_samples, 1, n_a))
        self.l_qa_mu = DimshuffleLayer(l_qa_x_mu, (0, 'x', 'x', 1))
        self.l_qa_logvar = DimshuffleLayer(l_qa_x_logvar, (0, 'x', 'x', 1))

        # Output of the recognition network q(z|x,a,y)
        self.l_qz = ReshapeLayer(l_qz_axy, (-1, self.sym_samples, 1, n_z))
        self.l_qz_mu = ReshapeLayer(l_qz_axy_mu,
                                    (-1, self.sym_samples, 1, n_z))
        self.l_qz_logvar = ReshapeLayer(l_qz_axy_logvar,
                                        (-1, self.sym_samples, 1, n_z))

        # Output of the classifier network q(y|a,x)
        self.l_qy = ReshapeLayer(l_qy_xa, (-1, self.sym_samples, 1, n_y))

        # Output of the generative network p(a|z,y)
        self.l_pa = ReshapeLayer(l_pa_zy, (-1, self.sym_samples, 1, n_a))
        self.l_pa_mu = ReshapeLayer(l_pa_zy_mu, (-1, self.sym_samples, 1, n_a))
        self.l_pa_logvar = ReshapeLayer(l_pa_zy_logvar,
                                        (-1, self.sym_samples, 1, n_a))

        # Output of the generative network p(x|a,z,y)
        self.l_px = ReshapeLayer(
            l_px_azy,
            (-1, self.sym_samples, 1, self.input_size**2 + self.n_mi_features))
        self.l_px_mu = ReshapeLayer(
            l_px_zy_mu, (-1, self.sym_samples, 1, self.input_size**2 +
                         self.n_mi_features)) if x_dist == "gaussian" else None
        self.l_px_logvar = ReshapeLayer(
            l_px_zy_logvar,
            (-1, self.sym_samples, 1, self.input_size**2 +
             self.n_mi_features)) if x_dist == "gaussian" else None

        # Predefined functions

        # Classifier
        inputs = [self.sym_x_l, self.sym_samples]
        outputs = get_output(self.l_qy, self.sym_x_l,
                             deterministic=True).mean(axis=(1, 2))
        self.f_qy = theano.function(inputs, outputs)

        # Auxiliary
        inputs = [self.sym_x_l, self.sym_samples]
        outputs = get_output(self.l_qa, self.sym_x_l,
                             deterministic=True).mean(axis=(1, 2))
        self.f_qa = theano.function(inputs, outputs)

        #
        inputs = {l_qz_axy: self.sym_z, l_y_in: self.sym_t_l}
        outputs = get_output(self.l_pa, inputs, deterministic=True)
        self.f_pa = theano.function(
            [self.sym_z, self.sym_t_l, self.sym_samples], outputs)

        inputs = {
            l_qa_x: self.sym_a,
            l_qz_axy: self.sym_z,
            l_y_in: self.sym_t_l
        }
        outputs = get_output(self.l_px, inputs, deterministic=True)
        self.f_px = theano.function(
            [self.sym_a, self.sym_z, self.sym_t_l, self.sym_samples], outputs)

        # Define model parameters
        self.model_params = get_all_params([self.l_qy, self.l_pa, self.l_px])
        self.trainable_model_params = get_all_params(
            [self.l_qy, self.l_pa, self.l_px], trainable=True)

    def build_model(self,
                    train_set_unlabeled,
                    train_set_labeled,
                    test_set,
                    validation_set=None):
        """
        Build the auxiliary deep generative model from the initialized hyperparameters.
        Define the lower bound term and compile it into a training function.
        :param train_set_unlabeled: Unlabeled train set containing variables x, t.
        :param train_set_labeled: Unlabeled train set containing variables x, t.
        :param test_set: Test set containing variables x, t.
        :param validation_set: Validation set containing variables x, t.
        :return: train, test, validation function and dicts of arguments.
        """
        super(ConvSDGMSSL, self).build_model(train_set_unlabeled, test_set,
                                             validation_set)

        sh_train_x_l = theano.shared(np.asarray(train_set_labeled[0],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        sh_train_t_l = theano.shared(np.asarray(train_set_labeled[1],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        n = self.sh_train_x.shape[0].astype(
            theano.config.floatX)  # no. of data points
        n_l = sh_train_x_l.shape[0].astype(
            theano.config.floatX)  # no. of labeled data points

        # Define the layers for the density estimation used in the lower bound.
        l_log_qa = GaussianLogDensityLayer(self.l_qa, self.l_qa_mu,
                                           self.l_qa_logvar)
        l_log_qz = GaussianLogDensityLayer(self.l_qz, self.l_qz_mu,
                                           self.l_qz_logvar)
        l_log_qy = MultinomialLogDensityLayer(self.l_qy, self.l_y_in, eps=1e-8)

        l_log_pz = StandardNormalLogDensityLayer(self.l_qz)
        l_log_pa = GaussianLogDensityLayer(self.l_qa, self.l_pa_mu,
                                           self.l_pa_logvar)
        if self.x_dist == 'bernoulli':
            l_log_px = BernoulliLogDensityLayer(self.l_px, self.l_x_in)
        elif self.x_dist == 'multinomial':
            l_log_px = MultinomialLogDensityLayer(self.l_px, setlf.l_x_in)
        elif self.x_dist == 'gaussian':
            l_log_px = GaussianLogDensityLayer(self.l_x_in, self.l_px_mu,
                                               self.l_px_logvar)

        def lower_bound(log_pa, log_qa, log_pz, log_qz, log_py, log_px):
            lb = log_px + log_py + log_pz + log_pa - log_qa - log_qz
            return lb

        # Lower bound for labeled data
        out_layers = [
            l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px, l_log_qy
        ]
        inputs = {self.l_x_in: self.sym_x_l, self.l_y_in: self.sym_t_l}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_l, log_pz_l, log_qa_x_l, log_qz_axy_l, log_px_zy_l, log_qy_ax_l = out
        # Prior p(y) expecting that all classes are evenly distributed
        py_l = softmax(T.zeros((self.sym_x_l.shape[0], self.n_y)))
        log_py_l = -categorical_crossentropy(py_l, self.sym_t_l).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_l = lower_bound(log_pa_l, log_qa_x_l, log_pz_l, log_qz_axy_l,
                           log_py_l, log_px_zy_l)
        lb_l = lb_l.mean(axis=(1, 2))  # Mean over the sampling dimensions
        log_qy_ax_l *= (
            self.sym_beta * (n / n_l)
        )  # Scale the supervised cross entropy with the alpha constant
        lb_l -= log_qy_ax_l.mean(axis=(
            1, 2
        ))  # Collect the lower bound term and mean over sampling dimensions

        # Lower bound for unlabeled data
        bs_u = self.sym_x_u.shape[0]

        # For the integrating out approach, we repeat the input matrix x, and construct a target (bs * n_y) x n_y
        # Example of input and target matrix for a 3 class problem and batch_size=2. 2D tensors of the form
        #               x_repeat                     t_repeat
        #  [[x[0,0], x[0,1], ..., x[0,n_x]]         [[1, 0, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [1, 0, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 1, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [0, 1, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 0, 1]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]]         [0, 0, 1]]
        t_eye = T.eye(self.n_y, k=0)
        t_u = t_eye.reshape((self.n_y, 1, self.n_y)).repeat(bs_u,
                                                            axis=1).reshape(
                                                                (-1, self.n_y))
        x_u = self.sym_x_u.reshape(
            (1, bs_u, self.input_size**2 + self.n_mi_features)).repeat(
                self.n_y, axis=0).reshape(
                    (-1, self.input_size**2 + self.n_mi_features))

        # Since the expectation of var a is outside the integration we calculate E_q(a|x) first
        a_x_u = get_output(self.l_qa,
                           self.sym_x_u,
                           batch_norm_update_averages=True,
                           batch_norm_use_averages=False)
        a_x_u_rep = a_x_u.reshape(
            (1, bs_u * self.sym_samples, self.n_a)).repeat(self.n_y,
                                                           axis=0).reshape(
                                                               (-1, self.n_a))
        out_layers = [l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px]
        inputs = {self.l_x_in: x_u, self.l_y_in: t_u, self.l_a_in: a_x_u_rep}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_u, log_pz_u, log_qa_x_u, log_qz_axy_u, log_px_zy_u = out

        ################################################################
        ################################################################
        # Prior p(y) expecting that all classes are evenly distributed #
        ################################################################
        ##################    is this appropriate?    ##################
        ################################################################
        ################################################################

        py_u = softmax(T.zeros((bs_u * self.n_y, self.n_y)))
        log_py_u = -categorical_crossentropy(py_u, t_u).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_u = lower_bound(log_pa_u, log_qa_x_u, log_pz_u, log_qz_axy_u,
                           log_py_u, log_px_zy_u)
        lb_u = lb_u.reshape(
            (self.n_y, 1, 1, bs_u)).transpose(3, 1, 2, 0).mean(axis=(1, 2))
        inputs = {
            self.l_x_in: self.sym_x_u,
            self.l_a_in: a_x_u.reshape((-1, self.n_a))
        }
        y_u = get_output(self.l_qy,
                         inputs,
                         batch_norm_update_averages=True,
                         batch_norm_use_averages=False).mean(axis=(1, 2))
        y_u += 1e-8  # Ensure that we get no NANs when calculating the entropy
        y_u /= T.sum(y_u, axis=1, keepdims=True)
        lb_u = (y_u * (lb_u - T.log(y_u))).sum(axis=1)

        if self.batchnorm:
            # TODO: implement the BN layer correctly.
            inputs = {
                self.l_x_in: self.sym_x_u,
                self.l_y_in: y_u,
                self.l_a_in: a_x_u
            }
            get_output(out_layers,
                       inputs,
                       weighting=None,
                       batch_norm_update_averages=True,
                       batch_norm_use_averages=False)

        # Regularizing with weight priors p(theta|N(0,1)), collecting and clipping gradients
        weight_priors = 0.0
        for p in self.trainable_model_params:
            if 'W' not in str(p):
                continue
            weight_priors += log_normal(p, 0, 1).sum()

        # Collect the lower bound and scale it with the weight priors.
        elbo = ((lb_l.mean() + lb_u.mean()) * n + weight_priors) / -n
        lb_labeled = -lb_l.mean()
        lb_unlabeled = -lb_u.mean()

        grads_collect = T.grad(elbo, self.trainable_model_params)
        params_collect = self.trainable_model_params
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        clip_grad, max_norm = 1, 5
        mgrads = total_norm_constraint(grads_collect, max_norm=max_norm)
        mgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
        updates = adam(mgrads, params_collect, self.sym_lr, sym_beta1,
                       sym_beta2)

        # Training function
        indices = self._srng.choice(size=[self.sym_bs_l],
                                    a=sh_train_x_l.shape[0],
                                    replace=False)
        x_batch_l = sh_train_x_l[
            indices]  # Change these to symbolic variables and generate them in training loop
        t_batch_l = sh_train_t_l[indices]
        x_batch_u = self.sh_train_x[self.batch_slice]
        if self.x_dist == 'bernoulli':  # Sample bernoulli input.
            x_batch_u = self._srng.binomial(size=x_batch_u.shape,
                                            n=1,
                                            p=x_batch_u,
                                            dtype=theano.config.floatX)
            x_batch_l = self._srng.binomial(size=x_batch_l.shape,
                                            n=1,
                                            p=x_batch_l,
                                            dtype=theano.config.floatX)

        givens = {
            self.sym_x_l: x_batch_l,
            self.sym_x_u: x_batch_u,
            self.sym_t_l: t_batch_l
        }
        inputs = [
            self.sym_index, self.sym_batchsize, self.sym_bs_l, self.sym_beta,
            self.sym_lr, sym_beta1, sym_beta2, self.sym_samples
        ]
        outputs = [elbo, lb_labeled, lb_unlabeled]
        f_train = theano.function(inputs=inputs,
                                  outputs=outputs,
                                  givens=givens,
                                  updates=updates)

        # Default training args. Note that these can be changed during or prior to training.
        self.train_args['inputs']['batchsize_unlabeled'] = 100
        self.train_args['inputs']['batchsize_labeled'] = 100
        self.train_args['inputs']['beta'] = 0.1
        self.train_args['inputs']['learningrate'] = 3e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['outputs']['lb'] = '%0.4f'
        self.train_args['outputs']['lb-labeled'] = '%0.4f'
        self.train_args['outputs']['lb-unlabeled'] = '%0.4f'

        # Validation and test function
        y = get_output(self.l_qy, self.sym_x_l,
                       deterministic=True).mean(axis=(1, 2))
        class_err = (1. - categorical_accuracy(y, self.sym_t_l).mean()) * 100
        givens = {self.sym_x_l: self.sh_test_x, self.sym_t_l: self.sh_test_t}
        f_test = theano.function(inputs=[self.sym_samples],
                                 outputs=[class_err],
                                 givens=givens)

        # Test args.  Note that these can be changed during or prior to training.
        self.test_args['inputs']['samples'] = 1
        self.test_args['outputs']['test'] = '%0.2f%%'

        f_validate = None
        if validation_set is not None:
            givens = {
                self.sym_x_l: self.sh_valid_x,
                self.sym_t_l: self.sh_valid_t
            }
            f_validate = theano.function(inputs=[self.sym_samples],
                                         outputs=[class_err],
                                         givens=givens)
        # Default validation args. Note that these can be changed during or prior to training.
        self.validate_args['inputs']['samples'] = 1
        self.validate_args['outputs']['validation'] = '%0.2f%%'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args

    def get_output(self, x, samples=1):
        return self.f_qy(x, samples)

    def model_info(self):
        qa_shapes = self.get_model_shape(get_all_params(self.l_qa))
        qy_shapes = self.get_model_shape(get_all_params(
            self.l_qy))[len(qa_shapes) - 1:]
        qz_shapes = self.get_model_shape(get_all_params(
            self.l_qz))[len(qa_shapes) - 1:]
        px_shapes = self.get_model_shape(get_all_params(
            self.l_px))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        pa_shapes = self.get_model_shape(get_all_params(
            self.l_pa))[(len(qz_shapes) - 1) + (len(qa_shapes) - 1):]
        s = ""
        s += 'batch norm: %s.\n' % (str(self.batchnorm))
        s += 'x distribution: %s.\n' % (str(self.x_dist))
        s += 'model q(a|x): %s.\n' % str(qa_shapes)[1:-1]
        s += 'model q(z|a,x,y): %s.\n' % str(qz_shapes)[1:-1]
        s += 'model q(y|a,x): %s.\n' % str(qy_shapes)[1:-1]
        s += 'model p(x|a,z,y): %s.\n' % str(px_shapes)[1:-1]
        s += 'model p(a|z,y): %s.' % str(pa_shapes)[1:-1]
        return s