Exemplo n.º 1
0
def eval_model(test_data, models, args):
    '''
        Run model on test data, and return test stats (includes loss

        accuracy, etc)
    '''
    if not isinstance(models, dict):
        models = {'model': models}
    if args.cuda:
        models['model'] = models['model'].cuda()

    batch_size = args.batch_size // args.batch_splits
    test_stats = init_metrics_dictionary(modes=['test'])
    data_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              collate_fn=ignore_None_collate,
                                              pin_memory=True,
                                              drop_last=False)

    loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, meta_loss, reg_loss, precision, recall, f1 = run_epoch(
        data_loader,
        train_model=False,
        truncate_epoch=False,
        models=models,
        optimizers=None,
        args=args)

    log_statement, test_stats = compute_eval_metrics(
        args, loss, accuracy, confusion_matrix, golds, preds, probs, auc,
        exams, meta_loss, reg_loss, precision, recall, f1, test_stats, 'test')
    print(log_statement)

    return test_stats
Exemplo n.º 2
0
def get_train_variables(args, model):
    '''
        Given args, and whether or not resuming training, return
        relevant train variales.

        returns:
        - start_epoch:  Index of initial epoch
        - epoch_stats: Dict summarizing epoch by epoch results
        - state_keeper: Object responsibile for saving and restoring training state
        - batch_size: sampling batch_size
        - models: Dict of models
        - optimizers: Dict of optimizers, one for each model
        - tuning_key: Name of epoch_stats key to control learning rate by
        - num_epoch_sans_improvement: Number of epochs since last dev improvment, as measured by tuning_key
        - num_epoch_since_reducing_lr: Number of epochs since last lr reduction
        - no_tuning_on_dev: True when training does not adapt based on dev performance
    '''
    start_epoch = 1
    if args.current_epoch is not None:
        start_epoch = args.current_epoch
    if args.lr is None:
        args.lr = args.init_lr
    if args.epoch_stats is not None:
        epoch_stats = args.epoch_stats
    else:
        epoch_stats = init_metrics_dictionary(modes=['train', 'dev'])

    state_keeper = state.StateKeeper(args)
    batch_size = args.batch_size // args.batch_splits

    # Set up models
    if isinstance(model, dict):
        models = model
    else:
        models = {'model': model}

    # Setup optimizers
    optimizers = {}
    for name in models:
        model = models[name]

        if args.cuda:
            model = model.cuda()

        optimizers[name] = model_factory.get_optimizer(model, args)

    if args.optimizer_state is not None:
        for optimizer_name in args.optimizer_state:
            state_dict = args.optimizer_state[optimizer_name]
            optimizers[optimizer_name] = state_keeper.load_optimizer(
                optimizers[optimizer_name], state_dict)

    num_epoch_sans_improvement = 0
    num_epoch_since_reducing_lr = 0

    no_tuning_on_dev = args.no_tuning_on_dev or args.ten_fold_cross_val

    tuning_key = "dev_{}".format(args.tuning_metric)

    return start_epoch, epoch_stats, state_keeper, batch_size, models, optimizers, tuning_key, num_epoch_sans_improvement, num_epoch_since_reducing_lr, no_tuning_on_dev
Exemplo n.º 3
0
def compute_threshold_and_dev_stats(dev_data, models, args):
    '''
    Compute threshold based on the Dev results
    '''
    if not isinstance(models, dict):
        models = {'model': models}
    if args.cuda:
        models['model'] = models['model'].cuda()

    dev_stats = init_metrics_dictionary(modes=['dev'])

    batch_size = args.batch_size // args.batch_splits
    data_loader = torch.utils.data.DataLoader(dev_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              collate_fn=ignore_None_collate,
                                              pin_memory=True,
                                              drop_last=False)
    loss, accuracy, confusion_matrix, golds, preds, probs, auc, exams, reg_loss, precision, recall, f1 = run_epoch(
        data_loader,
        train_model=False,
        truncate_epoch=False,
        models=models,
        optimizers=None,
        args=args)

    human_preds = get_human_preds(exams, dev_data.metadata_json)

    threshold, (th_lb, th_ub) = stats.get_thresholds_interval(
        probs,
        golds,
        human_preds,
        rebalance_eval_cancers=args.rebalance_eval_cancers)
    args.threshold = threshold
    print(' Dev Threshold: {:.8f} ({:.8f} - {:.8f})'.format(
        threshold, th_lb, th_ub))

    log_statement, dev_stats = compute_eval_metrics(
        args, loss, accuracy, confusion_matrix, golds, preds, probs, auc,
        exams, reg_loss, precision, recall, f1, dev_stats, 'dev')
    print(log_statement)
    return dev_stats
Exemplo n.º 4
0
def compute_threshold_and_dev_stats(dev_data, models, args):
    '''
    Compute threshold based on the Dev results
    '''
    if not isinstance(models, dict):
        models = {'model': models}
    models['model'] = models['model'].to(args.device)

    dev_stats = init_metrics_dictionary(modes=['dev'])

    batch_size = args.batch_size // args.batch_splits
    data_loader = torch.utils.data.DataLoader(
        dev_data,
        batch_size = batch_size,
        shuffle = False,
        num_workers = args.num_workers,
        collate_fn = ignore_None_collate,
        pin_memory=True,
        drop_last = False)
    loss, golds, preds, probs, exams, reg_loss, censor_times, adv_loss = run_epoch(
        data_loader,
        train_model=False,
        truncate_epoch=False,
        models=models,
        optimizers=None,
        args=args)


    if ('detection' in args.dataset or 'risk' in args.dataset) and '1year' in args.dataset and not args.survival_analysis_setup:
        human_preds = get_human_preds(exams, dev_data.metadata_json)


        threshold, (th_lb, th_ub) = stats.get_thresholds_interval(probs, golds, human_preds, rebalance_eval_cancers=args.rebalance_eval_cancers)
        args.threshold = threshold
        print(' Dev Threshold: {:.8f} ({:.8f} - {:.8f})'.format(threshold, th_lb, th_ub))
    else:
        args.threshold = None
    log_statement, dev_stats = compute_eval_metrics(
                            args, loss,
                            golds, preds, probs, exams,
                            reg_loss, censor_times, adv_loss, dev_stats, 'dev')
    print(log_statement)
    return dev_stats