Example #1
0
def feedforward_and_save_mean_var(net, dataloader, task_no, num_ens=1):
    cfg.mean_var_dir = "Mixtures/mean_vars/task-{}/".format(task_no)
    if not os.path.exists(cfg.mean_var_dir):
        os.makedirs(cfg.mean_var_dir, exist_ok=True)
        cfg.record_mean_var = True
        cfg.record_layers = None  # All layers
        cfg.record_now = True
        cfg.curr_batch_no = 0  # Not required
        cfg.curr_epoch_no = 0  # Not required

    net.train()  # To get distribution of mean and var
    accs = []

    for i, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes,
                              num_ens).to(device)
        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1).data

        log_outputs = utils.logmeanexp(outputs, dim=2)
        accs.append(metrics.acc(log_outputs, labels))
    return np.mean(accs)
Example #2
0
def train_model(net, optimizer, criterion, trainloader, num_ens=1):
    net.train()
    training_loss = 0.0
    accs = []
    kl_list = []
    for i, (inputs, labels) in enumerate(trainloader, 0):
        optimizer.zero_grad()

        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes,
                              num_ens).to(device)

        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1)

        kl = kl / num_ens
        kl_list.append(kl.item())
        log_outputs = utils.logmeanexp(outputs, dim=2)

        loss = criterion(log_outputs, labels, kl)
        loss.backward()
        optimizer.step()

        accs.append(metrics.acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()
    return training_loss / len(trainloader), np.mean(accs), np.mean(kl_list)
Example #3
0
def validate_model(net,
                   criterion,
                   validloader,
                   num_ens=1,
                   beta_type=0.1,
                   epoch=None,
                   num_epochs=None):
    """Calculate ensemble accuracy and NLL Loss"""
    net.train()
    valid_loss = 0.0
    accs = []

    for i, (inputs, labels) in enumerate(validloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes,
                              num_ens).to(device)
        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1).data

        log_outputs = utils.logmeanexp(outputs, dim=2)

        beta = metrics.get_beta(i - 1, len(validloader), beta_type, epoch,
                                num_epochs)
        valid_loss += criterion(log_outputs, labels, kl, beta).item()
        accs.append(metrics.acc(log_outputs, labels))

    return valid_loss / len(validloader), np.mean(accs)
	def _ulogprob_hid(self, Y, num_is_samples=100):
		"""
		Estimates the unnormalized marginal log-probabilities of hidden states.
		
		Use this method only if you know what you are doing.
		"""

		# approximate this SRBM with an RBM
		rbm = RBM(self.X.shape[0], self.Y.shape[0])
		rbm.W = self.W
		rbm.b = self.b
		rbm.c = self.c

		# allocate memory
		Q = np.asmatrix(np.zeros([num_is_samples, Y.shape[1]]))

		for k in range(num_is_samples):
			# draw importance samples
			X = rbm.backward(Y)

			# store importance weights
			Q[k, :] = self._ulogprob(X, Y) - rbm._clogprob_vis_hid(X, Y)

		# average importance weights to get estimates
		return utils.logmeanexp(Q, 0)
Example #5
0
    def _ulogprob_hid(self, Y, num_is_samples=100):
        """
		Estimates the unnormalized marginal log-probabilities of hidden states.
		
		Use this method only if you know what you are doing.
		"""

        # approximate this SRBM with an RBM
        rbm = RBM(self.X.shape[0], self.Y.shape[0])
        rbm.W = self.W
        rbm.b = self.b
        rbm.c = self.c

        # allocate memory
        Q = np.asmatrix(np.zeros([num_is_samples, Y.shape[1]]))

        for k in range(num_is_samples):
            # draw importance samples
            X = rbm.backward(Y)

            # store importance weights
            Q[k, :] = self._ulogprob(X, Y) - rbm._clogprob_vis_hid(X, Y)

        # average importance weights to get estimates
        return utils.logmeanexp(Q, 0)
    def estimate_log_probability(self, X, num_samples=200):
        """
        Estimates the log-probability in nats.

        This method returns two values: Optimistic but consistent estimates of
        the log probability of the given data samples and estimated lower bounds.
        The parameter C{num_samples} is only relevant for DBNs with at least 2
        layers.  L{estimate_log_partition_function}() should be run with
        appropriate parameters beforehand, otherwise the probability estimates
        will be very poor.
        @type  X: array_like
        @param X: the data points for which to estimate the log-probability
        @type  num_samples: integer
        @param num_samples: the number of Monte Carlo samples used to estimate the
        unnormalized probability of the data samples
        @rtype:  tuple
        @return: a tuple consisting of the estimated log-probabilities (first entry)
        and estimated lower bounds (second entry)
        """

        # estimate partition function if not done yet
        if not self.dbn[-1]._ais_logz:
            self.estimate_log_partition_function()

        if len(self.dbn) > 1:
            for l in range(len(self.dbn) - 1):
                if isinstance(self.dbn[l], SemiRBM):
                    # needed for estimating SemiRBM marginals
                    if not self.dbn[l]._ais_logz:
                        self.dbn[
                            l]._ais_logz = self.estimate_log_partition_function(
                                layer=l)

            # allocate (shared) memory for log importance weights
            logiws = shmarray.zeros([num_samples, X.shape[1]])

            # Monte Carlo estimation of unnormalized probability
            def parfor(i):
                samples = X

                for l in range(len(self.dbn) - 1):
                    logiws[i, :] += self.dbn[l]._ulogprob_vis(samples).A[0]
                    samples = self.dbn[l].forward(samples)
                    logiws[i, :] -= self.dbn[l]._ulogprob_hid(samples).A[0]
                logiws[i, :] += self.dbn[-1]._ulogprob_vis(samples).A[0]

            map(parfor, range(num_samples))

            # averaging weights yields unnormalized probability
            ulogprob = utils.logmeanexp(np.asmatrix(logiws), 0)
            ubound = logiws.mean(0)

        else:
            ulogprob = self.dbn[0]._ulogprob_vis(X)
            ubound = ulogprob.copy()

        # return normalized log probability
        return (ulogprob - self.dbn[-1]._ais_logz,
                ubound - self.dbn[-1]._ais_logz)
	def estimate_log_probability(self, X, num_samples=200):
		"""
		Estimates the log-probability in nats.
		
		This method returns two values: Optimistic but consistent estimates of
		the log probability of the given data samples and estimated lower bounds.
		The parameter C{num_samples} is only relevant for DBNs with at least 2
		layers.  L{estimate_log_partition_function}() should be run with
		appropriate parameters beforehand, otherwise the probability estimates
		will be very poor.

		@type  X: array_like
		@param X: the data points for which to estimate the log-probability

		@type  num_samples: integer
		@param num_samples: the number of Monte Carlo samples used to estimate the
		unnormalized probability of the data samples

		@rtype:  tuple
		@return: a tuple consisting of the estimated log-probabilities (first entry)
		and estimated lower bounds (second entry)
		"""

		# estimate partition function if not done yet
		if not self.dbn[-1]._ais_logz:
			self.estimate_log_partition_function()

		if len(self.dbn) > 1:
			for l in range(len(self.dbn) - 1):
				if isinstance(self.dbn[l], SemiRBM):
					# needed for estimating SemiRBM marginals
					if not self.dbn[l]._ais_logz:
						self.dbn[l]._ais_logz = self.estimate_log_partition_function(layer=l)

			# allocate (shared) memory for log importance weights
			logiws = shmarray.zeros([num_samples, X.shape[1]])

			# Monte Carlo estimation of unnormalized probability
			def parfor(i):
				samples = X

				for l in range(len(self.dbn) - 1):
					logiws[i, :] += self.dbn[l]._ulogprob_vis(samples).A[0]
					samples = self.dbn[l].forward(samples)
					logiws[i, :] -= self.dbn[l]._ulogprob_hid(samples).A[0]
				logiws[i, :] += self.dbn[-1]._ulogprob_vis(samples).A[0]
			map(parfor, range(num_samples))

			# averaging weights yields unnormalized probability
			ulogprob = utils.logmeanexp(np.asmatrix(logiws), 0)
			ubound = logiws.mean(0)

		else:
			ulogprob = self.dbn[0]._ulogprob_vis(X)
			ubound = ulogprob.copy()

		# return normalized log probability
		return (ulogprob - self.dbn[-1]._ais_logz, ubound - self.dbn[-1]._ais_logz)
Example #8
0
    def estimate_log_partition_function(self,
                                        num_ais_samples=100,
                                        beta_weights=[],
                                        layer=-1):
        """
		Estimate the log of the partition function.

		C{beta_weights} should be a list of monotonically increasing values ranging
		from 0 to 1. See Salakhutdinov & Murray (2008) for details on how to set
		the parameters.

		@type  num_ais_samples: integer
		@param num_ais_samples: number of samples used to estimate the partition
		function

		@type  beta_weights: array_like
		@param beta_weights: annealing weights ranging from zero to one

		@type  layer: integer
		@param layer: can be used to estimate the partition function of one
		of the lower layers

		@rtype:  real
		@return: the estimated log partition function
		"""

        bsbm = BaseBM(self.dbn[layer])
        mxbm = MixBM(bsbm, self.dbn[layer])

        # settings relevant only for SemiRBM
        bsbm.sampling_method = AbstractBM.GIBBS
        mxbm.sampling_method = AbstractBM.GIBBS
        mxbm.num_lateral_updates = 5

        # draw (independent) samples from the base model
        X = bsbm.sample(num_ais_samples, 0, 1)

        # compute importance weights
        logweights = bsbm._free_energy(X)

        for beta in beta_weights:
            mxbm.tune(beta)

            logweights -= mxbm._free_energy(X)
            Y = mxbm.forward(X)
            X = mxbm.backward(Y, X)
            logweights += mxbm._free_energy(X)

        logweights -= self.dbn[layer]._free_energy(X)

        # store results for later use
        self.dbn[layer]._ais_logweights = logweights + bsbm.logz
        self.dbn[layer]._ais_logz = utils.logmeanexp(logweights) + bsbm.logz
        self.dbn[layer]._ais_samples = X

        return self.dbn[layer]._ais_logz
	def estimate_log_partition_function(self, num_ais_samples=100, beta_weights=[], layer=-1):
		"""
		Estimate the log of the partition function.

		C{beta_weights} should be a list of monotonically increasing values ranging
		from 0 to 1. See Salakhutdinov & Murray (2008) for details on how to set
		the parameters.

		@type  num_ais_samples: integer
		@param num_ais_samples: number of samples used to estimate the partition
		function

		@type  beta_weights: array_like
		@param beta_weights: annealing weights ranging from zero to one

		@type  layer: integer
		@param layer: can be used to estimate the partition function of one
		of the lower layers

		@rtype:  real
		@return: the estimated log partition function
		"""

		bsbm = BaseBM(self.dbn[layer])
		mxbm = MixBM(bsbm, self.dbn[layer])

		# settings relevant only for SemiRBM
		bsbm.sampling_method = AbstractBM.GIBBS
		mxbm.sampling_method = AbstractBM.GIBBS
		mxbm.num_lateral_updates = 5

		# draw (independent) samples from the base model
		X = bsbm.sample(num_ais_samples, 0, 1)

		# compute importance weights
		logweights = bsbm._free_energy(X)

		for beta in beta_weights:
			mxbm.tune(beta)

			logweights -= mxbm._free_energy(X)
			Y = mxbm.forward(X)
			X = mxbm.backward(Y, X)
			logweights += mxbm._free_energy(X)

		logweights -= self.dbn[layer]._free_energy(X)

		# store results for later use
		self.dbn[layer]._ais_logweights = logweights + bsbm.logz
		self.dbn[layer]._ais_logz = utils.logmeanexp(logweights) + bsbm.logz
		self.dbn[layer]._ais_samples = X

		return self.dbn[layer]._ais_logz
Example #10
0
File: vae.py Project: lxuechen/BDMC
    def forward(self, x, k=1, warmup_const=1.):
        x = x.repeat(k, 1)
        mu, logvar = self.encode(x)
        z, logpz, logqz = self.sample(mu, logvar)
        x_logits = self.decode(z)

        logpx = utils.log_bernoulli(x_logits, x)
        elbo = logpx + logpz - warmup_const * logqz

        # need correction for Tensor.repeat
        elbo = utils.logmeanexp(elbo.view(k, -1).transpose(0, 1))
        elbo = torch.mean(elbo)

        logpx = torch.mean(logpx)
        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)

        return elbo, logpx, logpz, logqz
Example #11
0
def train_model(net,
                optimizer,
                criterion,
                trainloader,
                num_ens=1,
                beta_type=0.1):
    net.train()
    training_loss = 0.0
    accs = []
    kl_list = []
    freq = cfg.recording_freq_per_epoch
    freq = len(trainloader) // freq
    for i, (inputs, labels) in enumerate(trainloader, 1):
        cfg.curr_batch_no = i
        if i % freq == 0:
            cfg.record_now = True
        else:
            cfg.record_now = False

        optimizer.zero_grad()

        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes,
                              num_ens).to(device)

        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1)

        kl = kl / num_ens
        kl_list.append(kl.item())
        log_outputs = utils.logmeanexp(outputs, dim=2)

        beta = metrics.get_beta(i - 1, len(trainloader), beta_type)
        loss = criterion(log_outputs, labels, kl, beta)
        loss.backward()
        optimizer.step()

        accs.append(metrics.acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()
    return training_loss / len(trainloader), np.mean(accs), np.mean(kl_list)
def test_model(net, criterion, testloader, num_ens=10):
    """Calculate ensemble accuracy and NLL Loss"""
    net.eval()
    test_loss = 0.0
    accs = []

    for i, (inputs, labels) in enumerate(testloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes,
                              num_ens).to(device)
        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1).data

        log_outputs = utils.logmeanexp(outputs, dim=2)
        test_loss += criterion(log_outputs, labels, kl).item()
        accs.append(metrics.acc(log_outputs, labels))

    return test_loss / len(testloader), np.mean(accs)
Example #13
0
def predict_regular(net, validloader, bayesian=True, num_ens=10):
    """
    For both Bayesian and Frequentist models
    """
    net.eval()
    accs = []

    for i, (inputs, labels) in enumerate(validloader):
        inputs, labels = inputs.to(device), labels.to(device)
        if bayesian:
            outputs = torch.zeros(inputs.shape[0], net.num_classes,
                                  num_ens).to(device)
            for j in range(num_ens):
                net_out, _ = net(inputs)
                outputs[:, :, j] = F.log_softmax(net_out, dim=1).data

            log_outputs = utils.logmeanexp(outputs, dim=2)
            accs.append(metrics.acc(log_outputs, labels))
        else:
            output = net(inputs)
            accs.append(metrics.acc(output.detach(), labels))

    return np.mean(accs)
Example #14
0
File: ais.py Project: lxuechen/BDMC
def ais_trajectory(
    model,
    loader,
    forward: bool,
    schedule: Union[torch.Tensor, List],
    n_sample: Optional[int] = 100,
    initial_step_size: Optional[int] = 0.01,
    device: Optional[torch.device] = None,
):
    """Compute annealed importance sampling trajectories for a batch of data.

    Could be used for *both* forward and reverse chain in BDMC.

    Args:
      model (vae.VAE): VAE model
      loader (iterator): iterator that returns pairs, with first component
        being `x`, second would be `z` or label (will not be used)
      forward: indicate forward/backward chain
      schedule: temperature schedule, i.e. `p(z)p(x|z)^t`
      n_sample: number of importance samples
      device: device to run all computation on
      initial_step_size: initial step size for leap-frog integration;
        the actual step size is adapted online based on accept-reject ratios

    Returns:
        a list where each element is a torch.Tensor that contains the
        log importance weights for a single batch of data
    """
    def log_f_i(z, data, t, log_likelihood_fn=utils.log_bernoulli):
        """Unnormalized density for intermediate distribution `f_i`:
            f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t
        =>  log f_i = log p(z) + t * log p(x|z)
        """
        zeros = torch.zeros_like(z)
        log_prior = utils.log_normal(z, zeros, zeros)
        log_likelihood = log_likelihood_fn(model.decode(z), data)

        return log_prior + log_likelihood.mul_(t)

    logws = []
    for i, (batch, post_z) in enumerate(loader):
        B = batch.size(0) * n_sample
        batch = batch.to(device)
        batch = utils.safe_repeat(batch, n_sample)

        epsilon = torch.full(size=(B, ),
                             device=device,
                             fill_value=initial_step_size)
        accept_hist = torch.zeros(size=(B, ), device=device)
        logw = torch.zeros(size=(B, ), device=device)

        # initial sample of z
        if forward:
            current_z = torch.randn(size=(B, model.latent_dim), device=device)
        else:
            current_z = utils.safe_repeat(post_z, n_sample).to(device)

        for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]),
                                          1)):
            # update log importance weight
            log_int_1 = log_f_i(current_z, batch, t0)
            log_int_2 = log_f_i(current_z, batch, t1)
            logw += log_int_2 - log_int_1

            def U(z):
                return -log_f_i(z, batch, t1)

            @torch.enable_grad()
            def grad_U(z):
                z = z.clone().requires_grad_(True)
                grad, = torch.autograd.grad(U(z).sum(), z)
                max_ = B * model.latent_dim * 100.
                grad = torch.clamp(grad, -max_, max_)
                return grad

            def normalized_kinetic(v):
                zeros = torch.zeros_like(v)
                return -utils.log_normal(v, zeros, zeros)

            # resample velocity
            current_v = torch.randn_like(current_z)
            z, v = hmc.hmc_trajectory(current_z, current_v, grad_U, epsilon)
            current_z, epsilon, accept_hist = hmc.accept_reject(
                current_z,
                current_v,
                z,
                v,
                epsilon,
                accept_hist,
                j,
                U=U,
                K=normalized_kinetic,
            )

        logw = utils.logmeanexp(logw.view(n_sample, -1).transpose(0, 1))
        if not forward:
            logw = -logw
        logws.append(logw)
        print('Last batch stats %.4f' % (logw.mean().cpu().item()))

    return logws
    def fit(self,
            trainloader,
            validloader,
            num_train_ensemble=1,
            num_valid_ensemble=1,
            scheduler=None,
            num_epochs=1000):
        # . . set the scheduler
        if scheduler is not None:
            self.scheduler = scheduler

        # . . logs
        logs = {}

        # . . number of batches
        num_train_batch = len(trainloader)
        num_valid_batch = len(validloader)

        # . . register num batches to logs
        logs['num_train_batch'] = num_train_batch
        logs['num_valid_batch'] = num_valid_batch

        # . .
        # . . set the callback handler
        callback_handler = CallbackHandler(self.callbacks)

        # . . keep track of the losses
        train_losses = []
        valid_losses = []
        kldiv_losses = []

        # . . call the callback function on_train_begin(): load the best model
        callback_handler.on_train_begin(logs=logs, model=self.model)

        for epoch in range(num_epochs):
            # . . call the callback functions on_epoch_begin()
            callback_handler.on_epoch_begin(epoch, logs=logs, model=self.model)

            train_loss = 0.
            valid_loss = 0.
            kldiv_loss = 0.

            # . . activate the training mode
            self.model.train()

            # . . the training and validation accuracy
            train_accuracy = []
            valid_accuracy = []
            # . . get the next batch of training data
            for batch, (inputs, targets) in enumerate(trainloader):

                # . . the training loss for the current batch
                batch_loss = 0.

                # . . send the batch to GPU
                inputs, targets = inputs.to(self.device), targets.to(
                    self.device)

                # . . get the batch size
                batch_size = inputs.size(0)

                # . . prepare the outputs for multiple ensembles
                outputs = torch.zeros(inputs.shape[0], self.model.num_classes,
                                      num_train_ensemble).to(self.device)

                # . . zero the parameter gradients
                self.optimizer.zero_grad()

                # . . feed-forward network: multiple ensembles
                kl_div = 0.0
                for ens in range(num_train_ensemble):
                    outputs_, kl_div_ = self.model(inputs)
                    # . . accumulate the kl div loss
                    kl_div += kl_div_
                    # . . keep the outputs
                    outputs[:, :, ens] = F.log_softmax(outputs_, dim=1)

                # . . normalise the kl div loss over ensembles
                kl_div /= num_train_ensemble

                # . . make sure the outputs are positive
                log_outputs = utils.logmeanexp(outputs, dim=2)

                # . . compute the beta for the kl div loss
                beta_scl = 0.01
                beta = 2**(num_train_batch -
                           (batch + 1)) / (2**num_train_batch - 1)
                beta *= beta_scl

                # . . calculate the loss function
                loss = self.criterion(log_outputs, targets, kl_div, beta,
                                      batch_size)

                # . . backpropogate the scaled loss
                #loss.backward()
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()

                # . . update weights
                self.optimizer.step()

                # . . training loss for the current batch: accumulate over cameras
                batch_loss += loss.item()

                # . . accumulate the training loss
                train_loss += loss.item()

                # . . accumulate the KL divergence loss
                kldiv_loss += kl_div.item()

                # . . register the batch training loss
                logs['batch_loss'] = batch_loss

                # . . compute the accuracy
                train_accuracy.append(utils.accuracy(log_outputs, targets))

                # . . call the callback functions on_epoch_end()
                callback_handler.on_batch_end(batch,
                                              logs=logs,
                                              model=self.model)

            # . . activate the evaluation (validation) mode
            self.model.eval()
            # . . turn off the gradient for performance
            with torch.set_grad_enabled(False):
                # . . get the next batch of validation data
                for batch, (inputs, targets) in enumerate(validloader):

                    # . . send the batch to GPU
                    inputs, targets = inputs.to(self.device), targets.to(
                        self.device)

                    # . . get the batch size
                    batch_size = inputs.size(0)

                    # . . prepare the outputs for multiple ensembles
                    outputs = torch.zeros(inputs.shape[0],
                                          self.model.num_classes,
                                          num_valid_ensemble).to(self.device)

                    # . . feed-forward network: multiple ensembles
                    kl_div = 0.0
                    for ens in range(num_valid_ensemble):
                        outputs_, kl_div_ = self.model(inputs)
                        # . . accumulate the kl div loss
                        kl_div += kl_div_
                        # . . keep the outputs
                        outputs[:, :, ens] = F.log_softmax(outputs_, dim=1)

                    # . . normalise the kl div loss over ensembles
                    kl_div /= num_valid_ensemble

                    # . . make sure the outputs are positive
                    log_outputs = utils.logmeanexp(outputs, dim=2)

                    # . . compute the beta for the kl div loss
                    beta_scl = 0.01
                    beta = 2**(num_valid_batch -
                               (batch + 1)) / (2**num_valid_batch - 1)
                    beta *= beta_scl

                    # . . calculate the loss function
                    loss = self.criterion(log_outputs, targets, kl_div, beta,
                                          batch_size)

                    # . . accumulate the validation loss
                    valid_loss += loss.item()

                    # . . compute the accuracy
                    valid_accuracy.append(utils.accuracy(log_outputs, targets))

            # . . call the learning-rate scheduler
            if self.scheduler is not None:
                self.scheduler.step(valid_loss)

            # . . normalize the training and validation losses
            train_loss /= num_train_batch
            valid_loss /= num_valid_batch

            # . . compute the mean accuracy
            logs['train_acc'] = np.mean(train_accuracy)
            logs['valid_acc'] = np.mean(valid_accuracy)

            # . . on epoch end
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
            kldiv_losses.append(kldiv_loss)

            # . . update the epoch statistics (logs)
            logs["train_loss"] = train_loss
            logs["valid_loss"] = valid_loss
            logs["kldiv_loss"] = kldiv_loss * beta

            # . . call the callback functions on_epoch_end()
            callback_handler.on_epoch_end(epoch, logs=logs, model=self.model)

            # . . check if the training should continue
            if self.model._stop_training:
                break

        # . . call the callback function on_train_end(): load the best model
        callback_handler.on_train_end(logs=logs, model=self.model)

        return train_losses, valid_losses
    def evaluate(self, trainloader, testloader, num_eval_ensemble=1):
        # . . activate the validation (evaluation) mode
        self.model.eval()

        # . . training accuracy
        num_correct = 0
        num_predictions = 0

        # . . iterate over batches
        for inputs, targets in trainloader:
            # . . move to device
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            # . . prepare the outputs for multiple ensembles
            outputs = torch.zeros(inputs.shape[0], self.model.num_classes,
                                  num_eval_ensemble).to(self.device)

            # . . feed-forward network: multiple ensembles
            kl_div = 0.0
            for ens in range(num_eval_ensemble):
                outputs_, kl_div_ = self.model(inputs)
                # . . accumulate the kl div loss
                kl_div += kl_div_
                # . . keep the outputs
                outputs[:, :, ens] = F.log_softmax(outputs_, dim=1)
                #outputs[:,:,ens] = F.softmax(outputs_, dim=1)

            # . . normalise the kl div loss over ensembles
            kl_div /= num_eval_ensemble

            # . . make sure the outputs are positive
            log_outputs = utils.logmeanexp(outputs, dim=2)
            #log_outputs = torch.mean(outputs, dim=2)

            # . . network predictions
            _, predictions = torch.max(log_outputs, 1)

            # . . update statistics
            # . . number of correct predictions
            num_correct += (predictions == targets).sum().item()
            # . . number of predictions
            num_predictions += targets.shape[0]

        # . . compute the training accuracy
        training_accuracy = num_correct / num_predictions

        # . . test accuracy: preferably, should not be the validation dataset
        num_correct = 0
        num_predictions = 0

        # . . iterate over batches
        for inputs, targets in testloader:
            # . . move to device
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            # . . prepare the outputs for multiple ensembles
            outputs = torch.zeros(inputs.shape[0], self.model.num_classes,
                                  num_eval_ensemble).to(self.device)

            # . . feed-forward network: multiple ensembles
            kl_div = 0.0
            for ens in range(num_eval_ensemble):
                outputs_, kl_div_ = self.model(inputs)
                # . . accumulate the kl div loss
                kl_div += kl_div_
                # . . keep the outputs
                outputs[:, :, ens] = F.log_softmax(outputs_, dim=1)
                #outputs[:,:,ens] = F.softmax(outputs_, dim=1)

            # . . normalise the kl div loss over ensembles
            kl_div /= num_eval_ensemble

            # . . make sure the outputs are positive
            log_outputs = utils.logmeanexp(outputs, dim=2)
            #log_outputs = torch.mean(outputs, dim=2)

            # . . network predictions
            _, predictions = torch.max(log_outputs, 1)

            # . . update statistics
            # . . number of correct predictions
            num_correct += (predictions == targets).sum().item()
            # . . number of predictions
            num_predictions += targets.shape[0]

        # . . compute the training accuracy
        test_accuracy = num_correct / num_predictions

        # . . INFO
        print(
            f"Training accuracy: {training_accuracy:.4f}, Test accuracy: {test_accuracy:.4f}"
        )

        return training_accuracy, test_accuracy