class FullySupervisedTrainer: def __init__(self, batch_size, model_params, n_steps, optimizer, adam, sgd, steps_validation, steps_checkpoint, dataset, save_path): self.n_steps = n_steps self.start_step = 0 self.steps_validation = steps_validation self.steps_checkpoint = steps_checkpoint self.num_labeled = 50000 self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \ get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=self.num_labeled, which_dataset=dataset, validation=False) print('Labeled samples: ' + str(len(self.train_loader.sampler))) self.batch_size = self.train_loader.batch_size self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") print(self.device) depth, k, n_out = model_params self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=True).to(self.device) self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=True).to(self.device) for param in self.ema_model.parameters(): param.detach_() if optimizer == 'adam': self.lr, self.weight_decay = adam self.momentum, self.lr_decay = None, None self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999) else: self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True) self.ema_optimizer = None self.criterion = nn.CrossEntropyLoss() self.train_accuracies, self.train_losses, = [], [] self.val_accuracies, self.val_losses, = [], [] self.best_acc = 0 self.augment = Augment(K=1) self.path = save_path self.writer = SummaryWriter() def train(self): iter_train_loader = iter(self.train_loader) for step in range(self.n_steps): self.model.train() # Get next batch of data try: x_input, x_labels, _ = iter_train_loader.next() # Check if batch size has been cropped for last batch if x_input.shape[0] < self.batch_size: iter_train_loader = iter(self.train_loader) x_input, x_labels, _ = iter_train_loader.next() except: iter_train_loader = iter(self.train_loader) x_input, x_labels, _ = iter_train_loader.next() # Send to GPU x_input = x_input.to(self.device) x_labels = x_labels.to(self.device) # Augment x_input = self.augment(x_input) x_input = x_input.reshape((-1, 3, 32, 32)) # Compute X' predictions x_output = self.model(x_input) # Compute loss loss = self.criterion(x_output, x_labels) # Step self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.ema_optimizer: self.ema_optimizer.step() # Decaying learning rate. Used in with SGD Nesterov optimizer if not self.ema_optimizer and step in self.lr_decay: for g in self.optimizer.param_groups: g['lr'] *= 0.2 # Evaluate model self.model.eval() if step > 0 and not step % self.steps_validation: val_acc, is_best = self.evaluate_loss_acc(step) if is_best: self.save_model(step=step, path=f'{self.path}/best_checkpoint.pt') # Save checkpoint if step > 10000 and not step % self.steps_checkpoint: self.save_model(step=step, path=f'{self.path}/checkpoint_{step}.pt') # --- Training finished --- test_val, test_acc = self.evaluate(self.test_loader) print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" % (test_val, test_acc)) self.writer.flush() # --- support functions --- def evaluate_loss_acc(self, step): val_loss, val_acc = self.evaluate(self.val_loader) self.val_losses.append(val_loss) self.val_accuracies.append(val_acc) train_loss, train_acc = self.evaluate(self.train_loader) self.train_losses.append(train_loss) self.train_accuracies.append(train_acc) is_best = False if val_acc > self.best_acc: self.best_acc = val_acc is_best = True print( "Step %d.\tLoss train_lbl/valid %.2f %.2f\t Accuracy train_lbl/valid %.2f %.2f \tBest acc %.2f \t%s" % (step, train_loss, val_loss, train_acc, val_acc, self.best_acc, time.ctime())) self.writer.add_scalar("Loss train_label", train_loss, step) self.writer.add_scalar("Loss validation", val_loss, step) self.writer.add_scalar("Accuracy train_label", train_acc, step) self.writer.add_scalar("Accuracy validation", val_acc, step) return val_acc, is_best def evaluate(self, dataloader): correct, total, loss = 0, 0, 0 with torch.no_grad(): for i, data in enumerate(dataloader, 0): inputs, labels = data[0].to(self.device), data[1].to( self.device) if self.ema_optimizer: outputs = self.ema_model(inputs) else: outputs = self.model(inputs) loss += self.criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss /= dataloader.__len__() acc = correct / total * 100 return loss, acc def save_model(self, step=None, path='../models/model.pt'): if not step: step = self.n_steps # Training finished torch.save( { 'step': step, 'model_state_dict': self.model.state_dict(), 'ema_state_dict': self.ema_model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss_train': self.train_losses, 'loss_val': self.val_losses, 'acc_train': self.train_accuracies, 'acc_val': self.val_accuracies, 'steps': self.n_steps, 'batch_size': self.batch_size, 'num_labels': self.num_labeled, 'lr': self.lr, 'weight_decay': self.weight_decay, 'momentum': self.momentum, 'lr_decay': self.lr_decay, 'lbl_idx': self.lbl_idx, 'val_idx': self.val_idx, }, path) def load_checkpoint(self, model_name): saved_model = torch.load(f'../models/{model_name}') self.model.load_state_dict(saved_model['model_state_dict']) self.ema_model.load_state_dict(saved_model['ema_state_dict']) self.optimizer.load_state_dict(saved_model['optimizer_state_dict']) self.start_step = saved_model['step'] self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=saved_model['lbl_idx'], unlbl_idxs=[], valid_idxs=saved_model['val_idx']) print('Model ' + model_name + ' loaded.')
class MixMatchTrainer: def __init__(self, batch_size, num_lbls, model_params, n_steps, K, lambda_u, optimizer, adam, sgd, steps_validation, steps_checkpoint, dataset, save_path, use_pseudo, tau): self.validation_set = False self.n_steps = n_steps self.start_step = 0 self.K = K self.steps_validation = steps_validation self.steps_checkpoint = steps_checkpoint self.num_labeled = num_lbls self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx \ = get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=num_lbls, which_dataset=dataset, validation=self.validation_set) print('Labeled samples: ' + str(len(self.labeled_loader.sampler)) + '\tUnlabeled samples: ' + str(len(self.unlabeled_loader.sampler))) self.targets_list = np.array(self.labeled_loader.dataset.targets) self.batch_size = self.labeled_loader.batch_size self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(self.device) # -- Model -- depth, k, n_out = model_params self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device) self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device) for param in self.ema_model.parameters(): param.detach_() if optimizer == 'adam': self.lr, self.weight_decay = adam self.momentum, self.lr_decay = None, None self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999) else: self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True) self.ema_optimizer = None self.lambda_u_max, self.step_top_up = lambda_u self.loss_mixmatch = Loss(self.lambda_u_max, self.step_top_up) self.criterion = nn.CrossEntropyLoss() self.train_accuracies, self.train_losses, = [], [] self.val_accuracies, self.val_losses, = [], [] self.best_acc = 0 self.mixmatch = MixMatch(self.model, self.batch_size, self.device) self.writer = SummaryWriter() self.path = save_path # -- Pseudo label -- self.use_pseudo = use_pseudo self.steps_pseudo_lbl = 5000 self.tau = tau # confidence threshold self.min_unlbl_samples = 1000 # Make a deep copy of original unlabeled loader _, self.unlabeled_loader_original, _, _, _, _, _ \ = get_dataloaders_with_index(path='../data', batch_size=batch_size, num_labeled=num_lbls, which_dataset=dataset, lbl_idxs=self.lbl_idx, unlbl_idxs=self.unlbl_idx, valid_idxs=self.val_idx, validation=self.validation_set) def train(self): iter_labeled_loader = iter(self.labeled_loader) iter_unlabeled_loader = iter(self.unlabeled_loader) for step in range(self.start_step, self.n_steps): # Get next batch of data self.model.train() try: x_imgs, x_labels, _ = iter_labeled_loader.next() # Check if batch size has been cropped for last batch if x_imgs.shape[0] < self.batch_size: iter_labeled_loader = iter(self.labeled_loader) x_imgs, x_labels, _ = iter_labeled_loader.next() except: iter_labeled_loader = iter(self.labeled_loader) x_imgs, x_labels, _ = iter_labeled_loader.next() try: u_imgs, _, _ = iter_unlabeled_loader.next() if u_imgs.shape[0] < self.batch_size: iter_unlabeled_loader = iter(self.unlabeled_loader) u_imgs, _, _ = iter_unlabeled_loader.next() except: iter_unlabeled_loader = iter(self.unlabeled_loader) u_imgs, _, _ = iter_unlabeled_loader.next() # Send to GPU x_imgs = x_imgs.to(self.device) x_labels = x_labels.to(self.device) u_imgs = u_imgs.to(self.device) # MixMatch algorithm x, u = self.mixmatch.run(x_imgs, x_labels, u_imgs) x_input, x_targets = x u_input, u_targets = u u_targets.detach_() # stop gradients from propagation to label guessing # Compute X' predictions x_output = self.model(x_input) # Compute U' predictions. Separate in batches u_batch_outs = [] for k in range(self.K): u_batch = u_input[k * self.batch_size:(k + 1) * self.batch_size] u_batch_outs.append(self.model(u_batch)) u_outputs = torch.cat(u_batch_outs, dim=0) # Compute loss loss = self.loss_mixmatch(x_output, x_targets, u_outputs, u_targets, step) # Step self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.ema_optimizer: self.ema_optimizer.step() # Decaying learning rate. Used in with SGD Nesterov optimizer if not self.ema_optimizer and step in self.lr_decay: for g in self.optimizer.param_groups: g['lr'] *= 0.2 # Evaluate model self.model.eval() if step > 0 and not step % self.steps_validation: val_acc, is_best = self.evaluate_loss_acc(step) if is_best and step > 10000: self.save_model(step=step, path=f'{self.path}/best_checkpoint.pt') # Save checkpoint if step > 0 and not step % self.steps_checkpoint: self.save_model(step=step, path=f'{self.path}/checkpoint_{step}.pt') # Generate Pseudo-labels if self.use_pseudo and step >= 50000 and not step % self.steps_pseudo_lbl: # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth] matrix = self.get_pseudo_labels() self.print_threshold_comparison(matrix) # Generate pseudo set based on threshold (same for all classes) if self.tau != -1: matrix = self.generate_pseudo_set(matrix) # Generate pseudo set balanced (top 90% guesses of each class) else: matrix = self.generate_pseudo_set_balanced(matrix) iter_labeled_loader = iter(self.labeled_loader) iter_unlabeled_loader = iter(self.unlabeled_loader) # Save torch.save(matrix, f'{self.path}/pseudo_matrix_balanced_{step}.pt') # --- Training finished --- test_val, test_acc = self.evaluate(self.test_loader) print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" % (test_val, test_acc)) self.writer.flush() # --- support functions --- def evaluate_loss_acc(self, step): val_loss, val_acc = self.evaluate(self.val_loader) self.val_losses.append(val_loss) self.val_accuracies.append(val_acc) train_loss, train_acc = self.evaluate(self.labeled_loader) self.train_losses.append(train_loss) self.train_accuracies.append(train_acc) is_best = False if val_acc > self.best_acc: self.best_acc = val_acc is_best = True print("Step %d.\tLoss train_lbl/valid %.2f %.2f\t Accuracy train_lbl/valid %.2f %.2f \tBest acc %.2f \t%s" % (step, train_loss, val_loss, train_acc, val_acc, self.best_acc, time.ctime())) self.writer.add_scalar("Loss train_label", train_loss, step) self.writer.add_scalar("Loss validation", val_loss, step) self.writer.add_scalar("Accuracy train_label", train_acc, step) self.writer.add_scalar("Accuracy validation", val_acc, step) return val_acc, is_best def evaluate(self, dataloader): correct, total, loss = 0, 0, 0 with torch.no_grad(): for i, data in enumerate(dataloader, 0): inputs, labels = data[0].to(self.device), data[1].to(self.device) if self.ema_optimizer: outputs = self.ema_model(inputs) else: outputs = self.model(inputs) loss += self.criterion(outputs, labels).item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss /= dataloader.__len__() acc = correct / total * 100 return loss, acc def get_losses(self): return self.loss_mixmatch.loss_list, self.loss_mixmatch.lx_list, self.loss_mixmatch.lu_list, self.loss_mixmatch.lu_weighted_list def save_model(self, step=None, path=f'../model.pt'): loss_list, lx, lu, lu_weighted = self.get_losses() if not step: step = self.n_steps # Training finished torch.save({ 'step': step, 'model_state_dict': self.model.state_dict(), 'ema_state_dict': self.ema_model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss_train': self.train_losses, 'loss_val': self.val_losses, 'acc_train': self.train_accuracies, 'acc_val': self.val_accuracies, 'loss_batch': loss_list, 'lx': lx, 'lu': lu, 'lu_weighted': lu_weighted, 'steps': self.n_steps, 'batch_size': self.batch_size, 'num_labels': self.num_labeled, 'lambda_u_max': self.lambda_u_max, 'step_top_up': self.step_top_up, 'lr': self.lr, 'weight_decay': self.weight_decay, 'momentum': self.momentum, 'lr_decay': self.lr_decay, 'lbl_idx': self.lbl_idx, 'unlbl_idx': self.unlbl_idx, 'val_idx': self.val_idx, }, path) def load_checkpoint(self, path_checkpoint): saved_model = torch.load(path_checkpoint) self.model.load_state_dict(saved_model['model_state_dict']) self.ema_model.load_state_dict(saved_model['ema_state_dict']) self.optimizer.load_state_dict(saved_model['optimizer_state_dict']) self.start_step = saved_model['step'] self.train_losses = saved_model['loss_train'] self.val_losses = saved_model['loss_val'] self.train_accuracies = saved_model['acc_train'] self.val_accuracies = saved_model['acc_val'] self.batch_size = saved_model['batch_size'] self.num_labeled = saved_model['num_labels'] self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=saved_model['lbl_idx'], unlbl_idxs=saved_model['unlbl_idx'], valid_idxs=saved_model['val_idx'], validation=self.validation_set) self.unlabeled_loader_original = self.unlabeled_loader print('Model ' + path_checkpoint + ' loaded.') def get_pseudo_labels(self): matrix = torch.tensor([], device=self.device) # Iterate through unlabeled loader for batch_idx, (data, target, idx) in enumerate(self.unlabeled_loader_original): with torch.no_grad(): # Get predictions for unlabeled samples out = self.model(data.to(self.device)) p_out = torch.softmax(out, dim=1) # turn into probability distribution confidence, pseudo_lbl = torch.max(p_out, dim=1) pseudo_lbl_batch = torch.vstack((idx.to(self.device), confidence, pseudo_lbl)).T # Append to matrix matrix = torch.cat((matrix, pseudo_lbl_batch), dim=0) # (n_unlabeled, 3) n_unlabeled = matrix.shape[0] indices = matrix[:, 0].cpu().numpy().astype(int) ground_truth = self.targets_list[indices] matrix = torch.vstack((matrix.T, torch.tensor(ground_truth, device=self.device))).T # (n_unlabeled, 4) matrix = torch.vstack((matrix.T, torch.zeros(n_unlabeled, device=self.device))).T # (n_unlabeled, 5) # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth] # Check if pseudo label is ground truth for i in range(n_unlabeled): if matrix[i, 2] == matrix[i, 3]: matrix[i, 4] = 1 return matrix def generate_pseudo_set(self, matrix): unlbl_mask1 = (matrix[:, 1] < self.tau) # unlbl_mask2 = (matrix[:, 1] >= 0.99) pseudo_mask = (matrix[:, 1] >= self.tau) # pseudo_mask = (matrix[:, 1] >= self.tau) & (matrix[:, 1] < 0.99) # unlbl_indices = torch.cat((matrix[unlbl_mask1, 0], matrix[unlbl_mask2, 0])) unlbl_indices = matrix[unlbl_mask1, 0] matrix = matrix[pseudo_mask, :] indices = matrix[:, 0] new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist()) new_unlbl_idx = np.int_(unlbl_indices.tolist()) self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx, valid_idxs=self.val_idx, validation=self.validation_set) assert np.allclose(self.val_idx, new_val_idx), 'error' assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error' # Change real labels for pseudo labels for i in range(matrix.shape[0]): index = int(matrix[i, 0].item()) assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index] pseudo_labels = int(matrix[i, 2].item()) self.labeled_loader.dataset.targets[index] = pseudo_labels correct = torch.sum(matrix[:, 4]).item() pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0 print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc)) print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx), len(new_unlbl_idx), len(self.val_idx))) return matrix def generate_pseudo_set_balanced(self, matrix_all): unlbl_indices = torch.tensor([], device=self.device) # Get top 10% confident guesses for each class matrix = torch.tensor([], device=self.device) for i in range(10): matrix_label = matrix_all[matrix_all[:, 2] == i, :] threshold = torch.quantile(matrix_label[:, 1], 0.9) # returns prob in the percentile 90 unlbl_idxs = matrix_label[matrix_label[:, 1] < threshold, 0] matrix_label = matrix_label[matrix_label[:, 1] >= threshold, :] matrix = torch.cat((matrix, matrix_label), dim=0) unlbl_indices = torch.cat((unlbl_indices, unlbl_idxs)) indices = matrix[:, 0] new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist()) new_unlbl_idx = np.int_(unlbl_indices.tolist()) self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \ get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled, which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx, valid_idxs=self.val_idx, validation=self.validation_set) assert np.allclose(self.val_idx, new_val_idx), 'error' assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error' # Change real labels for pseudo labels for i in range(matrix.shape[0]): index = int(matrix[i, 0].item()) assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index] pseudo_labels = int(matrix[i, 2].item()) self.labeled_loader.dataset.targets[index] = pseudo_labels correct = torch.sum(matrix[:, 4]).item() pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0 print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc)) print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx), len(new_unlbl_idx), len(self.val_idx))) return matrix def print_threshold_comparison(self, matrix): m2 = matrix[matrix[:, 1] >= 0.9, :] for i, tau in enumerate([0.9, 0.95, 0.97, 0.99, 0.999]): pseudo_labels = m2[m2[:, 1] >= tau, :] total = pseudo_labels.shape[0] correct = torch.sum(pseudo_labels[:, 4]).item() print('Confidence threshold %.3f\t Generated / Correct / Precision\t %d\t%d\t%.2f ' % (tau, total, correct, correct / (total + np.finfo(float).eps) * 100))