def build_model(self, train_set, test_set, validation_set=None):
        """
        :param train_set_unlabeled: Unlabeled train set containing variables x, t.
        :param train_set_labeled: Unlabeled train set containing variables x, t.
        :param test_set: Test set containing variables x, t.
        :param validation_set: Validation set containing variables x, t.
        :return: train, test, validation function and dicts of arguments.
        """
        super(CVAE, self).build_model(train_set, test_set, validation_set)

        n = self.sh_train_x.shape[0].astype(
            theano.config.floatX)  # no. of data points

        # Define the layers for the density estimation used in the lower bound.
        l_log_qz = GaussianLogDensityLayer(self.l_qz, self.l_qz_mu,
                                           self.l_qz_logvar)
        l_log_pz = StandardNormalLogDensityLayer(self.l_qz)

        l_x_in = ReshapeLayer(self.l_x_in, (-1, self.seq_length * self.n_x))
        if self.x_dist == 'bernoulli':
            l_px = ReshapeLayer(
                self.l_px,
                (-1, self.sym_samples, 1, self.seq_length * self.n_x))
            l_log_px = BernoulliLogDensityLayer(l_px, l_x_in)
        elif self.x_dist == 'multinomial':
            l_px = ReshapeLayer(
                self.l_px,
                (-1, self.sym_samples, 1, self.seq_length * self.n_x))
            l_log_px = MultinomialLogDensityLayer(l_px, l_x_in)
        elif self.x_dist == 'gaussian':
            l_px_mu = ReshapeLayer(
                self.l_px_mu,
                (-1, self.sym_samples, 1, self.seq_length * self.n_x))
            l_px_logvar = ReshapeLayer(
                self.l_px_logvar,
                (-1, self.sym_samples, 1, self.seq_length * self.n_x))
            l_log_px = GaussianLogDensityLayer(l_x_in, l_px_mu, l_px_logvar)
        elif self.x_dist == 'linear':
            l_log_px = self.l_px

        self.sym_warmup = T.fscalar('warmup')

        def lower_bound(log_pz, log_qz, log_px):
            return log_px + (log_pz - log_qz) * (1. - self.sym_warmup - 0.1)

        # Lower bound
        out_layers = [l_log_pz, l_log_qz, l_log_px]
        inputs = {self.l_x_in: self.sym_x}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pz, log_qz, log_px = out

        # If the decoder output is linear we need the reconstruction error
        if self.x_dist == 'linear':
            log_px = -aggregate(squared_error(log_px.mean(axis=(1, 2)),
                                              self.sym_x),
                                mode='mean')

        lb = lower_bound(log_pz, log_qz, log_px)
        lb = lb.mean(axis=(1, 2))  # Mean over the sampling dimensions

        if self.batchnorm:
            # TODO: implement the BN layer correctly.
            inputs = {self.l_x_in: self.sym_x}
            get_output(out_layers,
                       inputs,
                       weighting=None,
                       batch_norm_update_averages=True,
                       batch_norm_use_averages=False)

        # Regularizing with weight priors p(theta|N(0,1)), collecting and clipping gradients
        weight_priors = 0.0
        for p in self.trainable_model_params:
            if 'W' not in str(p):
                continue
            weight_priors += log_normal(p, 0, 1).sum()

        # Collect the lower bound and scale it with the weight priors.
        elbo = lb.mean()
        cost = (elbo * n + weight_priors) / -n

        grads_collect = T.grad(cost, self.trainable_model_params)
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        clip_grad, max_norm = 1, 5
        mgrads = total_norm_constraint(grads_collect, max_norm=max_norm)
        mgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
        updates = adam(mgrads, self.trainable_model_params, self.sym_lr,
                       sym_beta1, sym_beta2)
        # updates = rmsprop(mgrads, self.trainable_model_params, self.sym_lr + (0*sym_beta1*sym_beta2))

        # Training function
        x_batch = self.sh_train_x[self.batch_slice]
        if self.x_dist == 'bernoulli':  # Sample bernoulli input.
            x_batch = self._srng.binomial(size=x_batch.shape,
                                          n=1,
                                          p=x_batch,
                                          dtype=theano.config.floatX)

        givens = {self.sym_x: x_batch}
        inputs = [
            self.sym_index, self.sym_batchsize, self.sym_lr, sym_beta1,
            sym_beta2, self.sym_samples, self.sym_warmup
        ]
        outputs = [
            log_px.mean(),
            log_pz.mean(),
            log_qz.mean(), elbo, self.sym_warmup
        ]
        f_train = theano.function(inputs=inputs,
                                  outputs=outputs,
                                  givens=givens,
                                  updates=updates)

        # Default training args. Note that these can be changed during or prior to training.
        self.train_args['inputs']['batchsize'] = 100
        self.train_args['inputs']['learningrate'] = 1e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['inputs']['warmup'] = 0.1
        self.train_args['outputs']['log p(x)'] = '%0.6f'
        self.train_args['outputs']['log p(z)'] = '%0.6f'
        self.train_args['outputs']['log q(z)'] = '%0.6f'
        self.train_args['outputs']['elbo train'] = '%0.6f'
        self.train_args['outputs']['warmup'] = '%0.3f'

        # Validation and test function
        givens = {self.sym_x: self.sh_test_x}
        f_test = theano.function(inputs=[self.sym_samples, self.sym_warmup],
                                 outputs=[elbo],
                                 givens=givens)

        # Test args.  Note that these can be changed during or prior to training.
        self.test_args['inputs']['samples'] = 1
        self.test_args['inputs']['warmup'] = 0.1
        self.test_args['outputs']['elbo test'] = '%0.6f'

        f_validate = None
        if validation_set is not None:
            givens = {self.sym_x: self.sh_valid_x}
            f_validate = theano.function(inputs=[self.sym_samples],
                                         outputs=[elbo],
                                         givens=givens)
            # Default validation args. Note that these can be changed during or prior to training.
            self.validate_args['inputs']['samples'] = 1
            self.validate_args['outputs']['elbo validation'] = '%0.6f'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args
示例#2
0
    def build_model(self, train_set, test_set, validation_set=None):
        super(VAE, self).build_model(train_set, test_set, validation_set)

        # Density estimations
        l_log_pz = StandardNormalLogDensityLayer(self.l_z)
        l_log_qz_x = GaussianLogDensityLayer(self.l_z, self.l_z_mu,
                                             self.l_z_logvar)
        if self.x_dist == 'bernoulli':
            l_px_z = BernoulliLogDensityLayer(self.l_xhat, self.l_x_in)
        elif self.x_dist == 'gaussian':
            l_px_z = GaussianLogDensityLayer(self.l_x_in, self.l_xhat_mu,
                                             self.l_xhat_logvar)

        out_layers = [l_log_pz, l_log_qz_x, l_px_z]
        inputs = {self.l_x_in: self.sym_x}
        log_pz, log_qz_x, log_px_z = get_output(out_layers, inputs)
        lb = -(log_pz + log_px_z - log_qz_x).mean(axis=1).mean()

        all_params = get_all_params(self.l_xhat, trainable=True)
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        updates = adam(lb, all_params, self.sym_lr, sym_beta1, sym_beta2)

        x_batch = self.sh_train_x[self.batch_slice]
        if self.x_dist == 'bernoulli':
            x_batch = self._srng.binomial(size=x_batch.shape,
                                          n=1,
                                          p=x_batch,
                                          dtype=theano.config.floatX)
        givens = {self.sym_x: x_batch}
        inputs = [
            self.sym_index, self.sym_batchsize, self.sym_lr, sym_beta1,
            sym_beta2, self.sym_samples
        ]
        outputs = [lb]
        f_train = theano.function(inputs=inputs,
                                  outputs=outputs,
                                  givens=givens,
                                  updates=updates)
        # Training args
        self.train_args['inputs']['batchsize'] = 100
        self.train_args['inputs']['learningrate'] = 3e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['outputs']['lb'] = '%0.4f'

        givens = {self.sym_x: self.sh_test_x}
        inputs = [self.sym_samples]
        outputs = [lb]
        f_test = theano.function(inputs=inputs, outputs=outputs, givens=givens)
        # Testing args
        self.test_args['inputs']['samples'] = 1
        self.test_args['outputs']['lb'] = '%0.4f'

        f_validate = None
        if validation_set is not None:
            givens = {self.sym_x: self.sh_valid_x}
            inputs = [self.sym_samples]
            outputs = [lb]
            f_validate = theano.function(inputs=inputs,
                                         outputs=outputs,
                                         givens=givens)
            # Validation args
            self.validate_args['inputs']['samples'] = 1
            self.validate_args['outputs']['lb'] = '%0.4f'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args
示例#3
0
    def build_model(self,
                    train_set_unlabeled,
                    train_set_labeled,
                    test_set,
                    validation_set=None):
        """
        Build the auxiliary deep generative model from the initialized hyperparameters.
        Define the lower bound term and compile it into a training function.
        :param train_set_unlabeled: Unlabeled train set containing variables x, t.
        :param train_set_labeled: Unlabeled train set containing variables x, t.
        :param test_set: Test set containing variables x, t.
        :param validation_set: Validation set containing variables x, t.
        :return: train, test, validation function and dicts of arguments.
        """
        super(CSDGM, self).build_model(train_set_unlabeled, test_set,
                                       validation_set)

        sh_train_x_l = theano.shared(np.asarray(train_set_labeled[0],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        sh_train_t_l = theano.shared(np.asarray(train_set_labeled[1],
                                                dtype=theano.config.floatX),
                                     borrow=True)
        n = self.sh_train_x.shape[0].astype(
            theano.config.floatX)  # no. of data points
        n_l = sh_train_x_l.shape[0].astype(
            theano.config.floatX)  # no. of labeled data points

        # Define the layers for the density estimation used in the lower bound.
        l_log_qa = GaussianLogDensityLayer(self.l_qa, self.l_qa_mu,
                                           self.l_qa_logvar)
        l_log_qz = GaussianLogDensityLayer(self.l_qz, self.l_qz_mu,
                                           self.l_qz_logvar)
        l_log_qy = MultinomialLogDensityLayer(self.l_qy, self.l_y_in, eps=1e-8)

        l_log_pz = StandardNormalLogDensityLayer(self.l_qz)
        l_log_pa = GaussianLogDensityLayer(self.l_qa, self.l_pa_mu,
                                           self.l_pa_logvar)

        l_x_in = ReshapeLayer(self.l_x_in, (-1, self.n_l * self.n_c))
        l_px = DimshuffleLayer(self.l_px, (0, 3, 1, 2, 4))
        l_px = ReshapeLayer(l_px, (-1, self.sym_samples, 1, self.n_c))
        if self.x_dist == 'bernoulli':
            l_log_px = BernoulliLogDensityLayer(self.l_px, self.l_x_in)
        elif self.x_dist == 'multinomial':
            l_log_px = MultinomialLogDensityLayer(l_px, l_x_in)
            l_log_px = ReshapeLayer(l_log_px, (-1, self.n_l, 1, 1, 1))
            l_log_px = MeanLayer(l_log_px, axis=1)
        elif self.x_dist == 'gaussian':
            l_px_mu = ReshapeLayer(
                DimshuffleLayer(self.l_px_mu, (0, 2, 3, 1, 4)),
                (-1, self.sym_samples, 1, self.n_l * self.n_c))
            l_px_logvar = ReshapeLayer(
                DimshuffleLayer(self.l_px_logvar, (0, 2, 3, 1, 4)),
                (-1, self.sym_samples, 1, self.n_l * self.n_c))
            l_log_px = GaussianLogDensityLayer(l_x_in, l_px_mu, l_px_logvar)

        def lower_bound(log_pa, log_qa, log_pz, log_qz, log_py, log_px):
            lb = log_px + log_py + (log_pz + log_pa - log_qa -
                                    log_qz) * (1.1 - self.sym_warmup)
            return lb

        # Lower bound for labeled data
        out_layers = [
            l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px, l_log_qy
        ]
        inputs = {self.l_x_in: self.sym_x_l, self.l_y_in: self.sym_t_l}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_l, log_pz_l, log_qa_x_l, log_qz_axy_l, log_px_zy_l, log_qy_ax_l = out

        # Prior p(y) expecting that all classes are evenly distributed
        py_l = softmax(T.zeros((self.sym_x_l.shape[0], self.n_y)))
        log_py_l = -categorical_crossentropy(py_l, self.sym_t_l).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_l = lower_bound(log_pa_l, log_qa_x_l, log_pz_l, log_qz_axy_l,
                           log_py_l, log_px_zy_l)
        lb_l = lb_l.mean(axis=(1, 2))  # Mean over the sampling dimensions
        log_qy_ax_l *= (
            self.sym_beta * (n / n_l)
        )  # Scale the supervised cross entropy with the alpha constant
        lb_l += log_qy_ax_l.mean(axis=(
            1, 2
        ))  # Collect the lower bound term and mean over sampling dimensions

        # Lower bound for unlabeled data
        bs_u = self.sym_x_u.shape[0]

        # For the integrating out approach, we repeat the input matrix x, and construct a target (bs * n_y) x n_y
        # Example of input and target matrix for a 3 class problem and batch_size=2. 2D tensors of the form
        #               x_repeat                     t_repeat
        #  [[x[0,0], x[0,1], ..., x[0,n_x]]         [[1, 0, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [1, 0, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 1, 0]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]          [0, 1, 0]
        #   [x[0,0], x[0,1], ..., x[0,n_x]]          [0, 0, 1]
        #   [x[1,0], x[1,1], ..., x[1,n_x]]]         [0, 0, 1]]
        t_eye = T.eye(self.n_y, k=0)
        t_u = t_eye.reshape((self.n_y, 1, self.n_y)).repeat(bs_u,
                                                            axis=1).reshape(
                                                                (-1, self.n_y))
        x_u = self.sym_x_u.reshape(
            (1, bs_u, self.n_l, self.n_c)).repeat(self.n_y, axis=0).reshape(
                (-1, self.n_l, self.n_c))

        # Since the expectation of var a is outside the integration we calculate E_q(a|x) first
        a_x_u = get_output(self.l_qa,
                           self.sym_x_u,
                           batch_norm_update_averages=True,
                           batch_norm_use_averages=False)
        a_x_u_rep = a_x_u.reshape(
            (1, bs_u * self.sym_samples, self.n_a)).repeat(self.n_y,
                                                           axis=0).reshape(
                                                               (-1, self.n_a))
        out_layers = [l_log_pa, l_log_pz, l_log_qa, l_log_qz, l_log_px]
        inputs = {self.l_x_in: x_u, self.l_y_in: t_u, self.l_a_in: a_x_u_rep}
        out = get_output(out_layers,
                         inputs,
                         batch_norm_update_averages=False,
                         batch_norm_use_averages=False)
        log_pa_u, log_pz_u, log_qa_x_u, log_qz_axy_u, log_px_zy_u = out

        # Prior p(y) expecting that all classes are evenly distributed
        py_u = softmax(T.zeros((bs_u * self.n_y, self.n_y)))
        log_py_u = -categorical_crossentropy(py_u, t_u).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_u = lower_bound(log_pa_u, log_qa_x_u, log_pz_u, log_qz_axy_u,
                           log_py_u, log_px_zy_u)
        lb_u = lb_u.reshape(
            (self.n_y, 1, 1, bs_u)).transpose(3, 1, 2, 0).mean(axis=(1, 2))
        inputs = {
            self.l_x_in: self.sym_x_u,
            self.l_a_in: a_x_u.reshape((-1, self.n_a))
        }
        y_u = get_output(self.l_qy,
                         inputs,
                         batch_norm_update_averages=True,
                         batch_norm_use_averages=False).mean(axis=(1, 2))
        y_u += 1e-8  # Ensure that we get no NANs when calculating the entropy
        y_u /= T.sum(y_u, axis=1, keepdims=True)
        lb_u = (y_u * (lb_u - T.log(y_u))).sum(axis=1)

        # Regularizing with weight priors p(theta|N(0,1)), collecting and clipping gradients
        weight_priors = 0.0
        for p in self.trainable_model_params:
            if 'W' not in str(p):
                continue
            weight_priors += log_normal(p, 0, 1).sum()

        # Collect the lower bound and scale it with the weight priors.
        elbo = ((lb_l.mean() + lb_u.mean()) * n + weight_priors) / -n
        lb_labeled = -lb_l.mean()
        lb_unlabeled = -lb_u.mean()
        log_px = log_px_zy_l.mean() + log_px_zy_u.mean()
        log_pz = log_pz_l.mean() + log_pz_u.mean()
        log_qz = log_qz_axy_l.mean() + log_qz_axy_u.mean()
        log_pa = log_pa_l.mean() + log_pa_u.mean()
        log_qa = log_qa_x_l.mean() + log_qa_x_u.mean()

        grads_collect = T.grad(elbo, self.trainable_model_params)
        params_collect = self.trainable_model_params
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        clip_grad, max_norm = 1, 5
        mgrads = total_norm_constraint(grads_collect, max_norm=max_norm)
        mgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
        updates = adam(mgrads, params_collect, self.sym_lr, sym_beta1,
                       sym_beta2)

        # Training function
        indices = self._srng.choice(size=[self.sym_bs_l],
                                    a=sh_train_x_l.shape[0],
                                    replace=False)
        x_batch_l = sh_train_x_l[indices]
        t_batch_l = sh_train_t_l[indices]
        x_batch_u = self.sh_train_x[self.batch_slice]
        if self.x_dist == 'bernoulli':  # Sample bernoulli input.
            x_batch_u = self._srng.binomial(size=x_batch_u.shape,
                                            n=1,
                                            p=x_batch_u,
                                            dtype=theano.config.floatX)
            x_batch_l = self._srng.binomial(size=x_batch_l.shape,
                                            n=1,
                                            p=x_batch_l,
                                            dtype=theano.config.floatX)

        givens = {
            self.sym_x_l: x_batch_l,
            self.sym_x_u: x_batch_u,
            self.sym_t_l: t_batch_l
        }
        inputs = [
            self.sym_index, self.sym_batchsize, self.sym_bs_l, self.sym_beta,
            self.sym_lr, sym_beta1, sym_beta2, self.sym_samples,
            self.sym_warmup
        ]
        outputs = [
            elbo, lb_labeled, lb_unlabeled, log_px, log_pz, log_qz, log_pa,
            log_qa
        ]
        f_train = theano.function(inputs=inputs,
                                  outputs=outputs,
                                  givens=givens,
                                  updates=updates)

        # Default training args. Note that these can be changed during or prior to training.
        self.train_args['inputs']['batchsize_unlabeled'] = 100
        self.train_args['inputs']['batchsize_labeled'] = 100
        self.train_args['inputs']['beta'] = 0.1
        self.train_args['inputs']['learningrate'] = 3e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['inputs']['warmup'] = 0.1
        self.train_args['outputs']['lb'] = '%0.3f'
        self.train_args['outputs']['lb-l'] = '%0.3f'
        self.train_args['outputs']['lb-u'] = '%0.3f'
        self.train_args['outputs']['px'] = '%0.3f'
        self.train_args['outputs']['pz'] = '%0.3f'
        self.train_args['outputs']['qz'] = '%0.3f'
        self.train_args['outputs']['pa'] = '%0.3f'
        self.train_args['outputs']['qa'] = '%0.3f'

        # Validation and test function
        y = get_output(self.l_qy, self.sym_x_l,
                       deterministic=True).mean(axis=(1, 2))
        class_err = (1. - categorical_accuracy(y, self.sym_t_l).mean()) * 100
        givens = {self.sym_x_l: self.sh_test_x, self.sym_t_l: self.sh_test_t}
        f_test = theano.function(inputs=[self.sym_samples],
                                 outputs=[class_err],
                                 givens=givens)

        # Test args.  Note that these can be changed during or prior to training.
        self.test_args['inputs']['samples'] = 1
        self.test_args['outputs']['test'] = '%0.2f%%'

        f_validate = None
        if validation_set is not None:
            givens = {
                self.sym_x_l: self.sh_valid_x,
                self.sym_t_l: self.sh_valid_t
            }
            f_validate = theano.function(inputs=[self.sym_samples],
                                         outputs=[class_err],
                                         givens=givens)
            # Default validation args. Note that these can be changed during or prior to training.
            self.validate_args['inputs']['samples'] = 1
            self.validate_args['outputs']['validation'] = '%0.2f%%'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args
示例#4
0
    def build_model(self, train_set, test_set, validation_set=None):
        """
        Build the auxiliary deep generative model from the initialized hyperparameters.
        Define the lower bound term and compile it into a training function.
        :param train_set: Train set containing variables x, t.
        for the unlabeled data_preparation in the train set, we define 0's in t.
        :param test_set: Test set containing variables x, t.
        :param validation_set: Validation set containing variables x, t.
        :return: train, test, validation function and dicts of arguments.
        """
        super(ADGMSSL, self).build_model(train_set, test_set, validation_set)

        # Define the layers for the density estimation used in the lower bound.
        l_log_pa = GaussianMarginalLogDensityLayer(self.l_a_mu,
                                                   self.l_a_logvar)
        l_log_pz = GaussianMarginalLogDensityLayer(self.l_z_mu,
                                                   self.l_z_logvar)
        l_log_qa_x = GaussianMarginalLogDensityLayer(1, self.l_a_logvar)
        l_log_qz_xy = GaussianMarginalLogDensityLayer(1, self.l_z_logvar)
        l_log_qy_ax = MultinomialLogDensityLayer(self.l_y,
                                                 self.l_y_in,
                                                 eps=1e-8)
        if self.x_dist == 'bernoulli':
            l_px_zy = BernoulliLogDensityLayer(self.l_xhat, self.l_x_in)
        elif self.x_dist == 'multinomial':
            l_px_zy = MultinomialLogDensityLayer(self.l_xhat, self.l_x_in)
        elif self.x_dist == 'gaussian':
            l_px_zy = GaussianLogDensityLayer(self.l_x_in, self.l_xhat_mu,
                                              self.l_xhat_logvar)

        ### Compute lower bound for labeled data_preparation ###
        out_layers = [
            l_log_pa, l_log_pz, l_log_qa_x, l_log_qz_xy, l_px_zy, l_log_qy_ax
        ]
        inputs = {self.l_x_in: self.sym_x_l, self.l_y_in: self.sym_t_l}
        log_pa_l, log_pz_l, log_qa_x_l, log_qz_axy_l, log_px_zy_l, log_qy_ax_l = get_output(
            out_layers, inputs)
        py_l = softmax(T.zeros(
            (self.sym_x_l.shape[0], self.n_y)))  # non-informative prior
        log_py_l = -categorical_crossentropy(py_l, self.sym_t_l).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_l = log_pa_l + log_pz_l + log_py_l + log_px_zy_l - log_qa_x_l - log_qz_axy_l
        # Upscale the discriminative term with a weight.
        log_qy_ax_l *= self.sym_beta
        xhat_grads_l = T.grad(lb_l.mean(axis=(1, 2)).sum(), self.xhat_params)
        y_grads_l = T.grad(log_qy_ax_l.mean(axis=(1, 2)).sum(), self.y_params)
        lb_l += log_qy_ax_l
        lb_l = lb_l.mean(axis=(1, 2))

        ### Compute lower bound for unlabeled data_preparation ###
        bs_u = self.sym_x_u.shape[0]  # size of the unlabeled data_preparation.
        t_eye = T.eye(self.n_y,
                      k=0)  # ones in diagonal and 0's elsewhere (bs x n_y).
        # repeat unlabeled t the number of classes for integration (bs * n_y) x n_y.
        t_u = t_eye.reshape((self.n_y, 1, self.n_y)).repeat(bs_u,
                                                            axis=1).reshape(
                                                                (-1, self.n_y))
        # repeat unlabeled x the number of classes for integration (bs * n_y) x n_x
        x_u = self.sym_x_u.reshape(
            (1, bs_u, self.n_x)).repeat(self.n_y, axis=0).reshape(
                (-1, self.n_x))
        out_layers = [l_log_pa, l_log_pz, l_log_qa_x, l_log_qz_xy, l_px_zy]
        inputs = {self.l_x_in: x_u, self.l_y_in: t_u}
        log_pa_u, log_pz_u, log_qa_x_u, log_qz_axy_u, log_px_zy_u = get_output(
            out_layers, inputs)
        py_u = softmax(T.zeros(
            (bs_u * self.n_y, self.n_y)))  # non-informative prior.
        log_py_u = -categorical_crossentropy(py_u, t_u).reshape(
            (-1, 1)).dimshuffle((0, 'x', 'x', 1))
        lb_u = log_pa_u + log_pz_u + log_py_u + log_px_zy_u - log_qa_x_u - log_qz_axy_u
        lb_u = lb_u.reshape(
            (self.n_y, self.sym_samples, 1,
             bs_u)).transpose(3, 1, 2,
                              0).mean(axis=(1, 2))  # mean over samples.
        y_ax_u = get_output(self.l_y, self.sym_x_u)
        y_ax_u = y_ax_u.mean(axis=(1, 2))  # bs x n_y
        y_ax_u += 1e-8  # ensure that we get no NANs.
        y_ax_u /= T.sum(y_ax_u, axis=1, keepdims=True)
        xhat_grads_u = T.grad((y_ax_u * lb_u).sum(axis=1).sum(),
                              self.xhat_params)
        lb_u = (y_ax_u * (lb_u - T.log(y_ax_u))).sum(axis=1)
        y_grads_u = T.grad(lb_u.sum(), self.y_params)

        # Loss - regularizing with weight priors p(theta|N(0,1)) and clipping gradients
        y_weight_priors = 0.0
        for p in self.y_params:
            if 'W' not in str(p):
                continue
            y_weight_priors += log_normal(p, 0, 1).sum()
        y_weight_priors_grad = T.grad(y_weight_priors,
                                      self.y_params,
                                      disconnected_inputs='ignore')

        xhat_weight_priors = 0.0
        for p in self.xhat_params:
            if 'W' not in str(p):
                continue
            xhat_weight_priors += log_normal(p, 0, 1).sum()
        xhat_weight_priors_grad = T.grad(xhat_weight_priors,
                                         self.xhat_params,
                                         disconnected_inputs='ignore')

        n = self.sh_train_x.shape[0].astype(
            theano.config.floatX
        )  # no. of data_preparation points in train set
        n_b = n / self.sym_batchsize.astype(
            theano.config.floatX)  # no. of batches in train set
        y_grads = [T.zeros(p.shape) for p in self.y_params]
        for i in range(len(y_grads)):
            y_grads[i] = (y_grads_l[i] + y_grads_u[i])
            y_grads[i] *= n_b
            y_grads[i] += y_weight_priors_grad[i]
            y_grads[i] /= -n

        xhat_grads = [T.zeros(p.shape) for p in self.xhat_params]
        for i in range(len(xhat_grads)):
            xhat_grads[i] = (xhat_grads_l[i] + xhat_grads_u[i])
            xhat_grads[i] *= n_b
            xhat_grads[i] += xhat_weight_priors_grad[i]
            xhat_grads[i] /= -n

        params = self.y_params + self.xhat_params
        grads = y_grads + xhat_grads

        # Collect the lower bound and scale it with the weight priors.
        elbo = ((lb_l.sum() + lb_u.sum()) * n_b + y_weight_priors +
                xhat_weight_priors) / -n

        # Avoid vanishing and exploding gradients.
        clip_grad, max_norm = 1, 5
        mgrads = total_norm_constraint(grads, max_norm=max_norm)
        mgrads = [T.clip(g, -clip_grad, clip_grad) for g in mgrads]
        sym_beta1 = T.scalar('beta1')
        sym_beta2 = T.scalar('beta2')
        updates = adam(mgrads, params, self.sym_lr, sym_beta1, sym_beta2)

        ### Compile training function ###
        x_batch_l = self.sh_train_x[self.batch_slice][:self.sym_bs_l]
        x_batch_u = self.sh_train_x[self.batch_slice][self.sym_bs_l:]
        t_batch_l = self.sh_train_t[self.batch_slice][:self.sym_bs_l]
        if self.x_dist == 'bernoulli':  # Sample bernoulli input.
            x_batch_u = self._srng.binomial(size=x_batch_u.shape,
                                            n=1,
                                            p=x_batch_u,
                                            dtype=theano.config.floatX)
            x_batch_l = self._srng.binomial(size=x_batch_l.shape,
                                            n=1,
                                            p=x_batch_l,
                                            dtype=theano.config.floatX)
        givens = {
            self.sym_x_l: x_batch_l,
            self.sym_x_u: x_batch_u,
            self.sym_t_l: t_batch_l
        }
        inputs = [
            self.sym_index, self.sym_batchsize, self.sym_bs_l, self.sym_beta,
            self.sym_lr, sym_beta1, sym_beta2, self.sym_samples
        ]
        f_train = theano.function(inputs=inputs,
                                  outputs=[elbo],
                                  givens=givens,
                                  updates=updates)
        # Default training args. Note that these can be changed during or prior to training.
        self.train_args['inputs']['batchsize'] = 200
        self.train_args['inputs']['batchsize_labeled'] = 100
        self.train_args['inputs']['beta'] = 1200.
        self.train_args['inputs']['learningrate'] = 3e-4
        self.train_args['inputs']['beta1'] = 0.9
        self.train_args['inputs']['beta2'] = 0.999
        self.train_args['inputs']['samples'] = 1
        self.train_args['outputs']['lb'] = '%0.4f'

        ### Compile testing function ###
        class_err_test = self._classification_error(self.sym_x_l, self.sym_t_l)
        givens = {self.sym_x_l: self.sh_test_x, self.sym_t_l: self.sh_test_t}
        f_test = theano.function(inputs=[self.sym_samples],
                                 outputs=[class_err_test],
                                 givens=givens)
        # Testing args.  Note that these can be changed during or prior to training.
        self.test_args['inputs']['samples'] = 1
        self.test_args['outputs']['err'] = '%0.2f%%'

        ### Compile validation function ###
        f_validate = None
        if validation_set is not None:
            class_err_valid = self._classification_error(
                self.sym_x_l, self.sym_t_l)
            givens = {
                self.sym_x_l: self.sh_valid_x,
                self.sym_t_l: self.sh_valid_t
            }
            inputs = [self.sym_samples]
            f_validate = theano.function(inputs=[self.sym_samples],
                                         outputs=[class_err_valid],
                                         givens=givens)
        # Default validation args. Note that these can be changed during or prior to training.
        self.validate_args['inputs']['samples'] = 1
        self.validate_args['outputs']['err'] = '%0.2f%%'

        return f_train, f_test, f_validate, self.train_args, self.test_args, self.validate_args