class TrainingSMNBNM(TrainingSoftMax): def __init__(self): super(TrainingSMNBNM, self).__init__() self.name = "class" self.model = PseudoMultiTaskNet(no_svd=True) if CUDA: self.model.cuda(0) self.lr = 0.01 self.epoch_args["optimizer"] = optim.SGD( self.model.parameters(), lr=10 * self.lr, momentum=self.momentum, weight_decay=self.weight_decay) self.epoch_args["scheduler"] = StepLR(self.epoch_args["optimizer"], step_size=self.step_size, gamma=self.gamma) def loss(outputs, labels): likelihoods = torch.log((outputs[0] + outputs[1] + outputs[2]) / 3) disagreement = torch.sum(torch.abs(outputs[0] - outputs[1])) _, y_pred = torch.max(likelihoods.data, 1) return (F.nll_loss(likelihoods, labels) + 0.01 * disagreement, y_pred) self.epoch_args["loss_fn"] = loss self.epoch_function = epoch_classification del self.epoch_args["aug_loss_fn"] del self.epoch_args["aug_optimizer"] del self.epoch_args["aug_scheduler"]
class TrainingSoftMax: def __init__(self): self.name = "class_sp_inter_ba" self.model = PseudoMultiTaskNet() if CUDA: self.model.cuda(0) self.writer = SummaryWriter() self.epoch_function = epoch_mixed self.epoch_args = {} self.epochs = 20 self.lr = 0.1 self.momentum = 0.9 self.weight_decay = 1e-5 self.step_size = np.ceil(self.epochs / 3) self.gamma = 0.1 def loss(outputs, labels): likelihoods = torch.log((outputs[0] + outputs[1] + outputs[2]) / 3) disagreement = torch.sum(torch.abs(outputs[0] - outputs[1])) sparsity = torch.norm(outputs[3], 1) _, y_pred = torch.max(likelihoods.data, 1) return (F.nll_loss(likelihoods, labels) + 0.01 * disagreement + 0.0001 * sparsity, y_pred) self.epoch_args["loss_fn"] = loss self.epoch_args["aug_loss_fn"] = lambda x, y: F.l1_loss( F.pad(x, (2, 2, 2, 2)), y) self.epoch_args["optimizer"] = optim.SGD( self.model.parameters(), lr=10 * self.lr, momentum=self.momentum, weight_decay=self.weight_decay) self.epoch_args["scheduler"] = StepLR(self.epoch_args["optimizer"], step_size=self.step_size, gamma=self.gamma) self.aug_batch_size = 16 self.batch_size = 16 * self.aug_batch_size self.aug_lr = 0.01 self.epoch_args["aug_optimizer"] = optim.SGD( self.model.parameters(), lr=10 * self.aug_lr, momentum=self.momentum, weight_decay=self.weight_decay) self.epoch_args["aug_scheduler"] = StepLR( self.epoch_args["aug_optimizer"], step_size=np.ceil(self.epochs / 2), gamma=self.gamma) # self.train_loader = load_augmented(self.batch_size, # self.aug_batch_size) self.train_loader = load_mnist(self.batch_size, train=True) self.val_loader = load_mnist(self.batch_size, train=False) for key, value in self.__dict__.items(): self.writer.add_text(f"config/{key}", str(value)) def pred_fn(self, outputs): likelihoods = torch.log((outputs[0] + outputs[1] + outputs[2]) / 3) _, y_pred = torch.max(likelihoods.data, 1) return y_pred def __call__(self): for key, value in self.__dict__.items(): self.writer.add_text(f"config/{key}", str(value)) train(self)