def train_VAE(self, max_epochs, train_dataloader, validation_dataloader): early_stopping = EarlyStopping('{}/{}_autoencoder.pt'.format( self.state_path, self.model_name), patience=10) for epoch in range(max_epochs): if early_stopping.early_stop: break train_loss = 0 for batch_idx, (data, _) in enumerate(train_dataloader): self.VAE.train() data = data.to(self.device) self.VAE_optim.zero_grad() batch_params = self.VAE(data) loss = self.VAE_criterion(batch_params, data) train_loss += loss.item() loss.backward() self.VAE_optim.step() if validation_dataloader is not None: validation_loss = self.unsupervised_validation_loss( validation_dataloader) early_stopping(validation_loss, self.VAE) if validation_dataloader is not None: early_stopping.load_checkpoint(self.VAE)
def train_classifier(self, max_epochs, train_dataloader, validation_dataloader, comparison): epochs = [] train_losses = [] validation_accs = [] early_stopping = EarlyStopping('{}/{}_classifier.pt'.format( self.model_name, self.dataset_name)) for epoch in range(max_epochs): if early_stopping.early_stop: break for batch_idx, (data, labels) in enumerate(train_dataloader): self.Classifier.train() data = data.float().to(self.device) labels = labels.to(self.device) self.Classifier_optim.zero_grad() with torch.no_grad(): z = self.Encoder(data) pred = self.Classifier(z) loss = self.Classifier_criterion(pred, labels) loss.backward() self.Classifier_optim.step() if comparison: epochs.append(epoch) train_losses.append(loss.item()) validation_accs.append( self.accuracy(validation_dataloader)) val = self.accuracy(validation_dataloader) print('Supervised Epoch: {} Validation acc: {}'.format(epoch, val)) early_stopping(1 - val, self.Classifier) if early_stopping.early_stop: early_stopping.load_checkpoint(self.Classifier) return epochs, train_losses, validation_accs
def train_classifier(self, max_epochs, train_dataloader, validation_dataloader): epochs = [] train_losses = [] validation_accs = [] early_stopping = EarlyStopping('{}/{}_classifier.pt'.format( self.state_path, self.model_name)) for epoch in range(max_epochs): if early_stopping.early_stop: break train_loss = 0 for batch_idx, (data, labels) in enumerate(train_dataloader): self.Classifier.train() data = data.to(self.device) labels = labels.to(self.device) self.Classifier_optim.zero_grad() with torch.no_grad(): z, _, _ = self.Encoder(data) pred = self.Classifier(z) loss = self.Classifier_criterion(pred, labels) loss.backward() self.Classifier_optim.step() train_loss += loss.item() if validation_dataloader is not None: acc = self.accuracy(validation_dataloader) validation_accs.append(acc) early_stopping(1 - acc, self.Classifier) epochs.append(epoch) train_losses.append(train_loss / len(train_dataloader)) if validation_dataloader is not None: early_stopping.load_checkpoint(self.Classifier) return epochs, train_losses, validation_accs
def train_autoencoder(self, max_epochs, train_dataloader, validation_dataloader): early_stopping = EarlyStopping('{}/{}_autoencoder.pt'.format( self.model_name, self.dataset_name), patience=10) for epoch in range(max_epochs): if early_stopping.early_stop: break train_loss = 0 for batch_idx, (data, _) in enumerate(train_dataloader): self.Autoencoder.train() data = data.to(self.device) self.Autoencoder_optim.zero_grad() recons = self.Autoencoder(data) loss = self.Autoencoder_criterion(recons, data) train_loss += loss.item() loss.backward() self.Autoencoder_optim.step() validation_loss = unsupervised_validation_loss( self.Autoencoder, validation_dataloader, self.Autoencoder_criterion, self.device) early_stopping(validation_loss, self.Autoencoder) print('Unsupervised Epoch: {} Loss: {} Validation loss: {}'.format( epoch, train_loss, validation_loss)) # print('Unsupervised Loss: {}'.format(train_loss)) if early_stopping.early_stop: early_stopping.load_checkpoint(self.Autoencoder)
def train_ladder(self, max_epochs, supervised_dataloader, unsupervised_dataloader, validation_dataloader): epochs = [] train_losses = [] validation_accs = [] early_stopping = EarlyStopping('{}/{}_inner.pt'.format( self.state_path, self.model_name)) for epoch in range(max_epochs): if early_stopping.early_stop: break train_loss = 0 for batch_idx, (labelled_data, unlabelled_data) in enumerate( zip(cycle(supervised_dataloader), unsupervised_dataloader)): self.ladder.train() self.optimizer.zero_grad() labelled_images, labels = labelled_data labelled_images = labelled_images.to(self.device) labels = labels.to(self.device) unlabelled_images, _ = unlabelled_data unlabelled_images = unlabelled_images.to(self.device) inputs = torch.cat((labelled_images, unlabelled_images), 0) batch_size = labelled_images.size(0) y_c, corr = self.ladder.forward_encoders( inputs, self.noise_std, True, batch_size) y, clean = self.ladder.forward_encoders( inputs, 0.0, True, batch_size) z_est_bn = self.ladder.forward_decoders( F.softmax(y_c, dim=1), corr, clean, batch_size) cost = self.supervised_cost_function.forward( labeled(y_c, batch_size), labels) zs = clean['unlabeled']['z'] u_cost = 0 for l in range(self.L, -1, -1): u_cost += self.unsupervised_cost_function.forward( z_est_bn[l], zs[l]) * self.denoising_cost[l] loss = cost + u_cost loss.backward() self.optimizer.step() train_loss += loss.item() if validation_dataloader is not None: acc = self.accuracy(validation_dataloader, 0) validation_accs.append(acc) early_stopping(1 - acc, self.ladder) epochs.append(epoch) train_losses.append(train_loss / len(unsupervised_dataloader)) if validation_dataloader is not None: early_stopping.load_checkpoint(self.ladder) return epochs, train_losses, validation_accs
def train_m2(self, max_epochs, labelled_loader, unlabelled_loader, validation_loader): if unlabelled_loader is None: alpha = 1 else: alpha = 0.1 * len(unlabelled_loader.dataset) / len( labelled_loader.dataset) epochs = [] train_losses = [] validation_accs = [] early_stopping = EarlyStopping('{}/{}_inner.pt'.format( self.state_path, self.model_name)) for epoch in range(max_epochs): if early_stopping.early_stop: break if unlabelled_loader is not None: data_iterator = zip(cycle(labelled_loader), unlabelled_loader) else: data_iterator = zip(labelled_loader, cycle([None])) train_loss = 0 for batch_idx, (labelled_data, unlabelled_data) in enumerate(data_iterator): self.M2.train() self.optimizer.zero_grad() labelled_images, labels = labelled_data labelled_images = labelled_images.float().to(self.device) labels = labels.to(self.device) labelled_predictions = self.M2.classify(labelled_images) labelled_loss = F.cross_entropy(labelled_predictions, labels) # labelled images ELBO L = self.elbo(labelled_images, y=labels) loss = L + alpha * labelled_loss if unlabelled_data is not None: unlabelled_images, _ = unlabelled_data unlabelled_images = unlabelled_images.float().to( self.device) U = self.elbo(unlabelled_images) loss += U loss.backward() self.optimizer.step() train_loss += loss.item() if validation_loader is not None: acc = self.accuracy(validation_loader) validation_accs.append(acc) early_stopping(1 - acc, self.M2) epochs.append(epoch) train_losses.append(train_loss) if validation_loader is not None: early_stopping.load_checkpoint(self.M2) return epochs, train_losses, validation_accs