def test_sparse(): weight = torch.randn(5, 1).cuda().requires_grad_() weight_sparse = weight.detach().clone().requires_grad_() optimizer_dense = MADGRAD([weight], lr=1e-3, momentum=0) optimizer_sparse = MADGRAD([weight_sparse], lr=1e-3, momentum=0) weight.grad = torch.rand_like(weight) weight.grad[0] = 0.0 # Add a zero weight_sparse.grad = weight.grad.to_sparse() optimizer_dense.step() optimizer_sparse.step() assert torch.allclose(weight, weight_sparse) weight.grad = torch.rand_like(weight) weight.grad[1] = 0.0 # Add a zero weight_sparse.grad = weight.grad.to_sparse() optimizer_dense.step() optimizer_sparse.step() assert torch.allclose(weight, weight_sparse) weight.grad = torch.rand_like(weight) weight.grad[0] = 0.0 # Add a zero weight_sparse.grad = weight.grad.to_sparse() optimizer_dense.step() optimizer_sparse.step() assert torch.allclose(weight, weight_sparse)
def train(self): device = self.device print('Running on device: {}'.format(device), 'start training...') print( f'Setting - Epochs: {self.num_epochs}, Learning rate: {self.learning_rate} ' ) train_loader = self.train_loader valid_loader = self.valid_loader model = self.model.to(device) if self.optimizer == 0: optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) elif self.optimizer == 1: optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) elif self.optimizer == 2: optimizer = MADGRAD(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) elif self.optimizer == 3: optimizer = AdamP(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) criterion = torch.nn.CrossEntropyLoss().to(device) if self.use_swa: optimizer = SWA(optimizer, swa_start=2, swa_freq=2, swa_lr=1e-5) # scheduler # scheduler_dct = { 0: None, 1: StepLR(optimizer, 10, gamma=0.5), 2: ReduceLROnPlateau(optimizer, 'min', factor=0.4, patience=int(0.3 * self.early_stopping_patience)), 3: CosineAnnealingLR(optimizer, T_max=5, eta_min=0.) } scheduler = scheduler_dct[self.scheduler] # early stopping early_stopping = EarlyStopping(patience=self.early_stopping_patience, verbose=True, path=f'checkpoint_{self.job}.pt') # training self.train_loss_lst = list() self.train_acc_lst = list() self.val_loss_lst = list() self.val_acc_lst = list() for epoch in range(1, self.num_epochs + 1): with tqdm(train_loader, unit='batch') as tepoch: avg_val_loss, avg_val_acc = None, None for idx, (img, label) in enumerate(tepoch): tepoch.set_description(f"Epoch {epoch}") model.train() optimizer.zero_grad() img, label = img.float().to(device), label.long().to( device) output = model(img) loss = criterion(output, label) predictions = output.argmax(dim=1, keepdim=True).squeeze() correct = (predictions == label).sum().item() accuracy = correct / len(img) loss.backward() optimizer.step() if idx == len(train_loader) - 1: val_loss_lst, val_acc_lst = list(), list() model.eval() with torch.no_grad(): for val_img, val_label in valid_loader: val_img, val_label = val_img.float().to( device), val_label.long().to(device) val_out = model(val_img) val_loss = criterion(val_out, val_label) val_pred = val_out.argmax( dim=1, keepdim=True).squeeze() val_acc = (val_pred == val_label ).sum().item() / len(val_img) val_loss_lst.append(val_loss.item()) val_acc_lst.append(val_acc) avg_val_loss = np.mean(val_loss_lst) avg_val_acc = np.mean(val_acc_lst) * 100. self.train_loss_lst.append(loss) self.train_acc_lst.append(accuracy) self.val_loss_lst.append(avg_val_loss) self.val_acc_lst.append(avg_val_acc) if scheduler is not None: current_lr = optimizer.param_groups[0]['lr'] else: current_lr = self.learning_rate # log tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy, val_loss=avg_val_loss, val_acc=avg_val_acc, current_lr=current_lr) # early stopping check early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping") break # scheduler update if scheduler is not None: if self.scheduler == 2: scheduler.step(avg_val_loss) else: scheduler.step() if self.use_swa: optimizer.swap_swa_sgd() self.model.load_state_dict(torch.load(f'checkpoint_{self.job}.pt'))