def step2_export_to_elasticsearch(hc, vds, args):
    if args.start_with_step > 2 or args.stop_after_step < 2 or args.only_export_to_elasticsearch_at_the_end:
        return hc, vds

    logger.info(
        "\n\n=============================== pipeline - step 2 - export to elasticsearch ==============================="
    )

    if vds is None or not args.skip_writing_intermediate_vds:
        stop_hail_context(hc)
        hc = create_hail_context()
        vds = read_in_dataset(hc,
                              args.step1_output_vds,
                              dataset_type=args.dataset_type,
                              skip_summary=True,
                              num_partitions=args.cpu_limit)

    export_to_elasticsearch(
        vds,
        args,
        operation=ELASTICSEARCH_UPSERT,
        delete_index_before_exporting=True,
        export_genotypes=True,
        disable_doc_values_for_fields=("sortedTranscriptConsequences", )
        if not bool(args.use_nested_objects_for_vep) else (),
        disable_index_for_fields=("sortedTranscriptConsequences", )
        if not bool(args.use_nested_objects_for_vep) else (),
        run_after_index_exists=(
            lambda: route_index_to_temp_es_cluster(True, args))
        if args.use_temp_loading_nodes else None,
    )

    args.start_with_step = 3  # step 2 finished, so, if an error occurs and it goes to retry, start with the next step

    return hc, vds
def step3_add_reference_datasets(hc, vds, args):
    if args.start_with_step > 3 or args.stop_after_step < 3:
        return hc, vds

    logger.info(
        "\n\n=============================== pipeline - step 3 - add reference datasets ==============================="
    )

    if vds is None or not args.skip_writing_intermediate_vds:
        stop_hail_context(hc)
        hc = create_hail_context()
        vds = read_in_dataset(hc,
                              args.step1_output_vds,
                              dataset_type=args.dataset_type,
                              skip_summary=True)

    if not args.only_export_to_elasticsearch_at_the_end:

        vds = compute_minimal_schema(vds, args.dataset_type)

    if args.dataset_type == "VARIANTS":
        # annotate with the combined reference data file which was generated using
        # ../download_and_create_reference_datasets/v01/hail_scripts/combine_all_variant_level_reference_data.py
        # and contains all these annotations in one .vds

        if not (args.exclude_dbnsfp or args.exclude_cadd or args.exclude_1kg
                or args.exclude_exac or args.exclude_topmed or args.exclude_mpc
                or args.exclude_gnomad or args.exclude_eigen
                or args.exclude_primate_ai or args.exclude_splice_ai):

            logger.info("\n==> add combined variant-level reference data")
            vds = add_combined_reference_data_to_vds(
                hc, vds, args.genome_version, subset=args.filter_interval)

        else:
            # annotate with each reference data file - one-by-one
            if not args.skip_annotations and not args.exclude_dbnsfp:
                logger.info("\n==> add dbnsfp")
                vds = add_dbnsfp_to_vds(hc,
                                        vds,
                                        args.genome_version,
                                        root="va.dbnsfp",
                                        subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_cadd:
                logger.info("\n==> add cadd")
                vds = add_cadd_to_vds(hc,
                                      vds,
                                      args.genome_version,
                                      root="va.cadd",
                                      subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_1kg:
                logger.info("\n==> add 1kg")
                vds = add_1kg_phase3_to_vds(hc,
                                            vds,
                                            args.genome_version,
                                            root="va.g1k",
                                            subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_exac:
                logger.info("\n==> add exac")
                vds = add_exac_to_vds(hc,
                                      vds,
                                      args.genome_version,
                                      root="va.exac",
                                      subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_topmed:
                logger.info("\n==> add topmed")
                vds = add_topmed_to_vds(hc,
                                        vds,
                                        args.genome_version,
                                        root="va.topmed",
                                        subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_mpc:
                logger.info("\n==> add mpc")
                vds = add_mpc_to_vds(hc,
                                     vds,
                                     args.genome_version,
                                     root="va.mpc",
                                     subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_gnomad:
                logger.info("\n==> add gnomad exomes")
                vds = add_gnomad_to_vds(hc,
                                        vds,
                                        args.genome_version,
                                        exomes_or_genomes="exomes",
                                        root="va.gnomad_exomes",
                                        subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_gnomad:
                logger.info("\n==> add gnomad genomes")
                vds = add_gnomad_to_vds(hc,
                                        vds,
                                        args.genome_version,
                                        exomes_or_genomes="genomes",
                                        root="va.gnomad_genomes",
                                        subset=args.filter_interval)

            if not args.skip_annotations and not args.exclude_eigen:
                logger.info("\n==> add eigen")
                vds = add_eigen_to_vds(hc,
                                       vds,
                                       args.genome_version,
                                       root="va.eigen",
                                       subset=args.filter_interval)

            if not args.exclude_primate_ai:
                logger.info("\n==> add primate_ai")
                vds = add_primate_ai_to_vds(hc,
                                            vds,
                                            args.genome_version,
                                            root="va.primate_ai",
                                            subset=args.filter_interval)

            if not args.exclude_splice_ai:
                logger.info("\n==> add splice_ai")
                vds = add_splice_ai_to_vds(hc,
                                           vds,
                                           args.genome_version,
                                           root="va.splice_ai",
                                           subset=args.filter_interval)

    if not args.skip_annotations and not args.exclude_clinvar:
        logger.info("\n==> add clinvar")
        vds = add_clinvar_to_vds(hc,
                                 vds,
                                 args.genome_version,
                                 root="va.clinvar",
                                 subset=args.filter_interval)

    if not args.skip_annotations and not args.exclude_hgmd:
        logger.info("\n==> add hgmd")
        vds = add_hgmd_to_vds(hc,
                              vds,
                              args.genome_version,
                              root="va.hgmd",
                              subset=args.filter_interval)

    if not args.is_running_locally and not args.skip_writing_intermediate_vds:
        write_vds(vds, args.step3_output_vds)

    args.start_with_step = 4  # step 3 finished, so, if an error occurs and it goes to retry, start with the next step

    return hc, vds
def step1_compute_derived_fields(hc, vds, args):
    if args.start_with_step > 1 or args.stop_after_step < 1:
        return hc, vds

    logger.info(
        "\n\n=============================== pipeline - step 1 - compute derived fields ==============================="
    )

    if vds is None or not args.skip_writing_intermediate_vds:
        stop_hail_context(hc)
        hc = create_hail_context()
        vds = read_in_dataset(hc,
                              args.step0_output_vds,
                              dataset_type=args.dataset_type,
                              skip_summary=True,
                              num_partitions=args.cpu_limit)

    parallel_computed_annotation_exprs = [
        "va.docId = %s" % get_expr_for_variant_id(512),
        "va.variantId = %s" % get_expr_for_variant_id(),
        "va.variantType= %s" % get_expr_for_variant_type(),
        "va.contig = %s" % get_expr_for_contig(),
        "va.pos = %s" % get_expr_for_start_pos(),
        "va.start = %s" % get_expr_for_start_pos(),
        "va.end = %s" % get_expr_for_end_pos(),
        "va.ref = %s" % get_expr_for_ref_allele(),
        "va.alt = %s" % get_expr_for_alt_allele(),
        "va.xpos = %s" % get_expr_for_xpos(pos_field="start"),
        "va.xstart = %s" % get_expr_for_xpos(pos_field="start"),
        "va.sortedTranscriptConsequences = %s" %
        get_expr_for_vep_sorted_transcript_consequences_array(
            vep_root="va.vep",
            include_coding_annotations=True,
            add_transcript_rank=bool(args.use_nested_objects_for_vep)),
    ]

    if args.dataset_type == "VARIANTS":
        FAF_CONFIDENCE_INTERVAL = 0.95  # based on https://macarthurlab.slack.com/archives/C027LHMPP/p1528132141000430

        parallel_computed_annotation_exprs += [
            "va.FAF = %s" % get_expr_for_filtering_allele_frequency(
                "va.info.AC[va.aIndex - 1]", "va.info.AN",
                FAF_CONFIDENCE_INTERVAL),
        ]

    serial_computed_annotation_exprs = [
        "va.xstop = %s" %
        get_expr_for_xpos(field_prefix="va.", pos_field="end"),
        "va.transcriptIds = %s" % get_expr_for_vep_transcript_ids_set(
            vep_transcript_consequences_root="va.sortedTranscriptConsequences"
        ),
        "va.domains = %s" % get_expr_for_vep_protein_domains_set(
            vep_transcript_consequences_root="va.sortedTranscriptConsequences"
        ),
        "va.transcriptConsequenceTerms = %s" %
        get_expr_for_vep_consequence_terms_set(
            vep_transcript_consequences_root="va.sortedTranscriptConsequences"
        ),
        "va.mainTranscript = %s" %
        get_expr_for_worst_transcript_consequence_annotations_struct(
            "va.sortedTranscriptConsequences"),
        "va.geneIds = %s" % get_expr_for_vep_gene_ids_set(
            vep_transcript_consequences_root="va.sortedTranscriptConsequences"
        ),
        "va.codingGeneIds = %s" % get_expr_for_vep_gene_ids_set(
            vep_transcript_consequences_root="va.sortedTranscriptConsequences",
            only_coding_genes=True),
    ]

    # serial_computed_annotation_exprs += [
    #   "va.sortedTranscriptConsequences = va.sortedTranscriptConsequences.map(c => drop(c, amino_acids, biotype))"
    #]

    if not bool(args.use_nested_objects_for_vep):
        serial_computed_annotation_exprs += [
            "va.sortedTranscriptConsequences = json(va.sortedTranscriptConsequences)"
        ]

    vds = vds.annotate_variants_expr(parallel_computed_annotation_exprs)

    for expr in serial_computed_annotation_exprs:
        vds = vds.annotate_variants_expr(expr)

    pprint(vds.variant_schema)

    INPUT_SCHEMA = {}
    if args.dataset_type == "VARIANTS":
        INPUT_SCHEMA["top_level_fields"] = """
            docId: String,
            variantId: String,
            originalAltAlleles: Set[String],

            contig: String,
            start: Int,
            pos: Int,
            end: Int,
            ref: String,
            alt: String,

            xpos: Long,
            xstart: Long,
            xstop: Long,

            rsid: String,
            --- qual: Double,
            filters: Set[String],
            aIndex: Int,

            geneIds: Set[String],
            transcriptIds: Set[String],
            codingGeneIds: Set[String],
            domains: Set[String],
            transcriptConsequenceTerms: Set[String],
            sortedTranscriptConsequences: String,
            mainTranscript: Struct,
        """

        if args.not_gatk_genotypes:
            INPUT_SCHEMA["info_fields"] = """
                AC: Array[Int],
                AF: Array[Double],
                AN: Int,
                --- BaseQRankSum: Double,
                --- ClippingRankSum: Double,
                --- DP: Int,
                --- FS: Double,
                --- InbreedingCoeff: Double,
                --- MQ: Double,
                --- MQRankSum: Double,
                --- QD: Double,
                --- ReadPosRankSum: Double,
                --- VQSLOD: Double,
                --- culprit: String,
            """
        else:
            INPUT_SCHEMA["info_fields"] = """
                AC: Array[Int],
                AF: Array[Double],
                AN: Int,
                --- BaseQRankSum: Double,
                --- ClippingRankSum: Double,
                --- DP: Int,
                --- FS: Double,
                --- InbreedingCoeff: Double,
                --- MQ: Double,
                --- MQRankSum: Double,
                --- QD: Double,
                --- ReadPosRankSum: Double,
                --- VQSLOD: Double,
                --- culprit: String,
            """
    elif args.dataset_type == "SV":
        INPUT_SCHEMA["top_level_fields"] = """
            docId: String,
            variantId: String,

            contig: String,
            start: Int,
            pos: Int,
            end: Int,
            ref: String,
            alt: String,

            xpos: Long,
            xstart: Long,
            xstop: Long,

            rsid: String,
            --- qual: Double,
            filters: Set[String],
            aIndex: Int,
            
            geneIds: Set[String],
            transcriptIds: Set[String],
            codingGeneIds: Set[String],
            domains: Set[String],
            transcriptConsequenceTerms: Set[String],
            sortedTranscriptConsequences: String,
            mainTranscript: Struct,
        """

        # END=100371979;SVTYPE=DEL;SVLEN=-70;CIGAR=1M70D	GT:FT:GQ:PL:PR:SR
        INPUT_SCHEMA["info_fields"] = """
            IMPRECISE: Boolean,
            SVTYPE: String,
            SVLEN: Int,
            END: Int,
            --- OCC: Int,
            --- FRQ: Double,
        """
    else:
        raise ValueError("Unexpected dataset_type: %s" % args.dataset_type)

    if args.exclude_vcf_info_field:
        INPUT_SCHEMA["info_fields"] = ""

    expr = convert_vds_schema_string_to_annotate_variants_expr(root="va.clean",
                                                               **INPUT_SCHEMA)

    vds = vds.annotate_variants_expr(expr=expr)
    vds = vds.annotate_variants_expr("va = va.clean")

    if not args.skip_writing_intermediate_vds:
        write_vds(vds, args.step1_output_vds)

    args.start_with_step = 2  # step 1 finished, so, if an error occurs and it goes to retry, start with the next step

    return hc, vds