Beispiel #1
0
def train_rf_model(
    ht: hl.Table,
    rf_features: List[str],
    tp_expr: hl.expr.BooleanExpression,
    fp_expr: hl.expr.BooleanExpression,
    fp_to_tp: float = 1.0,
    num_trees: int = 500,
    max_depth: int = 5,
    test_expr: hl.expr.BooleanExpression = False,
) -> Tuple[hl.Table, pyspark.ml.PipelineModel]:
    """
    Perform random forest (RF) training using a Table annotated with features and training data.

    .. note::

        This function uses `train_rf` and extends it by:
            - Adding an option to apply the resulting model to test variants which are withheld from training.
            - Uses a false positive (FP) to true positive (TP) ratio to determine what variants to use for RF training.

    The returned Table includes the following annotations:
        - rf_train: indicates if the variant was used for training of the RF model.
        - rf_label: indicates if the variant is a TP or FP.
        - rf_test: indicates if the variant was used in testing of the RF model.
        - features: global annotation of the features used for the RF model.
        - features_importance: global annotation of the importance of each feature in the model.
        - test_results: results from testing the model on variants defined by `test_expr`.

    :param ht: Table annotated with features for the RF model and the positive and negative training data.
    :param rf_features: List of column names to use as features in the RF training.
    :param tp_expr: TP training expression.
    :param fp_expr: FP training expression.
    :param fp_to_tp: Ratio of FPs to TPs for creating the RF model. If set to 0, all training examples are used.
    :param num_trees: Number of trees in the RF model.
    :param max_depth: Maxmimum tree depth in the RF model.
    :param test_expr: An expression specifying variants to hold out for testing and use for evaluation only.
    :return: Table with TP and FP training sets used in the RF training and the resulting RF model.
    """

    ht = ht.annotate(_tp=tp_expr, _fp=fp_expr, rf_test=test_expr)

    rf_ht = sample_training_examples(
        ht, tp_expr=ht._tp, fp_expr=ht._fp, fp_to_tp=fp_to_tp, test_expr=ht.rf_test
    )
    ht = ht.annotate(rf_train=rf_ht[ht.key].train, rf_label=rf_ht[ht.key].label)

    summary = ht.group_by("_tp", "_fp", "rf_train", "rf_label", "rf_test").aggregate(
        n=hl.agg.count()
    )
    logger.info("Summary of TP/FP and RF training data:")
    summary.show(n=20)

    logger.info(
        "Training RF model:\nfeatures: {}\nnum_tree: {}\nmax_depth:{}".format(
            ",".join(rf_features), num_trees, max_depth
        )
    )

    rf_model = train_rf(
        ht.filter(ht.rf_train),
        features=rf_features,
        label="rf_label",
        num_trees=num_trees,
        max_depth=max_depth,
    )

    test_results = None
    if test_expr is not None:
        logger.info(f"Testing model on specified variants or intervals...")
        test_ht = ht.filter(hl.is_defined(ht.rf_label) & ht.rf_test)
        test_results = test_model(
            test_ht, rf_model, features=rf_features, label="rf_label"
        )

    features_importance = get_features_importance(rf_model)
    ht = ht.select_globals(
        features_importance=features_importance,
        features=rf_features,
        test_results=test_results,
    )

    return ht.select("rf_train", "rf_label", "rf_test"), rf_model
Beispiel #2
0
def train_rf(data_type, args):

    # Get unique hash for run and load previous runs
    run_hash = str(uuid.uuid4())[:8]
    rf_runs = get_rf_runs(data_type)
    while run_hash in rf_runs:
        run_hash = str(uuid.uuid4())[:8]

    ht = hl.read_table(rf_annotated_path(data_type, args.adj))

    ht = ht.repartition(500, shuffle=False)

    if not args.vqsr_training:
        if args.no_transmitted_singletons:
            tp_expr = ht.omni | ht.mills | ht.info_POSITIVE_TRAIN_SITE
        else:
            tp_expr = ht.omni | ht.mills | ht.info_POSITIVE_TRAIN_SITE | ht.transmitted_singleton

        ht = ht.annotate(tp=tp_expr)

    test_intervals_str = [] if not args.test_intervals else [
        args.test_intervals
    ] if isinstance(args.test_intervals, str) else args.test_intervals
    test_intervals_locus = [
        hl.parse_locus_interval(x) for x in test_intervals_str
    ]

    if test_intervals_locus:
        ht = ht.annotate_globals(test_intervals=test_intervals_locus)

    ht = sample_rf_training_examples(
        ht,
        tp_col='info_POSITIVE_TRAIN_SITE' if args.vqsr_training else 'tp',
        fp_col='info_NEGATIVE_TRAIN_SITE'
        if args.vqsr_training else 'fail_hard_filters',
        fp_to_tp=args.fp_to_tp)

    ht = ht.persist()

    rf_ht = ht.filter(ht[TRAIN_COL])
    rf_features = get_features_list(
        True, not (args.vqsr_features or args.median_features),
        args.vqsr_features, args.median_features)

    logger.info(
        "Training RF model:\nfeatures: {}\nnum_tree: {}\nmax_depth:{}\nTest intervals: {}"
        .format(",".join(rf_features), args.num_trees, args.max_depth,
                ",".join(test_intervals_str)))

    rf_model = random_forest.train_rf(rf_ht,
                                      features=rf_features,
                                      label=LABEL_COL,
                                      num_trees=args.num_trees,
                                      max_depth=args.max_depth)

    logger.info("Saving RF model")
    random_forest.save_model(rf_model,
                             rf_path(data_type,
                                     data='model',
                                     run_hash=run_hash),
                             overwrite=args.overwrite)

    test_results = None
    if args.test_intervals:
        logger.info("Testing model {} on intervals {}".format(
            run_hash, ",".join(test_intervals_str)))
        test_ht = hl.filter_intervals(ht, test_intervals_locus, keep=True)
        test_ht = test_ht.checkpoint('gs://gnomad-tmp/test_random_forest.ht',
                                     overwrite=True)
        test_ht = test_ht.filter(hl.is_defined(test_ht[LABEL_COL]))
        test_results = random_forest.test_model(
            test_ht,
            rf_model,
            features=get_features_list(True, not args.vqsr_features,
                                       args.vqsr_features),
            label=LABEL_COL)
        ht = ht.annotate_globals(test_results=test_results)

    logger.info("Writing RF training HT")
    features_importance = random_forest.get_features_importance(rf_model)
    ht = ht.annotate_globals(
        features_importance=features_importance,
        features=get_features_list(True, not args.vqsr_features,
                                   args.vqsr_features),
        vqsr_training=args.vqsr_training,
        no_transmitted_singletons=args.no_transmitted_singletons,
        adj=args.adj)
    ht.write(rf_path(data_type, data='training', run_hash=run_hash),
             overwrite=args.overwrite)

    logger.info("Adding run to RF run list")
    rf_runs[run_hash] = get_run_data(data_type, args.vqsr_training,
                                     args.no_transmitted_singletons, args.adj,
                                     test_intervals_str, features_importance,
                                     test_results)
    with hl.hadoop_open(rf_run_hash_path(data_type), 'w') as f:
        json.dump(rf_runs, f)

    return run_hash