def train_and_save_model(df, count_train, save_path, num_epochs, val_interval, only_evaluate): train_loader, val_loader, class_weights, sizes = create_train_and_test_data_loaders( df, count_train) pretrained_path = os.path.join(os.getcwd(), 'pretrained.pth') if os.path.exists(save_path) and only_evaluate: model = get_model(save_path) elif os.path.exists(pretrained_path): model = get_model(pretrained_path) else: model = get_model() loss_function = CombinedLoss(class_weights) wandb.config.learning_rate = 9e-5 optimizer = torch.optim.AdamW(model.parameters(), wandb.config.learning_rate) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90) wandb.watch(model) # start a typical PyTorch training best_metric = float('-inf') best_metric_epoch = -1 writer = SummaryWriter(log_dir=wandb.run.dir) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if only_evaluate: logger.info('Evaluating NN model on validation data') evaluate_model(model, val_loader, device, writer, 0, 'val') logger.info('Evaluating NN model on training data') evaluate_model(model, train_loader, device, writer, 0, 'train') return sizes _, file_name = os.path.split(save_path) for epoch in range(num_epochs): logger.info('-' * 25) logger.info(f'epoch {epoch + 1}/{num_epochs}') model.train() epoch_loss = 0 step = 0 epoch_len = len(train_loader.dataset) // train_loader.batch_size logger.info(f'epoch_len: {epoch_len}') y_true = [] y_pred = [] for batch_data in train_loader: step += 1 inputs = batch_data['img'][torchio.DATA].to(device) info = batch_data['info'].to(device) optimizer.zero_grad() outputs = model(inputs) y_true.extend(info[..., 0].cpu().tolist()) y = outputs[..., 0].cpu().tolist() y = [int(round(y[t])) for t in range(len(y))] y = [max(0, min(y[t], 10)) for t in range(len(y))] # clamp to 0 - 10 range y_pred.extend(y) loss = loss_function(outputs, info) loss.backward() optimizer.step() epoch_loss += loss.item() logger.debug(f'{step}:{loss.item():.4f}') print('.', end='', flush=True) if step % 100 == 0: print('', flush=True) # new line writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step) wandb.log({'train_loss': loss.item()}) epoch_loss /= step logger.info(f'\nepoch {epoch + 1} average loss: {epoch_loss:.4f}') wandb.log({'epoch average loss': epoch_loss}) epoch_cm = confusion_matrix(y_true, y_pred) logger.info(f'confusion matrix:\n{epoch_cm}') wandb.log({'confusion matrix': epoch_cm}) if (epoch + 1) % val_interval == 0: logger.info('Evaluating on validation set') metric = evaluate_model(model, val_loader, device, writer, epoch, 'val') if metric >= best_metric: best_metric = metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), os.path.join(wandb.run.dir, file_name)) logger.info(f'saved new best metric model as {save_path}') logger.info( 'current epoch: {} current metric: {:.2f} best metric: {:.2f} at epoch {}' .format(epoch + 1, metric, best_metric, best_metric_epoch)) scheduler.step() logger.info( f'Learning rate after epoch {epoch + 1}: {optimizer.param_groups[0]["lr"]}' ) wandb.log({'learn_rate': optimizer.param_groups[0]['lr']}) epoch_suffix = '.epoch' + str(num_epochs) torch.save(model.state_dict(), save_path + epoch_suffix) torch.save(model.state_dict(), os.path.join(wandb.run.dir, file_name + epoch_suffix)) logger.info( f'train completed, best_metric: {best_metric:.2f} at epoch: {best_metric_epoch}' ) writer.close() return sizes
args = parser.parse_args() logger.info(args) monai.config.print_config() if args.all: logger.info(f'Training {args.nfolds} folds') for f in range(args.nfolds): process_folds(args.folds, f, False, args.nfolds) # evaluate all at the end, so results are easy to pick up from the log for f in range(args.nfolds): process_folds(args.folds, f, True, args.nfolds) elif args.folds is not None: process_folds(args.folds, args.vfold, args.evaluate, args.nfolds) elif args.modelfile is not None and args.evaluate1 is not None: evaluate1(get_model(args.modelfile), args.evaluate1) elif args.predicthd is not None: predict_hd_data_root = args.predicthd df = read_and_normalize_data_frame( predict_hd_data_root + r'phenotype/bids_image_qc_information.tsv') logger.info(df) full_path = Path('bids_image_qc_information-customized.csv').absolute() df.to_csv(full_path, index=False) logger.info(f'CSV file written: {full_path}') elif args.ncanda is not None: logger.info('Adding support for NCANDA data is a TODO') else: logger.info('Not enough arguments specified') logger.info(parser.format_help())