示例#1
0
def train_and_save_model(df, count_train, save_path, num_epochs, val_interval,
                         only_evaluate):
    train_loader, val_loader, class_weights, sizes = create_train_and_test_data_loaders(
        df, count_train)

    pretrained_path = os.path.join(os.getcwd(), 'pretrained.pth')
    if os.path.exists(save_path) and only_evaluate:
        model = get_model(save_path)
    elif os.path.exists(pretrained_path):
        model = get_model(pretrained_path)
    else:
        model = get_model()

    loss_function = CombinedLoss(class_weights)
    wandb.config.learning_rate = 9e-5
    optimizer = torch.optim.AdamW(model.parameters(),
                                  wandb.config.learning_rate)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)
    wandb.watch(model)

    # start a typical PyTorch training
    best_metric = float('-inf')
    best_metric_epoch = -1
    writer = SummaryWriter(log_dir=wandb.run.dir)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if only_evaluate:
        logger.info('Evaluating NN model on validation data')
        evaluate_model(model, val_loader, device, writer, 0, 'val')
        logger.info('Evaluating NN model on training data')
        evaluate_model(model, train_loader, device, writer, 0, 'train')
        return sizes

    _, file_name = os.path.split(save_path)

    for epoch in range(num_epochs):
        logger.info('-' * 25)
        logger.info(f'epoch {epoch + 1}/{num_epochs}')
        model.train()
        epoch_loss = 0
        step = 0
        epoch_len = len(train_loader.dataset) // train_loader.batch_size
        logger.info(f'epoch_len: {epoch_len}')
        y_true = []
        y_pred = []

        for batch_data in train_loader:
            step += 1
            inputs = batch_data['img'][torchio.DATA].to(device)
            info = batch_data['info'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            y_true.extend(info[..., 0].cpu().tolist())
            y = outputs[..., 0].cpu().tolist()
            y = [int(round(y[t])) for t in range(len(y))]
            y = [max(0, min(y[t], 10))
                 for t in range(len(y))]  # clamp to 0 - 10 range
            y_pred.extend(y)

            loss = loss_function(outputs, info)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            logger.debug(f'{step}:{loss.item():.4f}')
            print('.', end='', flush=True)
            if step % 100 == 0:
                print('', flush=True)  # new line
            writer.add_scalar('train_loss', loss.item(),
                              epoch_len * epoch + step)
            wandb.log({'train_loss': loss.item()})
        epoch_loss /= step
        logger.info(f'\nepoch {epoch + 1} average loss: {epoch_loss:.4f}')
        wandb.log({'epoch average loss': epoch_loss})
        epoch_cm = confusion_matrix(y_true, y_pred)
        logger.info(f'confusion matrix:\n{epoch_cm}')
        wandb.log({'confusion matrix': epoch_cm})

        if (epoch + 1) % val_interval == 0:
            logger.info('Evaluating on validation set')
            metric = evaluate_model(model, val_loader, device, writer, epoch,
                                    'val')

            if metric >= best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), save_path)

                torch.save(model.state_dict(),
                           os.path.join(wandb.run.dir, file_name))
                logger.info(f'saved new best metric model as {save_path}')

            logger.info(
                'current epoch: {} current metric: {:.2f} best metric: {:.2f} at epoch {}'
                .format(epoch + 1, metric, best_metric, best_metric_epoch))

            scheduler.step()
            logger.info(
                f'Learning rate after epoch {epoch + 1}: {optimizer.param_groups[0]["lr"]}'
            )
            wandb.log({'learn_rate': optimizer.param_groups[0]['lr']})

    epoch_suffix = '.epoch' + str(num_epochs)
    torch.save(model.state_dict(), save_path + epoch_suffix)
    torch.save(model.state_dict(),
               os.path.join(wandb.run.dir, file_name + epoch_suffix))

    logger.info(
        f'train completed, best_metric: {best_metric:.2f} at epoch: {best_metric_epoch}'
    )
    writer.close()
    return sizes
示例#2
0
    args = parser.parse_args()
    logger.info(args)

    monai.config.print_config()

    if args.all:
        logger.info(f'Training {args.nfolds} folds')
        for f in range(args.nfolds):
            process_folds(args.folds, f, False, args.nfolds)
        # evaluate all at the end, so results are easy to pick up from the log
        for f in range(args.nfolds):
            process_folds(args.folds, f, True, args.nfolds)
    elif args.folds is not None:
        process_folds(args.folds, args.vfold, args.evaluate, args.nfolds)
    elif args.modelfile is not None and args.evaluate1 is not None:
        evaluate1(get_model(args.modelfile), args.evaluate1)
    elif args.predicthd is not None:
        predict_hd_data_root = args.predicthd
        df = read_and_normalize_data_frame(
            predict_hd_data_root + r'phenotype/bids_image_qc_information.tsv')
        logger.info(df)
        full_path = Path('bids_image_qc_information-customized.csv').absolute()
        df.to_csv(full_path, index=False)
        logger.info(f'CSV file written: {full_path}')
    elif args.ncanda is not None:
        logger.info('Adding support for NCANDA data is a TODO')
    else:
        logger.info('Not enough arguments specified')
        logger.info(parser.format_help())