Esempio n. 1
0
def eval_model(test_data, models, args):
    '''
        Run model on test data, and return test stats (includes loss

        accuracy, etc)
    '''
    if not isinstance(models, dict):
        models = {'model': models}
    if args.cuda:
        models['model'] = models['model'].cuda()

    batch_size = args.batch_size // args.batch_splits
    test_stats = init_metrics_dictionary(modes=['test'])
    data_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              collate_fn=ignore_None_collate,
                                              pin_memory=True,
                                              drop_last=False)

    loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, meta_loss, reg_loss, precision, recall, f1 = run_epoch(
        data_loader,
        train_model=False,
        truncate_epoch=False,
        models=models,
        optimizers=None,
        args=args)

    log_statement, test_stats = compute_eval_metrics(
        args, loss, accuracy, confusion_matrix, golds, preds, probs, auc,
        exams, meta_loss, reg_loss, precision, recall, f1, test_stats, 'test')
    print(log_statement)

    return test_stats
Esempio n. 2
0
def compute_threshold_and_dev_stats(dev_data, models, args):
    '''
    Compute threshold based on the Dev results
    '''
    if not isinstance(models, dict):
        models = {'model': models}
    if args.cuda:
        models['model'] = models['model'].cuda()

    dev_stats = init_metrics_dictionary(modes=['dev'])

    batch_size = args.batch_size // args.batch_splits
    data_loader = torch.utils.data.DataLoader(dev_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              collate_fn=ignore_None_collate,
                                              pin_memory=True,
                                              drop_last=False)
    loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, reg_loss, precision, recall, f1 = run_epoch(
        data_loader,
        train_model=False,
        truncate_epoch=False,
        models=models,
        optimizers=None,
        args=args)

    human_preds = get_human_preds(exams, dev_data.metadata_json)

    threshold, (th_lb, th_ub) = stats.get_thresholds_interval(
        probs,
        golds,
        human_preds,
        rebalance_eval_cancers=args.rebalance_eval_cancers)
    args.threshold = threshold
    print(' Dev Threshold: {:.8f} ({:.8f} - {:.8f})'.format(
        threshold, th_lb, th_ub))

    log_statement, dev_stats = compute_eval_metrics(
        args, loss, accuracy, confusion_matrix, golds, preds, probs, auc,
        exams, reg_loss, precision, recall, f1, dev_stats, 'dev')
    print(log_statement)
    return dev_stats
Esempio n. 3
0
def compute_threshold_and_dev_stats(dev_data, models, args):
    '''
    Compute threshold based on the Dev results
    '''
    if not isinstance(models, dict):
        models = {'model': models}
    models['model'] = models['model'].to(args.device)

    dev_stats = init_metrics_dictionary(modes=['dev'])

    batch_size = args.batch_size // args.batch_splits
    data_loader = torch.utils.data.DataLoader(
        dev_data,
        batch_size = batch_size,
        shuffle = False,
        num_workers = args.num_workers,
        collate_fn = ignore_None_collate,
        pin_memory=True,
        drop_last = False)
    loss, golds, preds, probs, exams, reg_loss, censor_times, adv_loss = run_epoch(
        data_loader,
        train_model=False,
        truncate_epoch=False,
        models=models,
        optimizers=None,
        args=args)


    if ('detection' in args.dataset or 'risk' in args.dataset) and '1year' in args.dataset and not args.survival_analysis_setup:
        human_preds = get_human_preds(exams, dev_data.metadata_json)


        threshold, (th_lb, th_ub) = stats.get_thresholds_interval(probs, golds, human_preds, rebalance_eval_cancers=args.rebalance_eval_cancers)
        args.threshold = threshold
        print(' Dev Threshold: {:.8f} ({:.8f} - {:.8f})'.format(threshold, th_lb, th_ub))
    else:
        args.threshold = None
    log_statement, dev_stats = compute_eval_metrics(
                            args, loss,
                            golds, preds, probs, exams,
                            reg_loss, censor_times, adv_loss, dev_stats, 'dev')
    print(log_statement)
    return dev_stats
Esempio n. 4
0
def train_model(train_data, dev_data, model, args):
    '''
        Train model and tune on dev set. If model doesn't improve dev performance within args.patience
        epochs, then halve the learning rate, restore the model to best and continue training.

        At the end of training, the function will restore the model to best dev version.

        returns epoch_stats: a dictionary of epoch level metrics for train and test
        returns models : dict of models, containing best performing model setting from this call to train
    '''

    start_epoch, epoch_stats, state_keeper, batch_size, models, optimizers, tuning_key, num_epoch_sans_improvement, num_epoch_since_reducing_lr, no_tuning_on_dev = get_train_variables(
        args, model)

    train_data_loader, dev_data_loader = get_train_and_dev_dataset_loaders(
        args,
        train_data,
        dev_data,
        batch_size)
    for epoch in range(start_epoch, args.epochs + 1):

        print("-------------\nEpoch {}:\n".format(epoch))

        for mode, data_loader in [('Train', train_data_loader), ('Dev', dev_data_loader)]:
            train_model = mode == 'Train'
            key_prefix = mode.lower()
            loss,  golds, preds, probs, exams, reg_loss, censor_times, adv_loss = run_epoch(
                data_loader,
                train_model=train_model,
                truncate_epoch=True,
                models=models,
                optimizers=optimizers,
                args=args)

            log_statement, epoch_stats = compute_eval_metrics(args, loss, golds, preds,
                                                            probs, exams, reg_loss, censor_times, adv_loss, epoch_stats, key_prefix)

            if mode == 'Dev' and 'mammo_1year' in args.dataset:
                dev_human_preds = get_human_preds(exams, dev_data.metadata_json)
                threshold, _ = stats.get_thresholds_interval(probs, golds, dev_human_preds,
                    rebalance_eval_cancers=args.rebalance_eval_cancers, num_resamples=NUM_RESAMPLES_DURING_TRAIN)
                print(' Dev Threshold: {:.8f} '.format(threshold))
                (fnr, _), (tpr, _), (tnr, _) = stats.get_rates_intervals(probs, golds, threshold,
                                rebalance_eval_cancers=args.rebalance_eval_cancers, num_resamples=NUM_RESAMPLES_DURING_TRAIN)
                epoch_stats['{}_fnr'.format(key_prefix)].append(fnr)
                epoch_stats['{}_tnr'.format(key_prefix)].append(tnr)
                epoch_stats['{}_tpr'.format(key_prefix)].append(tpr)
                log_statement = "{} fnr: {:.3f} tnr: {:.3f} tpr: {:.3f}".format(log_statement, fnr, tnr, tpr)

            print(log_statement)

        # Save model if beats best dev, or if not tuning on dev
        best_func, arg_best = (min, np.argmin) if tuning_key == 'dev_loss' else (max, np.argmax)
        improved = best_func(epoch_stats[tuning_key]) == epoch_stats[tuning_key][-1]
        if improved or no_tuning_on_dev:
            num_epoch_sans_improvement = 0
            if not os.path.isdir(args.save_dir):
                os.makedirs(args.save_dir)
            epoch_stats['best_epoch'] = arg_best( epoch_stats[tuning_key] )
            state_keeper.save(models, optimizers, epoch, args.lr, epoch_stats)

        num_epoch_since_reducing_lr += 1
        if improved:
            num_epoch_sans_improvement = 0
        else:
            num_epoch_sans_improvement += 1
        print('---- Best Dev {} is {} at epoch {}'.format(
            args.tuning_metric,
            epoch_stats[tuning_key][epoch_stats['best_epoch']],
            epoch_stats['best_epoch'] + 1))

        if num_epoch_sans_improvement >= args.patience or \
                (no_tuning_on_dev and num_epoch_since_reducing_lr >= args.lr_reduction_interval):
            print("Reducing learning rate")
            num_epoch_sans_improvement = 0
            num_epoch_since_reducing_lr = 0
            if not args.turn_off_model_reset:
                models, optimizer_states, _, _, _ = state_keeper.load()

                # Reset optimizers
                for name in optimizers:
                    optimizer = optimizers[name]
                    state_dict = optimizer_states[name]
                    optimizers[name] = state_keeper.load_optimizer(optimizer, state_dict)
            # Reduce LR
            for name in optimizers:
                optimizer = optimizers[name]
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= args.lr_decay

            # Update lr also in args for resumable usage
            args.lr *= .5

    # Restore model to best dev performance, or last epoch when not tuning on dev
    models, _, _, _, _ = state_keeper.load()

    return epoch_stats, models