class TrainerVaDE: """This is the trainer for the Variational Deep Embedding (VaDE). """ def __init__(self, args, device, dataloader_sup, dataloader_unsup, dataloader_test, n_classes): if args.dataset == 'mnist': from models import Autoencoder, VaDE self.autoencoder = Autoencoder().to(device) self.autoencoder.apply(weights_init_normal) self.VaDE = VaDE().to(device) elif args.dataset == 'webcam': from models_office import Autoencoder, VaDE, feature_extractor self.autoencoder = Autoencoder().to(device) checkpoint = torch.load('weights/imagenet_params.pth.tar', map_location=device) self.autoencoder.load_state_dict(checkpoint['state_dict'], strict=False) checkpoint = torch.load('weights/feature_extractor_params.pth.tar', map_location=device) self.feature_extractor = feature_extractor().to(device) self.feature_extractor.load_state_dict(checkpoint['state_dict']) self.freeze_extractor() self.VaDE = VaDE().to(device) self.dataloader_sup = dataloader_sup self.dataloader_unsup = dataloader_unsup self.dataloader_test = dataloader_test self.device = device self.args = args self.n_classes = n_classes def pretrain(self): """Here we train an stacked autoencoder which will be used as the initialization for the VaDE. This initialization is usefull because reconstruction in VAEs would be weak at the begining and the models are likely to get stuck in local minima. """ optimizer = optim.Adam(self.autoencoder.parameters(), lr=self.args.lr_ae) self.autoencoder.train() print('Training the autoencoder...') for epoch in range(1500): total_loss = 0 for x, _ in self.dataloader_unsup: optimizer.zero_grad() x = x.to(self.device) if self.args.dataset == 'webcam': x = self.feature_extractor(x) x = x.detach() x_hat = self.autoencoder(x) loss = F.binary_cross_entropy( x_hat, x, reduction='mean') #reconstruction error loss.backward() optimizer.step() total_loss += loss.item() print('Training Autoencoder... Epoch: {}, Loss: {}'.format( epoch, total_loss / len(self.dataloader_unsup))) self.save_weights_ae() #self.train_GMM() #training a GMM for initialize the VaDE #self.predict_GMM() #Predict and assign supervised points to its GMM components #self.save_weights_for_VaDE() #saving weights for the VaDE def train_GMM(self): """It is possible to fit a Gaussian Mixture Model (GMM) using the latent space generated by the stacked autoencoder. This way, we generate an initialization for the priors (pi, mu, var) of the VaDE model. """ print('Fiting Gaussian Mixture Model...') x = torch.cat([data[0] for data in self.dataloader_unsup ]).to(self.device) #all x samples. if self.args.dataset == 'webcam': x = self.feature_extractor(x) x = x.detach() z = self.autoencoder.encode(x) self.gmm = GaussianMixture(n_components=self.n_classes, covariance_type='diag') self.gmm.fit(z.cpu().detach().numpy()) def predict_GMM(self): """It is possible to fit a Gaussian Mixture Model (GMM) using the latent space generated by the stacked autoencoder. This way, we generate an initialization for the priors (pi, mu, var) of the VaDE model. """ print('Predicting over Gaussian Mixture Model...') x, y = torch.cat([(data[0], data[1]) for data in self.dataloader_sup ]).to(self.device) #all x samples. x = x[np.argsort(y.cpu().detach().numpy())] if self.args.dataset == 'webcam': x = self.feature_extractor(x) x = x.detach() z = self.autoencoder.encode(x) probas = self.gmm.predict_proba(z.cpu().detach().numpy()) self.assign_GMMS(probas) def assign_GMMS(self, probas): assignation = [] possibilities = np.arange(self.n_classes) index = 0 toselect = 1 while len(possibilities) > 0: sorted_ = np.argsort(probas[index]) max_ = sorted_[-toselect] if max_ in possibilities: assignation.append(max_) possibilities = np.setdiff1d(possibilities, max_) index += 1 toselect = 1 else: toselect += 1 self.assignation = assignation def save_weights_for_VaDE(self): """Saving the pretrained weights for the encoder, decoder, pi, mu, var. """ print('Saving weights.') state_dict = self.autoencoder.state_dict() self.VaDE.load_state_dict(state_dict, strict=False) self.VaDE.pi_prior.data = torch.from_numpy( self.gmm.weights_[self.assignation]).float().to(self.device) self.VaDE.mu_prior.data = torch.from_numpy( self.gmm.means_[self.assignation]).float().to(self.device) self.VaDE.log_var_prior.data = torch.log( torch.from_numpy( self.gmm.covariances_[self.assignation])).float().to( self.device) torch.save( self.VaDE.state_dict(), 'weights/pretrained_parameters_{}.pth'.format(self.args.dataset)) def save_weights_ae(self): """Saving the pretrained weights for the encoder, decoder, pi, mu, var. """ print('Saving weights.') state = {'state_dict': self.autoencoder.state_dict()} torch.save( state, 'weights/autoencoder_parameters_{}.pth.tar'.format( self.args.dataset)) def save_weights_vade(self): """Saving the pretrained weights for the encoder, decoder, pi, mu, var. """ print('Saving weights.') state = {'state_dict': self.autoencoder.state_dict()} torch.save( state, 'weights/vade_parameters_{}.pth.tar'.format(self.args.dataset)) def train(self): """ """ if self.args.pretrain == True: self.VaDE.load_state_dict( torch.load('weights/pretrained_parameters_{}.pth'.format( self.args.dataset), map_location=self.device)) else: self.VaDE.apply(weights_init_normal) self.optimizer = optim.Adam(self.VaDE.parameters(), lr=self.args.lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.9) self.acc = [] self.acc_t = [] self.rec = [] self.rec_t = [] self.dkl = [] self.dkl_t = [] self.forward_step = ComputeLosses(self.VaDE, self.args) print('Training VaDE...') self.test_VaDE(-1) for epoch in range(self.args.epochs): self.train_VaDE(epoch) self.test_VaDE(epoch) lr_scheduler.step() self.save_weights_vade() def train_VaDE(self, epoch): self.VaDE.train() total_loss = 0 total_dkl = 0 total_rec = 0 for (x_s, y_s), (x_u, _) in zip(cycle(self.dataloader_sup), self.dataloader_unsup): self.optimizer.zero_grad() x_s, y_s = x_s.to(self.device), y_s.to(self.device) x_u = x_u.to(self.device) if self.args.dataset == 'webcam': x_s = self.feature_extractor(x_s) x_s = x_s.detach() x_u = self.feature_extractor(x_u) x_u = x_u.detach() loss, reconst_loss, kl_div, acc = self.forward_step.forward( 'train', x_s, y_s, x_u) loss.backward() self.optimizer.step() total_loss += loss.item() total_dkl += kl_div.item() total_rec += reconst_loss.item() self.acc.append(acc) self.dkl.append(total_dkl / len(self.dataloader_unsup)) self.rec.append(total_rec / len(self.dataloader_unsup)) print('Training VaDE... Epoch: {}, Loss: {}, Acc: {}'.format( epoch, total_loss / len(self.dataloader_unsup), acc)) def test_VaDE(self, epoch): self.VaDE.eval() with torch.no_grad(): total_loss = 0 total_acc = 0 total_dkl = 0 total_rec = 0 for x, y in self.dataloader_test: x, y = x.to(self.device), y.to(self.device) if self.args.dataset == 'webcam': x = self.feature_extractor(x) x = x.detach() loss, reconst_loss, kl_div, acc = self.forward_step.forward( 'test', x, y) total_loss += loss.item() total_acc += acc.item() total_dkl += kl_div.item() total_rec += reconst_loss.item() self.acc_t.append(total_acc / len(self.dataloader_test)) self.dkl_t.append(total_dkl / len(self.dataloader_test)) self.rec_t.append(total_rec / len(self.dataloader_test)) print('Testing VaDE... Epoch: {}, Loss: {}, Acc: {}'.format( epoch, total_loss / len(self.dataloader_test), total_acc / len(self.dataloader_test))) def freeze_extractor(self): for _, param in self.feature_extractor.named_parameters(): param.requires_grad = False self.feature_extractor.eval()
class TrainerVaDE: """This is the trainer for the Variational Deep Embedding (VaDE). """ def __init__(self, args, device, dataloader): self.autoencoder = Autoencoder().to(device) self.VaDE = VaDE().to(device) self.dataloader = dataloader self.device = device self.args = args def pretrain(self): """Here we train an stacked autoencoder which will be used as the initialization for the VaDE. This initialization is usefull because reconstruction in VAEs would be weak at the begining and the models are likely to get stuck in local minima. """ optimizer = optim.Adam(self.autoencoder.parameters(), lr=0.002) self.autoencoder.apply( weights_init_normal ) #intializing weights using normal distribution. self.autoencoder.train() print('Training the autoencoder...') for epoch in range(30): total_loss = 0 for x, _ in self.dataloader: optimizer.zero_grad() x = x.to(self.device) x_hat = self.autoencoder(x) loss = F.binary_cross_entropy( x_hat, x, reduction='mean') # just reconstruction loss.backward() optimizer.step() total_loss += loss.item() print('Training Autoencoder... Epoch: {}, Loss: {}'.format( epoch, total_loss)) self.train_GMM() #training a GMM for initialize the VaDE self.save_weights_for_VaDE() #saving weights for the VaDE def train_GMM(self): """It is possible to fit a Gaussian Mixture Model (GMM) using the latent space generated by the stacked autoencoder. This way, we generate an initialization for the priors (pi, mu, var) of the VaDE model. """ print('Fiting Gaussian Mixture Model...') x = torch.cat([data[0] for data in self.dataloader ]).view(-1, 784).to(self.device) #all x samples. z = self.autoencoder.encode(x) self.gmm = GaussianMixture(n_components=10, covariance_type='diag') self.gmm.fit(z.cpu().detach().numpy()) def save_weights_for_VaDE(self): """Saving the pretrained weights for the encoder, decoder, pi, mu, var. """ print('Saving weights.') state_dict = self.autoencoder.state_dict() self.VaDE.load_state_dict(state_dict, strict=False) self.VaDE.pi_prior.data = torch.from_numpy( self.gmm.weights_).float().to(self.device) self.VaDE.mu_prior.data = torch.from_numpy(self.gmm.means_).float().to( self.device) self.VaDE.log_var_prior.data = torch.log( torch.from_numpy(self.gmm.covariances_)).float().to(self.device) torch.save(self.VaDE.state_dict(), self.args.pretrained_path) def train(self): """ """ if self.args.pretrain == True: self.VaDE.load_state_dict( torch.load(self.args.pretrained_path, map_location=self.device)) else: self.VaDE.apply(weights_init_normal) self.optimizer = optim.Adam(self.VaDE.parameters(), lr=self.args.lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.9) print('Training VaDE...') for epoch in range(self.args.epochs): self.train_VaDE(epoch) self.test_VaDE(epoch) lr_scheduler.step() def train_VaDE(self, epoch): self.VaDE.train() total_loss = 0 for x, _ in self.dataloader: self.optimizer.zero_grad() x = x.to(self.device) x_hat, mu, log_var, z = self.VaDE(x) #print('Before backward: {}'.format(self.VaDE.pi_prior)) loss = self.compute_loss(x, x_hat, mu, log_var, z) loss.backward() self.optimizer.step() total_loss += loss.item() #print('After backward: {}'.format(self.VaDE.pi_prior)) print('Training VaDE... Epoch: {}, Loss: {}'.format(epoch, total_loss)) def test_VaDE(self, epoch): self.VaDE.eval() with torch.no_grad(): total_loss = 0 y_true, y_pred = [], [] for x, true in self.dataloader: x = x.to(self.device) x_hat, mu, log_var, z = self.VaDE(x) gamma = self.compute_gamma(z, self.VaDE.pi_prior) pred = torch.argmax(gamma, dim=1) loss = self.compute_loss(x, x_hat, mu, log_var, z) total_loss += loss.item() y_true.extend(true.numpy()) y_pred.extend(pred.cpu().detach().numpy()) acc = self.cluster_acc(np.array(y_true), np.array(y_pred)) print('Testing VaDE... Epoch: {}, Loss: {}, Acc: {}'.format( epoch, total_loss, acc[0])) def compute_loss(self, x, x_hat, mu, log_var, z): p_c = self.VaDE.pi_prior gamma = self.compute_gamma(z, p_c) log_p_x_given_z = F.binary_cross_entropy(x_hat, x, reduction='sum') h = log_var.exp().unsqueeze(1) + (mu.unsqueeze(1) - self.VaDE.mu_prior).pow(2) h = torch.sum(self.VaDE.log_var_prior + h / self.VaDE.log_var_prior.exp(), dim=2) log_p_z_given_c = 0.5 * torch.sum(gamma * h) log_p_c = torch.sum(gamma * torch.log(p_c + 1e-9)) log_q_c_given_x = torch.sum(gamma * torch.log(gamma + 1e-9)) log_q_z_given_x = 0.5 * torch.sum(1 + log_var) loss = log_p_x_given_z + log_p_z_given_c - log_p_c + log_q_c_given_x - log_q_z_given_x loss /= x.size(0) return loss def compute_gamma(self, z, p_c): h = (z.unsqueeze(1) - self.VaDE.mu_prior).pow(2) / self.VaDE.log_var_prior.exp() h += self.VaDE.log_var_prior h += torch.Tensor([np.log(np.pi * 2)]).to(self.device) p_z_c = torch.exp( torch.log(p_c + 1e-9).unsqueeze(0) - 0.5 * torch.sum(h, dim=2)) + 1e-9 gamma = p_z_c / torch.sum(p_z_c, dim=1, keepdim=True) return gamma def cluster_acc(self, real, pred): D = max(pred.max(), real.max()) + 1 w = np.zeros((D, D), dtype=np.int64) for i in range(pred.size): w[pred[i], real[i]] += 1 ind = linear_assignment(w.max() - w) return sum([w[i, j] for i, j in ind]) * 1.0 / pred.size * 100, w