Beispiel #1
0
def main():
  parser = argparse.ArgumentParser(
    description='LOL HI THERE',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
  )
  parser.add_argument('ssm_fn')
  parser.add_argument('params1_fn')
  parser.add_argument('params2_fn')
  args = parser.parse_args()

  variants = inputparser.load_ssms(args.ssm_fn)
  vids = common.extract_vids(variants)
  V, T, T_prime, omega = inputparser.load_read_counts(variants)
  M, S  = V.shape

  C1, vids1, assign1 = extract_assignment(args.params1_fn)
  C2, vids2, assign2 = extract_assignment(args.params2_fn)
  assert vids1 == vids2 == vids

  hparams = {
    'phi_alpha0': 1.,
    'phi_beta0': 1.,
    'conc': 1e-2,
  }
  llh1 = clustermaker.calc_llh(V, T_prime, assign1, hparams['phi_alpha0'], hparams['phi_beta0'], hparams['conc'])
  llh2 = clustermaker.calc_llh(V, T_prime, assign2, hparams['phi_alpha0'], hparams['phi_beta0'], hparams['conc'])
  nlglh1 = -llh1 / (M*S*np.log(2))
  nlglh2 = -llh2 / (M*S*np.log(2))

  h**o, comp, vm = sklearn.metrics.homogeneity_completeness_v_measure(assign1, assign2)
  ami = sklearn.metrics.adjusted_mutual_info_score(assign1, assign2)
  print(C1, C2, llh1, llh2, nlglh1, nlglh2, h**o, comp, vm, ami, sep=',')
Beispiel #2
0
def main():
  parser = argparse.ArgumentParser(
    description='LOL HI THERE',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
  )
  parser.add_argument('--use-supervars', action='store_true')
  parser.add_argument('ssm_fn')
  parser.add_argument('params_fn')
  parser.add_argument('citup_snv_fn')
  parser.add_argument('citup_vid_fn')
  parser.add_argument('citup_clusters_fn')
  args = parser.parse_args()

  variants = inputparser.load_ssms(args.ssm_fn)
  params = inputparser.load_params(args.params_fn)
  clusters = params['clusters']

  if args.use_supervars:
    supervars = clustermaker.make_cluster_supervars(clusters, variants)
    superclusters = clustermaker.make_superclusters(supervars)
    garbage = set()
    write_snvs(supervars, garbage, args.citup_snv_fn, args.citup_vid_fn)
    write_clusters(supervars, garbage, superclusters, args.citup_clusters_fn)
  else:
    garbage = set(params['garbage'])
    write_snvs(variants, garbage, args.citup_snv_fn, args.citup_vid_fn)
    write_clusters(variants, garbage, clusters, args.citup_clusters_fn)
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--uniform-proposal', action='store_true')
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    parser.add_argument('pastri_allele_counts_fn')
    parser.add_argument('pastri_proposal_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)
    clusters = params['clusters']
    supervars = clustermaker.make_cluster_supervars(clusters, variants)

    matrices = {
        'var_reads': extract_matrix(supervars, 'var_reads'),
        'total_reads': extract_matrix(supervars, 'total_reads'),
        'alpha': extract_matrix(supervars, 'var_reads'),
        'beta': extract_matrix(supervars, 'total_reads'),
    }
    if args.uniform_proposal:
        matrices['alpha'][:] = 1
        matrices['beta'][:] = 2

    C_max = 15
    matrices['alpha'] = matrices['alpha'][:C_max, ]
    matrices['beta'] = matrices['beta'][:C_max, ]

    write_matrices(('A', matrices['var_reads']),
                   ('D', matrices['total_reads']),
                   outfn=args.pastri_allele_counts_fn)
    write_matrices(('Alpha', matrices['alpha']), ('Beta', matrices['beta']),
                   outfn=args.pastri_proposal_fn)
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)
    clusters = params['clusters']

    supervars = clustermaker.make_cluster_supervars(clusters, variants)
    superclusters = clustermaker.make_superclusters(supervars)
    # Add empty initial cluster, which serves as tree root.
    superclusters.insert(0, [])
    M = len(superclusters)

    iterations = 1000
    parallel = 0

    parents = [[0, 0, 0], [0, 1, 2]]
    for P in parents:
        adj = _parents2adj(P)
        print_init(supervars, adj)
        for method in ('projection', 'rprop', 'graddesc'):
            phi, eta = phi_fitter._fit_phis(adj, superclusters, supervars,
                                            method, iterations, parallel)
            # Sometimes the `projection` fitter will return zeros, which result in an
            # LLH of -inf if the number of variant reads `V` is non-zero, since
            # `Binom(X=V > 0, | N=V+R, p=0) = 0`. To avoid this, set a floor of 1e-6
            # on phi values.
            phi = np.maximum(1e-6, phi)
            print_method(method, phi, supervars)
            print()
Beispiel #5
0
def main():
  parser = argparse.ArgumentParser(
    description='LOL HI THERE',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
  )
  parser.add_argument('--counts', required=True)
  parser.add_argument('in_ssm_fn')
  parser.add_argument('in_params_fn')
  parser.add_argument('out_base')
  args = parser.parse_args()

  random.seed(1337)

  counts = [int(C) for C in args.counts.split(',')]
  assert len(counts) == len(set(counts))
  ssms = inputparser.load_ssms(args.in_ssm_fn)
  params = inputparser.load_params(args.in_params_fn)
  sampnames = params['samples']

  # Always include diagnosis sample, on assumption we're working with
  # SJbALL022609 from Steph for the paper congraph figure.
  subsets = _select_samp_subsets(sampnames, counts, all_must_include=['D'])
  for subset in subsets:
    idxs = _find_idxs(sampnames, subset)
    new_ssms = _filter_ssms(ssms, idxs)
    new_params = dict(params)
    new_params['samples'] = subset

    out_base = '%s_S%s' % (args.out_base, len(subset))
    inputparser.write_ssms(new_ssms, out_base + '.ssm')
    with open(out_base + '.params.json', 'w') as F:
      json.dump(new_params, F)
Beispiel #6
0
def main():
  ssmfns = (sys.argv[1], sys.argv[3])
  paramfns = (sys.argv[2], sys.argv[4])

  ssms = [inputparser.load_ssms(F) for F in ssmfns]
  params = [inputparser.load_params(F) for F in paramfns]
  samps = [P['samples'] for P in params]


  samps_to_rename = (0,)
  for idx in samps_to_rename:
    samps[idx] = _rename(samps[idx])

  _compare(ssms, samps)
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    parser.add_argument('out_dir')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)
    sampnames = params['samples']

    convert(variants, sampnames, args.out_dir)
Beispiel #8
0
def main():
  parser = argparse.ArgumentParser(
    description='LOL HI THERE',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
  )
  parser.add_argument('ssm_fn')
  parser.add_argument('scresults_fn')
  parser.add_argument('params_fn_orig')
  parser.add_argument('params_fn_modified')
  args = parser.parse_args()

  variants = inputparser.load_ssms(args.ssm_fn)
  varid_map = build_variant_to_varid_map(variants)
  clusters, garbage = convert_clusters(args.scresults_fn, varid_map)
  add_missing_sex_variants_to_garbage(variants, clusters, garbage)
  write_results(clusters, garbage, args.params_fn_orig, args.params_fn_modified)
Beispiel #9
0
def main():
  parser = argparse.ArgumentParser(
    description='LOL HI THERE',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
  )
  parser.add_argument('in_ssm_fn')
  parser.add_argument('out_ssm_fn')
  args = parser.parse_args()

  np.set_printoptions(linewidth=400, precision=3, threshold=sys.maxsize, suppress=True)
  np.seterr(divide='raise', invalid='raise', over='raise')

  ssms = inputparser.load_ssms(args.in_ssm_fn)
  fixed_prop = _fix_omegas(ssms, print_bad=False)
  print('fixed_omegas=%s' % fixed_prop)
  inputparser.write_ssms(ssms, args.out_ssm_fn)
Beispiel #10
0
def _process(ssmfn, jsonfn, order):
    params = inputparser.load_params(jsonfn)
    ssms = inputparser.load_ssms(ssmfn)

    order = [int(idx) for idx in order.split(',')]
    N = len(params['samples'])
    assert set(range(N)) == set(order)
    assert len(list(ssms.values())[0]['var_reads']) == N

    params['samples'] = [params['samples'][idx] for idx in order]
    for vid in ssms.keys():
        for K in ('var_reads', 'ref_reads', 'total_reads', 'vaf', 'omega_v'):
            ssms[vid][K] = ssms[vid][K][order]

    with open(jsonfn, 'w') as F:
        json.dump(params, F)
    inputparser.write_ssms(ssms, ssmfn)
Beispiel #11
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--impute-garbage', action='store_true')
    parser.add_argument('neutree_fn')
    parser.add_argument('pairtree_ssm_fn')
    parser.add_argument('mutphi_fn')
    args = parser.parse_args()

    ntree = neutree.load(args.neutree_fn)
    mphi = mutphi.calc_mutphi(ntree.phis, ntree.logscores, ntree.clusterings,
                              args.pairtree_ssm_fn, ntree.counts)
    variants = inputparser.load_ssms(args.pairtree_ssm_fn)
    if args.impute_garbage:
        mphi = mutstat.impute_garbage(mphi, ntree.garbage,
                                      lambda vid: _impute(vid, variants))
    mutstat.write(mphi, args.mutphi_fn)
Beispiel #12
0
def _process(ssmfn, jsonfn, to_remove):
  params = inputparser.load_params(jsonfn)
  ssms = inputparser.load_ssms(ssmfn)

  to_remove = set([int(idx) for idx in to_remove.split(',')])
  N = len(params['samples'])
  all_samps = set(range(N))
  assert to_remove.issubset(all_samps)
  to_keep = sorted(all_samps - to_remove)
  assert len(to_keep) > 0

  params['samples'] = [params['samples'][idx] for idx in to_keep]
  for vid in ssms.keys():
    for K in ('var_reads', 'ref_reads', 'total_reads', 'vaf', 'omega_v'):
      ssms[vid][K] = ssms[vid][K][to_keep]

  with open(jsonfn, 'w') as F:
    json.dump(params, F)
  inputparser.write_ssms(ssms, ssmfn)
Beispiel #13
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--uniform-proposal', action='store_true')
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    parser.add_argument('lichee_snv_fn')
    parser.add_argument('lichee_cluster_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)
    sampnames = params['samples']
    clusters = params['clusters']
    garbage = set(params['garbage'])

    snv_indices = write_snvs(variants, sampnames, garbage, args.lichee_snv_fn)
    write_clusters(variants, clusters, snv_indices, args.lichee_cluster_fn)
def impute(ssmfn, params, mphi):
  clustered = set([V for C in params['clusters'] for V in C])
  mphi_vids = set(mphi.vids)
  missing = list(clustered - mphi_vids)
  if len(missing) == 0:
    sys.exit()

  variants = inputparser.load_ssms(ssmfn)
  missing_reads = np.array([variants[V]['total_reads'] for V in missing]).astype(np.float)
  assert np.all(missing_reads >= 1)
  # Assign uniform probability based on total read count.
  missing_logprobs = np.log(1 / missing_reads)

  combined = mutphi.Mutphi(
    vids = list(mphi.vids) + missing,
    assays = mphi.assays,
    logprobs = np.vstack((mphi.logprobs, missing_logprobs)),
  )
  return combined
Beispiel #15
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--use-supervars',
                        dest='use_supervars',
                        action='store_true')
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    parser.add_argument('pwgs_ssm_fn')
    parser.add_argument('pwgs_params_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)

    if args.use_supervars:
        variants = clustermaker.make_cluster_supervars(params['clusters'],
                                                       variants)
    write_ssms(variants, args.pwgs_ssm_fn)
    write_params(params['samples'], args.pwgs_params_fn)
Beispiel #16
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    parser.add_argument('calder_input_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)
    clusters = params['clusters']
    supervars = clustermaker.make_cluster_supervars(clusters, variants)

    vids1, var_reads = extract_matrix(supervars, 'var_reads')
    vids2, ref_reads = extract_matrix(supervars, 'ref_reads')
    assert vids1 == vids2
    vids = vids1

    _write_inputs(vids, params['samples'], var_reads, ref_reads,
                  args.calder_input_fn)
Beispiel #17
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--phi-hat-threshold',
                        type=float,
                        default=1 - 1e-2,
                        help='Blah')
    parser.add_argument('--quantile', type=float, default=0.5, help='Blah')
    parser.add_argument('--print-bad-data', action='store_true')
    parser.add_argument('in_ssm_fn')
    parser.add_argument('in_params_fn')
    parser.add_argument('out_params_fn')
    args = parser.parse_args()

    np.set_printoptions(linewidth=400,
                        precision=3,
                        threshold=sys.maxsize,
                        suppress=True)
    np.seterr(divide='raise', invalid='raise', over='raise')

    ssms = inputparser.load_ssms(args.in_ssm_fn)
    params = inputparser.load_params(args.in_params_fn)
    ssms = inputparser.remove_garbage(ssms, params['garbage'])

    bad_vids, bad_samp_prop = _remove_bad(ssms, args.phi_hat_threshold,
                                          args.quantile, args.print_bad_data)
    bad_ssm_prop = len(bad_vids) / len(ssms)
    if len(bad_vids) > 0:
        params['garbage'] = common.sort_vids(params['garbage'] + bad_vids)
        with open(args.out_params_fn, 'w') as F:
            json.dump(params, F)

    stats = {
        'bad_ssms': common.sort_vids(bad_vids),
        'bad_samp_prop': '%.3f' % bad_samp_prop,
        'bad_ssm_prop': '%.3f' % bad_ssm_prop,
    }
    for K, V in stats.items():
        print('%s=%s' % (K, V))
Beispiel #18
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--tree-index', type=int, default=0)
    parser.add_argument('--truth', dest='truth_fn')
    parser.add_argument('ssm_fn')
    parser.add_argument('results_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    if args.truth_fn:
        truth = _parse_truth(args.truth_fn)
    else:
        truth = {}

    results = resultserializer.Results(args.results_fn)
    sampnames = results.get('sampnames')
    clusters = results.get('clusters')
    garbage = results.get('garbage')
    variants = inputparser.remove_garbage(variants, garbage)

    phi = results.get('phi')[args.tree_index]
    struct = results.get('struct')[args.tree_index]
    K, S = phi.shape
    assert len(sampnames) == S
    eta = util.calc_eta(struct, phi)

    cns_pairs = stephutil.find_samp_pairs(sampnames, ' BM', ' CNS')
    spleen_pairs = stephutil.find_samp_pairs(sampnames, ' BM', ' Spleen')
    all_pairs = cns_pairs + spleen_pairs

    concord = _calc_concord(variants, clusters, eta, sampnames, all_pairs,
                            truth)

    results = {
        'concord': concord,
    }
    print(json.dumps(results))
Beispiel #19
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('ssm_fn')
    parser.add_argument('mutphi_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    vids = common.extract_vids(variants)
    var_reads = np.array([variants[V]['var_reads'] for V in vids])
    total_reads = np.array([variants[V]['total_reads'] for V in vids])
    omega_v = np.array([variants[V]['omega_v'] for V in vids])

    mle_phi = (1 / omega_v) * (var_reads / total_reads)
    assert np.all(0 <= mle_phi)
    mle_phi = np.minimum(1, mle_phi)

    clusters = [[V] for V in vids]
    llhs = [0]
    counts = [1]
    mphi = mutphi.calc_mutphi([mle_phi], llhs, [clusters], args.ssm_fn, counts)
    mutstat.write(mphi, args.mutphi_fn)
Beispiel #20
0
def calc_mutphi(cluster_phis, llhs, clusterings, ssmsfn, counts):
  assert len(cluster_phis) == len(llhs) == len(clusterings)
  variants = inputparser.load_ssms(ssmsfn)
  weights = util.softmax(llhs + np.log(counts))

  vids = None
  # TODO: make assays meaningful, rather than just always setting it to None.
  assays = None
  logprobs = None

  assert not np.allclose(0, weights)
  for (cluster_phi, clustering, weight) in zip(cluster_phis, clusterings, weights):
    # Note: 0*-np.inf is NaN. So, if we have a weight of zero (because the
    # tree's LLH was really bad) and a logprob of -inf for a mutation in a
    # sample (beause the phi assigned there was super bad), we get NaNs in our
    # output. Avoid this by skipping trees with zero weight.
    if np.isclose(0, weight):
      continue

    cluster_phi = evalutil.fix_rounding_errors(cluster_phi)
    assert np.all(0 <= cluster_phi) and np.all(cluster_phi <= 1)
    V, membership = util.make_membership_mat(clustering)
    mphi = np.dot(membership, cluster_phi)

    if vids is None:
      vids = V
    assert V == vids
    if logprobs is None:
      logprobs = np.zeros(mphi.shape)

    # TODO: should I be doing something like logsumexp?
    weighted = weight * _calc_logprob(mphi, vids, variants)
    assert not np.any(np.isnan(weighted)) and not np.any(np.isinf(weighted))
    logprobs += weighted

  return mutstat.Mutstat(vids=vids, assays=assays, stats=logprobs)
def main():
    all_plot_choices = set((
        'tree',
        'pairwise_separate',
        'pairwise_mle',
        'vaf_matrix',
        'phi',
        'phi_hat',
        'phi_interleaved',
        'cluster_stats',
        'eta',
        'diversity_indices',
    ))
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--seed', type=int)
    parser.add_argument('--tree-index', type=int, default=0)
    parser.add_argument('--plot',
                        dest='plot_choices',
                        type=lambda s: set(s.split(',')),
                        help='Things to plot; by default, plot everything')
    parser.add_argument('--omit-plots',
                        dest='omit_plots',
                        type=lambda s: set(s.split(',')),
                        help='Things to omit from plotting; overrides --plot')
    parser.add_argument('--runid')
    parser.add_argument(
        '--reorder-subclones',
        action='store_true',
        help=
        'Reorder subclones according to depth-first search through tree structure'
    )
    parser.add_argument(
        '--tree-json',
        dest='tree_json_fn',
        help=
        'Additional external file in which to store JSON, which is already stored statically in the HTML file'
    )
    parser.add_argument('--phi-orientation',
                        choices=('samples_as_rows', 'populations_as_rows'),
                        default='populations_as_rows')
    parser.add_argument(
        '--remove-normal',
        action='store_true',
        help=
        'Remove normal (non-cancerous) population 0 from tree, phi, and eta plots.'
    )
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    parser.add_argument('results_fn')
    parser.add_argument('discord_fn')
    parser.add_argument('html_out_fn')
    args = parser.parse_args()

    np.seterr(divide='raise', invalid='raise', over='raise')

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)

    plot_choices = _choose_plots(args.plot_choices, args.omit_plots,
                                 all_plot_choices)

    results = resultserializer.Results(args.results_fn)
    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)
    discord = _parse_discord(args.discord_fn)

    data = {
        K: results.get(K)[args.tree_index]
        for K in (
            'struct',
            'count',
            'llh',
            'prob',
            'phi',
        )
    }
    data['garbage'] = results.get('garbage')
    data['clusters'] = results.get('clusters')
    data['samples'] = params['samples']
    data['clustrel_posterior'] = results.get_mutrel('clustrel_posterior')
    if args.reorder_subclones:
        data, params = _reorder_subclones(data, params)

    if 'hidden_samples' in params:
        hidden = set(params['hidden_samples'])
        assert hidden.issubset(set(
            data['samples'])) and len(hidden) < len(data['samples'])
        visible_sampidxs = [
            idx for idx, samp in enumerate(data['samples'])
            if samp not in hidden
        ]
    else:
        visible_sampidxs = None

    samp_colours = params.get('samp_colours', None)
    pop_colours = params.get('pop_colours', None)
    if samp_colours is not None:
        assert set([S[0] for S in samp_colours]).issubset(data['samples'])
    if pop_colours is not None:
        assert len(pop_colours) == len(data['struct']) + 1

    supervars = clustermaker.make_cluster_supervars(data['clusters'], variants)
    supervars = [supervars[vid] for vid in common.sort_vids(supervars.keys())]

    with open(args.html_out_fn, 'w') as outf:
        write_header(args.runid, args.tree_index, outf)

        if 'tree' in plot_choices:
            tree_struct = util.make_tree_struct(
                data['struct'],
                data['count'],
                data['llh'],
                data['prob'],
                data['phi'],
                supervars,
                data['clusters'],
                data['samples'],
            )
            tree_struct['discord'] = discord

            _write_tree_html(
                tree_struct,
                args.tree_index,
                visible_sampidxs,
                samp_colours,
                pop_colours,
                'eta' in plot_choices,
                'diversity_indices' in plot_choices,
                'phi' in plot_choices,
                'phi_hat' in plot_choices,
                'phi_interleaved' in plot_choices,
                args.phi_orientation,
                args.remove_normal,
                outf,
            )
            if args.tree_json_fn is not None:
                _write_tree_json(tree_struct, args.tree_json_fn)

        if 'vaf_matrix' in plot_choices:
            vaf_plotter.plot_vaf_matrix(
                data['clusters'],
                variants,
                supervars,
                data['garbage'],
                data['phi'],
                data['samples'],
                should_correct_vaf=True,
                outf=outf,
            )

        if 'pairwise_mle' in plot_choices:
            relation_plotter.plot_ml_relations(data['clustrel_posterior'],
                                               outf)
        if 'pairwise_separate' in plot_choices:
            relation_plotter.plot_separate_relations(
                data['clustrel_posterior'], outf)
        if 'cluster_stats' in plot_choices:
            write_cluster_stats(data['clusters'], data['garbage'], supervars,
                                variants, outf)

        write_footer(outf)
Beispiel #22
0
def main():
    parser = argparse.ArgumentParser(
        description='LOL HI THERE',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--concentration',
        dest='logconc',
        type=float,
        default=-2,
        help=
        'log10(alpha) for Chinese restaurant process. The larger this is, the stronger the preference for more clusters.'
    )
    parser.add_argument('--parallel',
                        dest='parallel',
                        type=int,
                        default=1,
                        help='Number of tasks to run in parallel')
    parser.add_argument(
        '--prior',
        type=float,
        default=0.25,
        help=
        'Pairwise coclustering prior probability. Used only for --model=pairwise or --model=both.'
    )
    parser.add_argument('--model',
                        choices=('pairwise', 'linfreq'),
                        required=True,
                        help='Clustering model to use')
    parser.add_argument('ssm_fn')
    parser.add_argument('params_fn')
    args = parser.parse_args()

    variants = inputparser.load_ssms(args.ssm_fn)
    params = inputparser.load_params(args.params_fn)
    clusters = params['clusters']
    garbage = params.get('garbage', [])
    variants = inputparser.remove_garbage(variants, garbage)

    M = len(variants)
    S = len(list(variants.values())[0]['var_reads'])
    logconc = _normalize_logconc(args.logconc, S)

    if args.model == 'pairwise':
        vids, Z = cluster_pairwise._convert_clustering_to_assignment(clusters)
        logprior = _make_coclust_logprior(args.prior, S)
        mutrel_posterior, mutrel_evidence = pairwise.calc_posterior(
            variants, logprior, 'mutation', args.parallel)
        assert vids == mutrel_posterior.vids
        log_clust_probs, log_notclust_probs = cluster_pairwise._make_coclust_probs(
            mutrel_posterior)
        llh = cluster_pairwise._calc_llh(Z, log_clust_probs,
                                         log_notclust_probs, logconc)
    elif args.model == 'linfreq':
        vids1, V, T, T_prime, omega = inputparser.load_read_counts(variants)
        vids2, Z = cluster_pairwise._convert_clustering_to_assignment(clusters)
        assert vids1 == vids2

        # Beta distribution prior for phi
        phi_alpha0 = 1.
        phi_beta0 = 1.
        llh = cluster_linfreq._calc_llh(V, T_prime, Z, phi_alpha0, phi_beta0,
                                        logconc)
    else:
        raise Exception('Unknown model')

    nlglh = -llh / (M * S * np.log(2))
    print(llh, nlglh)