def main(args): """Main function for the testing pipeline :args: commandline arguments :returns: None """ ########################################################################## # Basic settings # ########################################################################## exp_dir = 'experiments' model_dir = os.path.join(exp_dir, 'models') model_file = os.path.join(model_dir, 'best.pth') val_dataset = dataset.NCovDataset('data/', stage='val') val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=11, drop_last=False) cov_net = model.COVNet(n_classes=args.n_classes) if torch.cuda.is_available(): cov_net.cuda() state = torch.load(model_file) cov_net.load_state_dict(state.state_dict()) with torch.no_grad(): val_loss, metric_collects = evaluate_model(cov_net, val_loader) prefix = '******Evaluate******' utils.print_progress(mean_loss=val_loss, metric_collects=metric_collects, prefix=prefix)
def main(args): """Main function for the training pipeline :args: commandlien arguments :returns: None """ ########################################################################## # Basic settings # ########################################################################## exp_dir = 'experiments' log_dir = os.path.join(exp_dir, 'logs') model_dir = os.path.join(exp_dir, 'models') os.makedirs(model_dir, exist_ok=True) ########################################################################## # Define all the necessary variables for model training and evaluation # ########################################################################## writer = SummaryWriter(log_dir) train_dataset = dataset.NCovDataset('data/', stage='train') weights = train_dataset.make_weights_for_balanced_classes() weights = torch.DoubleTensor(weights) sampler = torch.utils.data.sampler.WeightedRandomSampler( weights, len(train_dataset.case_ids)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=20, drop_last=False, sampler=sampler) val_dataset = dataset.NCovDataset('data/', stage='val') val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=11, drop_last=False) cov_net = model.COVNet(n_classes=3) if torch.cuda.is_available(): cov_net = cov_net.cuda() optimizer = optim.Adam(cov_net.parameters(), lr=args.lr, weight_decay=0.1) if args.lr_scheduler == "plateau": scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True) elif args.lr_scheduler == "step": scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=args.gamma) best_val_loss = float('inf') best_val_accu = float(0) iteration_change_loss = 0 t_start_training = time.time() ########################################################################## # Main training loop # ########################################################################## for epoch in range(args.epochs): current_lr = get_lr(optimizer) t_start = time.time() ############################################################ # The actual training and validation step for each epoch # ############################################################ train_loss, train_metric = train_model(cov_net, train_loader, epoch, args.epochs, optimizer, writer, current_lr, args.log_every) with torch.no_grad(): val_loss, val_metric = evaluate_model(cov_net, val_loader, epoch, args.epochs, writer, current_lr) ############################## # Adjust the learning rate # ############################## if args.lr_scheduler == 'plateau': scheduler.step(val_loss) elif args.lr_scheduler == 'step': scheduler.step() t_end = time.time() delta = t_end - t_start utils.print_epoch_progress(train_loss, val_loss, delta, train_metric, val_metric) iteration_change_loss += 1 print('-' * 30) train_acc, val_acc = train_metric['accuracy'], val_metric['accuracy'] file_name = ('train_acc_{}_val_acc_{}_epoch_{}.pth'.format( train_acc, val_acc, epoch)) torch.save(cov_net, os.path.join(model_dir, file_name)) if val_acc > best_val_accu: best_val_accu = val_acc if bool(args.save_model): torch.save(cov_net, os.path.join(model_dir, 'best.pth')) if val_loss < best_val_loss: best_val_loss = val_loss iteration_change_loss = 0 if iteration_change_loss == args.patience: print( ('Early stopping after {0} iterations without the decrease ' + 'of the val loss').format(iteration_change_loss)) break t_end_training = time.time() print('training took {}s'.format(t_end_training - t_start_training))