def train(self): device = self.device print('Running on device: {}'.format(device), 'start training...') print( f'Setting - Epochs: {self.num_epochs}, Learning rate: {self.learning_rate} ' ) train_loader = self.train_loader valid_loader = self.valid_loader model = self.model.to(device) if self.optimizer == 0: optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) elif self.optimizer == 1: optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) elif self.optimizer == 2: optimizer = MADGRAD(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) elif self.optimizer == 3: optimizer = AdamP(model.parameters(), lr=self.learning_rate, weight_decay=1e-5) criterion = torch.nn.CrossEntropyLoss().to(device) if self.use_swa: optimizer = SWA(optimizer, swa_start=2, swa_freq=2, swa_lr=1e-5) # scheduler # scheduler_dct = { 0: None, 1: StepLR(optimizer, 10, gamma=0.5), 2: ReduceLROnPlateau(optimizer, 'min', factor=0.4, patience=int(0.3 * self.early_stopping_patience)), 3: CosineAnnealingLR(optimizer, T_max=5, eta_min=0.) } scheduler = scheduler_dct[self.scheduler] # early stopping early_stopping = EarlyStopping(patience=self.early_stopping_patience, verbose=True, path=f'checkpoint_{self.job}.pt') # training self.train_loss_lst = list() self.train_acc_lst = list() self.val_loss_lst = list() self.val_acc_lst = list() for epoch in range(1, self.num_epochs + 1): with tqdm(train_loader, unit='batch') as tepoch: avg_val_loss, avg_val_acc = None, None for idx, (img, label) in enumerate(tepoch): tepoch.set_description(f"Epoch {epoch}") model.train() optimizer.zero_grad() img, label = img.float().to(device), label.long().to( device) output = model(img) loss = criterion(output, label) predictions = output.argmax(dim=1, keepdim=True).squeeze() correct = (predictions == label).sum().item() accuracy = correct / len(img) loss.backward() optimizer.step() if idx == len(train_loader) - 1: val_loss_lst, val_acc_lst = list(), list() model.eval() with torch.no_grad(): for val_img, val_label in valid_loader: val_img, val_label = val_img.float().to( device), val_label.long().to(device) val_out = model(val_img) val_loss = criterion(val_out, val_label) val_pred = val_out.argmax( dim=1, keepdim=True).squeeze() val_acc = (val_pred == val_label ).sum().item() / len(val_img) val_loss_lst.append(val_loss.item()) val_acc_lst.append(val_acc) avg_val_loss = np.mean(val_loss_lst) avg_val_acc = np.mean(val_acc_lst) * 100. self.train_loss_lst.append(loss) self.train_acc_lst.append(accuracy) self.val_loss_lst.append(avg_val_loss) self.val_acc_lst.append(avg_val_acc) if scheduler is not None: current_lr = optimizer.param_groups[0]['lr'] else: current_lr = self.learning_rate # log tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy, val_loss=avg_val_loss, val_acc=avg_val_acc, current_lr=current_lr) # early stopping check early_stopping(avg_val_loss, model) if early_stopping.early_stop: print("Early stopping") break # scheduler update if scheduler is not None: if self.scheduler == 2: scheduler.step(avg_val_loss) else: scheduler.step() if self.use_swa: optimizer.swap_swa_sgd() self.model.load_state_dict(torch.load(f'checkpoint_{self.job}.pt'))
def train(cfg, train_loader, val_loader, val_labels, k): # Set Config MODEL_ARC = cfg.values.model_arc OUTPUT_DIR = cfg.values.output_dir NUM_CLASSES = cfg.values.num_classes SAVE_PATH = os.path.join(OUTPUT_DIR, MODEL_ARC) best_score = 0. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') os.makedirs(SAVE_PATH, exist_ok=True) if k > 0: os.makedirs(SAVE_PATH + f'/{k}_fold', exist_ok=True) num_epochs = cfg.values.train_args.num_epochs max_lr = cfg.values.train_args.max_lr min_lr = cfg.values.train_args.min_lr weight_decay = cfg.values.train_args.weight_decay log_intervals = cfg.values.train_args.log_intervals model = CNNModel(model_arc=MODEL_ARC, num_classes=NUM_CLASSES) model.to(device) # base_optimizer = SGDP # optimizer = SAM(model.parameters(), base_optimizer, lr=max_lr, momentum=momentum) optimizer = MADGRAD(model.parameters(), lr=max_lr, weight_decay=weight_decay) first_cycle_steps = len(train_loader) * num_epochs // 2 scheduler = CosineAnnealingWarmupRestarts( optimizer, first_cycle_steps=first_cycle_steps, cycle_mult=1.0, max_lr=max_lr, min_lr=min_lr, warmup_steps=int(first_cycle_steps * 0.2), gamma=0.5) criterion = nn.BCEWithLogitsLoss() wandb.watch(model) for epoch in range(num_epochs): model.train() loss_values = AverageMeter() scaler = GradScaler() for step, (images, labels) in enumerate(tqdm(train_loader)): images = images.to(device) labels = labels.to(device) batch_size = labels.size(0) with autocast(): logits = model(images) loss = criterion(logits.view(-1), labels) loss_values.update(loss.item(), batch_size) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step() wandb.log({ 'Learning rate': get_learning_rate(optimizer)[0], 'Train Loss': loss_values.val }) if step % log_intervals == 0: tqdm.write( f'Epoch : [{epoch + 1}/{num_epochs}][{step}/{len(train_loader)}] || ' f'LR : {get_learning_rate(optimizer)[0]:.6e} || ' f'Train Loss : {loss_values.val:.4f} ({loss_values.avg:.4f}) ||' ) with torch.no_grad(): model.eval() loss_values = AverageMeter() preds = [] for step, (images, labels) in enumerate(tqdm(val_loader)): images = images.to(device) labels = labels.to(device) batch_size = labels.size(0) logits = model(images) loss = criterion(logits.view(-1), labels) preds.append(logits.sigmoid().to('cpu').numpy()) loss_values.update(loss.item(), batch_size) predictions = np.concatenate(preds) # f1, roc_auc = get_score(val_labels, predictions) roc_auc = get_score(val_labels, predictions) is_best = roc_auc >= best_score best_score = max(roc_auc, best_score) if is_best: if k > 0: remove_all_file(SAVE_PATH + f'/{k}_fold') print( f"Save checkpoints {SAVE_PATH + f'/{k}_fold/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth'}..." ) torch.save( model.state_dict(), SAVE_PATH + f'/{k}_fold/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth' ) else: remove_all_file(SAVE_PATH) print( f"Save checkpoints {SAVE_PATH + f'/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth'}..." ) torch.save( model.state_dict(), SAVE_PATH + f'/{epoch + 1}_epoch_{best_score * 100.0:.2f}%.pth') wandb.log({ 'Validation Loss average': loss_values.avg, 'ROC AUC Score': roc_auc, # 'F1 Score' : f1 }) tqdm.write(f'Epoch : [{epoch + 1}/{num_epochs}] || ' f'Val Loss : {loss_values.avg:.4f} || ' f'ROC AUC score : {roc_auc:.4f} ||')