예제 #1
0
    def test_mini(self):
        """
        Check if MCMC EM will run to completion
        """
        MOTIF_LEN = 3

        feat_generator = HierarchicalMotifFeatureGenerator(motif_lens=[MOTIF_LEN])
        obs_data_raw = read_gene_seq_csv_data(INPUT_GENES, INPUT_SEQS, motif_len=MOTIF_LEN)
        obs_data = []
        for obs_seq_mutation in obs_data_raw:
            obs_seq_m = obs_seq_mutation
            feat_generator.add_base_features(obs_seq_m)
            obs_data.append(obs_seq_m)

        init_theta = np.random.rand(feat_generator.feature_vec_len, 1)
        theta_mask = np.ones((feat_generator.feature_vec_len, 1), dtype=bool)

        # check SurvivalProblemLasso
        em_algo = MCMC_EM(
            obs_data,
            feat_generator,
            MutationOrderGibbsSampler,
            SurvivalProblemLasso,
            theta_mask,
            num_jobs=1,
        )
        em_algo.run(init_theta, max_em_iters=1)

        # check SurvivalProblemFusedLasso
        em_algo = MCMC_EM(
            obs_data,
            feat_generator,
            MutationOrderGibbsSampler,
            SurvivalProblemFusedLassoProximal,
            theta_mask,
            num_jobs=1,
        )
        em_algo.run(init_theta, max_em_iters=1)

        # Check SurvivalProblemLassoCVXPY
        em_algo = MCMC_EM(
            obs_data,
            feat_generator,
            MutationOrderGibbsSampler,
            SurvivalProblemLassoCVXPY,
            theta_mask,
            num_jobs=1,
        )
        em_algo.run(init_theta, max_em_iters=1)

        # Check SurvivalProblemFusedLassoCVXPY
        em_algo = MCMC_EM(
            obs_data,
            feat_generator,
            MutationOrderGibbsSampler,
            SurvivalProblemFusedLassoCVXPY,
            theta_mask,
            num_jobs=1,
        )
        em_algo.run(init_theta, max_em_iters=1)
예제 #2
0
def _collect_statistics(fitted_models, args, true_thetas, stat, fit_type):
    """
    @param fit_type: either "refit" or "penalized"
    """
    statistics = []
    stat_func = _get_stat_func(stat)
    for fmodel in fitted_models:
        if fmodel is not None:
            if stat == 'discovered':
                # don't use aggregate fit for false discovery rate
                feat_gen = fmodel.refit_feature_generator
                feat_gen.update_feats_after_removing([])
                true_theta = true_thetas[1]
                possible_mask = feat_gen.get_possible_motifs_to_targets(
                    true_theta.shape, )
            else:
                feat_gen = HierarchicalMotifFeatureGenerator(
                    motif_lens=[args.agg_motif_len],
                    left_motif_flank_len_list=[[args.agg_pos_mutating]],
                )
                true_theta = true_thetas[0]
                possible_mask = feat_gen.get_possible_motifs_to_targets(
                    true_theta.shape, )
            try:
                s = stat_func(fmodel, feat_gen, true_theta, possible_mask,
                              fit_type)
                if s is not None:
                    statistics.append(s)
            except ValueError as e:
                print(e)
    return statistics
예제 #3
0
def create_simulator(args):
    feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=[args.agg_motif_len],
        left_motif_flank_len_list=[[args.agg_motif_len / 2]],
    )
    with open(args.input_model, 'r') as f:
        agg_theta, _ = pickle.load(f)

    if agg_theta.shape[1] == NUM_NUCLEOTIDES:
        simulator = SurvivalModelSimulatorMultiColumn(agg_theta,
                                                      feat_generator,
                                                      lambda0=args.lambda0)
    elif agg_theta.shape[1] == 1:
        agg_theta_shape = (agg_theta.size, NUM_NUCLEOTIDES)
        probability_matrix = np.ones(agg_theta_shape) * 1.0 / 3
        possible_motifs_mask = feat_generator.get_possible_motifs_to_targets(
            agg_theta_shape)
        probability_matrix[~possible_motifs_mask] = 0
        simulator = SurvivalModelSimulatorSingleColumn(agg_theta,
                                                       probability_matrix,
                                                       feat_generator,
                                                       lambda0=args.lambda0)
    else:
        raise ValueError("Aggregate theta shape is wrong")
    return simulator
예제 #4
0
    def setUpClass(cls):
        np.random.seed(1)
        cls.motif_len = 3
        cls.feat_gen = HierarchicalMotifFeatureGenerator(motif_lens=[3])
        cls.feat_gen_hier = HierarchicalMotifFeatureGenerator(motif_lens=[2,3], left_motif_flank_len_list=[[0,1], [1]])

        cls.obs_seq_mut = ObservedSequenceMutations("agtctggcatcaaagaaagagcgatttag", "aggctcgtattcgctaaaataagcaccag", cls.motif_len)
        cls.mutation_order = [12, 18, 3, 5, 19, 16, 8, 17, 21, 0, 22, 10, 24, 11, 9, 23]
예제 #5
0
def main(args=sys.argv[1:]):
    args = parse_args()

    # Load fitted theta file
    with open(args.input_pkl, "r") as f:
        method_results = pickle.load(f)
        method_res = pick_best_model(method_results)
        per_target_model = method_res.refit_theta.shape[
            1] == NUM_NUCLEOTIDES + 1

    feat_generator = method_res.refit_feature_generator

    max_motif_len = feat_generator.motif_len
    full_feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=[max_motif_len],
        left_motif_flank_len_list=feat_generator.left_motif_flank_len,
    )

    theta = method_res.refit_theta
    if args.center_median:
        theta -= np.median(theta)

    num_agg_cols = NUM_NUCLEOTIDES if per_target_model else 1
    agg_start_col = 1 if per_target_model else 0

    full_theta = np.zeros((full_feat_generator.feature_vec_len, num_agg_cols))
    theta_lower = np.zeros((full_feat_generator.feature_vec_len, num_agg_cols))
    theta_upper = np.zeros((full_feat_generator.feature_vec_len, num_agg_cols))
    for col_idx in range(num_agg_cols):
        full_theta[:,
                   col_idx], theta_lower[:,
                                         col_idx], theta_upper[:,
                                                               col_idx] = feat_generator.combine_thetas_and_get_conf_int(
                                                                   method_res.
                                                                   refit_theta,
                                                                   variance_est
                                                                   =method_res.
                                                                   variance_est,
                                                                   col_idx=
                                                                   col_idx +
                                                                   agg_start_col,
                                                                   add_targets=
                                                                   False,
                                                               )

    agg_possible_motif_mask = full_feat_generator.get_possible_motifs_to_targets(
        full_theta.shape)
    full_theta[~agg_possible_motif_mask] = -np.inf
    theta_lower[~agg_possible_motif_mask] = -np.inf
    theta_upper[~agg_possible_motif_mask] = -np.inf

    if args.no_conf_int:
        theta_lower = full_theta
        theta_upper = full_theta

    plot_theta(args.output_csv, full_theta, theta_lower, theta_upper,
               args.output_pdf, per_target_model, full_feat_generator,
               max_motif_len, args.center_nucs, args.y_lab)
예제 #6
0
def main(args=sys.argv[1:]):
    args = parse_args()

    MOTIF_LEN = 5

    full_feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=[MOTIF_LEN],
        left_motif_flank_len_list=[[MOTIF_LEN / 2]],
    )

    # Load fitted theta file
    if args.logistic_pkl is not None:
        fitted_model = load_logistic_model(args.logistic_pkl)
    elif args.mut is not None:
        # If it came from fit_shumlate_model.py, it's in wide format
        fitted_model = ShazamModel(MOTIF_LEN,
                                   args.mut,
                                   args.sub,
                                   wide_format=True)

    full_theta = fitted_model.agg_refit_theta
    # center median
    theta_med = np.median(full_theta[~np.isinf(full_theta)])
    full_theta -= theta_med
    theta_lower = full_theta
    theta_upper = full_theta

    per_target_model = full_theta.shape[1] > 1
    plot_theta(args.output_csv, full_theta, theta_lower, theta_upper,
               args.output_pdf, per_target_model, full_feat_generator,
               MOTIF_LEN, args.center_nucs, args.y_lab)
예제 #7
0
    def setUpClass(cls):
        np.random.seed(1)
        cls.motif_len = 3
        cls.burn_in = 10

        cls.feat_gen = HierarchicalMotifFeatureGenerator(
            motif_lens=[cls.motif_len])
        cls.theta = np.random.rand(cls.feat_gen.feature_vec_len, 1) * 2
예제 #8
0
def likelihood_of_tree_from_shazam(tree, mutability_file, substitution_file=None, num_jobs=1, scratch_dir='_output', num_samples=1000, burn_in=0, num_tries=5):
    """
    Given an ETE tree and theta vector, compute the likelihood of that tree

    @param tree: an ETE tree
    @param mutability_file: csv of mutability fit from SHazaM
    @param substitution_file: csv of substitution fit from SHazaM; if empty assume all targets equiprobable
    @param num_jobs: how many jobs to run
    @param scratch_dir: where to put temporary output if running more than one job
    @param num_samples: number of chibs samples
    @param burn_in: number of burn-in iterations
    @param num_tries: number of tries for Chibs sampler

    @return: log likelihood of a tree given a SHazaM fit
    """

    # Default for SHazaM is S5F
    feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=[5],
        left_motif_flank_len_list=[[2]],
    )

    theta_ref = get_shazam_theta(mutability_file, substitution_file)

    per_target_model = theta_ref.shape[1] == NUM_NUCLEOTIDES + 1

    obs_data = get_sequence_mutations_from_tree(
        tree,
        feat_generator.motif_len,
        feat_generator.max_left_motif_flank_len,
        feat_generator.max_right_motif_flank_len,
    )

    feat_generator.add_base_features_for_list(obs_data)
    log_like_evaluator = LogLikelihoodEvaluator(obs_data, feat_generator, num_jobs, scratch_dir)

    return log_like_evaluator.get_log_lik(
        theta_ref,
        num_samples=num_samples,
        burn_in=burn_in,
        num_tries=num_tries,
    )
예제 #9
0
    def setUpClass(cls):
        np.random.seed(10)
        cls.motif_len = 3
        cls.BURN_IN = 10
        cls.feat_gen = HierarchicalMotifFeatureGenerator(
            motif_lens=[3], left_motif_flank_len_list=[[1]])
        cls.feat_gen_hier = HierarchicalMotifFeatureGenerator(
            motif_lens=[1, 3], left_motif_flank_len_list=[[0], [1]])
        cls.obs = ObservedSequenceMutations("attcaaatgatatac",
                                            "ataaatagggtttac",
                                            cls.motif_len,
                                            left_flank_len=1,
                                            right_flank_len=1)

        cls.feat_gen_off = HierarchicalMotifFeatureGenerator(
            motif_lens=[3], left_motif_flank_len_list=[[0, 1, 2]])
        cls.obs_off = ObservedSequenceMutations("attcaaatgatatac",
                                                "ataaatagggtttac",
                                                cls.motif_len,
                                                left_flank_len=2,
                                                right_flank_len=2)
예제 #10
0
def main(args=sys.argv[1:]):
    args = parse_args()

    # Randomly generate number of mutations or use default
    np.random.seed(args.seed)

    hier_feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=args.motif_lens,
        left_motif_flank_len_list=args.positions_mutating,
    )

    if args.use_shmulate_as_truth:
        h5f_theta = _read_mutability_probability_params(args)
        h5f_theta[h5f_theta == 0] = -np.inf
        h5f_theta[h5f_theta != -np.inf] = np.log(
            h5f_theta[h5f_theta != -np.inf])
        agg_h5f_theta = h5f_theta[:, 0:1] + h5f_theta[:, 1:]
        dump_parameters(agg_h5f_theta, h5f_theta, args, hier_feat_generator)
    else:
        theta_sampling_col0, theta_sampling_col_prob = _make_theta_sampling_distribution(
            args)
        avg_sampled_magnitude = np.sqrt(np.var(theta_sampling_col0))
        theta_raw = _generate_true_parameters(
            hier_feat_generator,
            args,
            theta_sampling_col0,
            theta_sampling_col_prob,
        )

        agg_theta_raw = hier_feat_generator.create_aggregate_theta(
            theta_raw, keep_col0=False)

        # Now rescale theta according to effect size
        mult_factor = 1.0 / np.sqrt(
            np.var(agg_theta_raw[agg_theta_raw != -np.inf])
        ) * args.effect_size * avg_sampled_magnitude
        agg_theta = agg_theta_raw * mult_factor
        theta_raw = theta_raw * mult_factor

        dump_parameters(agg_theta, theta_raw, args, hier_feat_generator)
예제 #11
0
 def setUpClass(cls):
     """
     Set up state
     """
     np.random.seed(1)
     cls.motif_len = 3
     cls.mut_pos_list = [[1]]
     cls.feat_gen = HierarchicalMotifFeatureGenerator(
         motif_lens=[cls.motif_len],
         left_motif_flank_len_list=cls.mut_pos_list,
     )
     cls.num_jobs = 1
     cls.scratch_dir = 'test/_output/'
     cls.num_e_samples = 4
     cls.sampling_rate = 1
     cls.burn_in = 1
     cls.nonzero_ratio = 0.5
    def setUpClass(cls):
        np.random.seed(10)
        cls.motif_len = 5

        cls.feat_gen_hier = HierarchicalMotifFeatureGenerator(
            motif_lens=[3, 5], left_motif_flank_len_list=[[0, 1], [2]])

        obs_seq_mut = ObservedSequenceMutations(
            "agtctggcatcaaagaaagagcgatttag", "aggctcgtattcgctaaaataagcaccag",
            cls.motif_len)
        cls.mutation_order = [
            12, 18, 3, 5, 19, 16, 8, 17, 21, 0, 22, 10, 24, 11, 9, 23
        ]

        cls.feat_gen_hier.add_base_features(obs_seq_mut)
        cls.sample_hier = ImputedSequenceMutations(obs_seq_mut,
                                                   cls.mutation_order)
예제 #13
0
    def _test_value_calculation_size(self, theta_num_col):
        np.random.seed(10)
        motif_len = 3
        penalty_param = 0.5

        feat_gen = HierarchicalMotifFeatureGenerator(motif_lens=[motif_len])
        motif_list = feat_gen.motif_list
        theta = np.random.rand(feat_gen.feature_vec_len, theta_num_col)
        theta_mask = feat_gen.get_possible_motifs_to_targets(theta.shape)
        theta[~theta_mask] = -np.inf

        obs = ObservedSequenceMutations("aggtgggttac", "aggagagttac", motif_len)
        feat_gen.add_base_features(obs)
        sample = ImputedSequenceMutations(obs, obs.mutation_pos_dict.keys())
        problem_cvx = SurvivalProblemLassoCVXPY(feat_gen, [sample], penalty_param, theta_mask)
        ll_cvx = problem_cvx.calculate_per_sample_log_lik(theta, sample)
        value_cvx = problem_cvx.get_value(theta)

        feature_mut_steps = feat_gen.create_for_mutation_steps(sample)
        problem_custom = SurvivalProblemLasso(feat_gen, [sample], penalty_param, theta_mask)
        ll_custom = problem_custom.calculate_per_sample_log_lik(np.exp(theta), problem_custom.precalc_data[0])
        value_custom = problem_custom.get_value(theta)
        self.assertTrue(np.isclose(ll_cvx.value, ll_custom))
        self.assertTrue(np.isclose(value_cvx.value, -value_custom))
예제 #14
0
        bar.set_hatch(hatch)
    for i, bar in enumerate(sns_plot.axes[0, 0].patches):
        hatch = hatches[i % 3]
        bar.set_hatch(hatch)
    plt.ylabel('Diff. from true theta', fontsize=17)
    plt.xlabel("True theta size", fontsize=17)
    sns_plot.set(ylim=(-5, 5))
    x = sns_plot.axes[0, 0].get_xlim()
    sns_plot.axes[0, 0].plot(x, len(x) * [0], 'k--', alpha=.4)
    plt.legend(loc='upper right', fontsize='large')

    sns_plot.savefig(fname)


dense_agg_feat_gen = HierarchicalMotifFeatureGenerator(
    motif_lens=[MOTIF_LEN],
    left_motif_flank_len_list=[[MUT_POS]],
)

all_df = pd.DataFrame(
    columns=['theta', 'samm_fit', 'shazam_fit', 'sim_method', 'seed'])
for sim_method in SIM_METHODS:
    for seed in NSEEDS:
        with open(TRUE_MODEL_STR % (sim_method, seed), 'r') as f:
            theta, _ = pickle.load(f)
        model_shape = theta.shape
        possible_agg_mask = dense_agg_feat_gen.get_possible_motifs_to_targets(
            model_shape, )
        theta = theta[possible_agg_mask] - np.median(theta[possible_agg_mask])
        tmp_df = pd.DataFrame()
        tmp_df['theta'] = theta
예제 #15
0
def main(args=sys.argv[1:]):
    args = parse_args()
    log.basicConfig(format="%(message)s",
                    filename=args.log_file,
                    level=log.DEBUG)
    np.random.seed(args.seed)

    feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=args.motif_lens,
        left_motif_flank_len_list=args.positions_mutating,
    )
    theta_shape = (feat_generator.feature_vec_len,
                   NUM_NUCLEOTIDES + 1 if args.per_target_model else 1)

    log.info("Reading data")
    obs_data, metadata = read_gene_seq_csv_data(
        args.input_naive,
        args.input_mutated,
        motif_len=args.max_motif_len,
        left_flank_len=args.max_left_flank,
        right_flank_len=args.max_right_flank,
    )
    log.info("num observations %d", len(obs_data))

    feat_generator.add_base_features_for_list(obs_data)

    # Process data
    fold_indices = data_split.split(
        len(obs_data),
        metadata,
        args.tuning_sample_ratio,
        args.k_folds,
        validation_column=args.validation_col,
    )
    data_folds = []
    for train_idx, val_idx in fold_indices:
        train_set = [obs_data[i] for i in train_idx]
        val_set = [obs_data[i] for i in val_idx]
        train_X, train_y, train_y_orig = get_X_y_matrices(
            train_set, args.per_target_model)
        val_X, val_y, val_y_orig = get_X_y_matrices(val_set,
                                                    args.per_target_model)
        logistic_reg = LogisticRegressionMotif(
            theta_shape,
            train_X,
            train_y,
            train_y_orig,
            per_target_model=args.per_target_model)
        data_folds.append((val_X, val_y, val_y_orig, logistic_reg))

    # Fit the models for each penalty parameter
    best_pen_param = get_best_penalty_param(args.penalty_params, data_folds)
    log.info("best penalty param %f", best_pen_param)

    # Refit penalized with all the data
    penalized_theta = fit_to_data(obs_data, best_pen_param, theta_shape,
                                  args.per_target_model)
    lines = get_nonzero_theta_print_lines(penalized_theta, feat_generator)
    log.info("========penalized==========")
    log.info(lines)

    # Convert theta to log probability of mutation
    if args.per_target_model:
        possible_mask = feat_generator.get_possible_motifs_to_targets(
            penalized_theta.shape)
        penalized_theta[~possible_mask] = np.inf
    penalized_theta_exp_sum = np.sum(np.exp(-penalized_theta[:, 1:]),
                                     axis=1).reshape(
                                         (penalized_theta.shape[0], 1))
    theta_prob = np.hstack([
        1.0 / (1.0 + np.exp(-penalized_theta[:, 0:1])),
        np.exp(-penalized_theta[:, 1:]) / penalized_theta_exp_sum
    ])
    theta_log_prob = np.log(theta_prob)
    theta_est = feat_generator.create_aggregate_theta(theta_log_prob,
                                                      keep_col0=False)

    hier_full_feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=[args.max_motif_len],
        left_motif_flank_len_list=[[args.max_left_flank]],
    )
    if args.per_target_model:
        possible_mask = hier_full_feat_generator.get_possible_motifs_to_targets(
            theta_est.shape)
        theta_est[~possible_mask] = -np.inf

    agg_lines = get_nonzero_theta_print_lines(theta_est,
                                              hier_full_feat_generator)
    log.info("===========aggregate=======")
    log.info(agg_lines)

    with open(args.model_pkl, "w") as f:
        pickle.dump(LogisticModel(theta_est), f)
예제 #16
0
def main(args=sys.argv[1:]):
    args = parse_args()
    print(args)
    np.random.seed(args.seed)

    thetas = []
    labels = []
    # get data
    obs_data, metadata = read_gene_seq_csv_data(
        args.input_naive,
        args.input_mutated,
        motif_len=args.motif_len,
        left_flank_len=args.left_flank,
        right_flank_len=args.right_flank,
    )

    ## get samm theta
    with open(args.input_samm, "r") as f:
        method_results = pickle.load(f)
        method_res = pick_best_model(method_results)
        feat_generator = method_res.refit_feature_generator
        per_target_model = method_res.refit_theta.shape[1] == NUM_NUCLEOTIDES + 1

    max_motif_len = args.motif_len
    full_feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=[max_motif_len],
        left_motif_flank_len_list=[[args.left_flank]],
    )
    full_feat_generator.add_base_features_for_list(obs_data)

    samm_theta = feat_generator.create_aggregate_theta(method_res.refit_theta, keep_col0=False)
    agg_possible_motif_mask = full_feat_generator.get_possible_motifs_to_targets(samm_theta.shape)
    samm_theta[~agg_possible_motif_mask] = -np.inf
    thetas += [samm_theta]
    labels += ['samm']

    # get shazam theta
    shazam_model = ShazamModel(max_motif_len, args.input_shazam, None, wide_format=True)

    shazam_theta = shazam_model.agg_refit_theta
    thetas += [shazam_theta]
    labels += ['shazam']

    # get logistic theta
    logistic_model = load_logistic_model(args.input_logistic)

    logistic_theta = logistic_model.agg_refit_theta
    thetas += [logistic_theta]
    labels += ['logistic']

    # do all comparisons
    cur_ref_idx = -1
    for idx in itertools.permutations(range(len(thetas)), 2):
        if idx[0] != cur_ref_idx:
            # need to update val_set_evaluator
            cur_ref_idx = idx[0]
            val_set_evaluator = LikelihoodComparer(
                obs_data,
                full_feat_generator,
                theta_ref=thetas[idx[0]],
                num_samples=args.num_val_samples,
                burn_in=args.num_val_burnin,
                num_jobs=args.num_jobs,
                scratch_dir=args.scratch_dir,
            )
        log_lik_ratio, lower_bound, upper_bound = val_set_evaluator.get_log_likelihood_ratio(thetas[idx[1]])
        print "{} with {} reference:".format(labels[idx[1]], labels[idx[0]])
        print "(lower, ratio, upper) = (%.4f, %.4f, %.4f)" % (lower_bound, log_lik_ratio, upper_bound)
예제 #17
0
def main(args=sys.argv[1:]):
    args = parse_args()
    log.basicConfig(format="%(message)s",
                    filename=args.log_file,
                    level=log.DEBUG)
    np.random.seed(args.seed)

    if max(args.k_folds, args.num_cpu_threads) > 1:
        all_runs_pool = Pool(max(args.k_folds, args.num_cpu_threads))
    else:
        all_runs_pool = None

    feat_generator = HierarchicalMotifFeatureGenerator(
        motif_lens=args.motif_lens,
        left_motif_flank_len_list=args.positions_mutating,
    )

    log.info("Reading data")
    obs_data, metadata = read_gene_seq_csv_data(
        args.input_naive,
        args.input_mutated,
        motif_len=args.max_motif_len,
        left_flank_len=args.max_left_flank,
        right_flank_len=args.max_right_flank,
    )

    feat_generator.add_base_features_for_list(obs_data)
    fold_indices = data_split.split(
        len(obs_data),
        metadata,
        args.tuning_sample_ratio,
        args.k_folds,
        validation_column=args.validation_col,
    )
    data_folds = []
    for train_idx, val_idx in fold_indices:
        train_set = [obs_data[i] for i in train_idx]
        val_set = [obs_data[i] for i in val_idx]
        data_folds.append((train_set, val_set))

    st_time = time.time()
    log.info("Data statistics:")
    log.info("  Number of sequences: Train %d, Val %d" %
             (len(train_idx), len(val_idx)))
    log.info(get_data_statistics_print_lines(obs_data, feat_generator))
    log.info("Settings %s" % args)
    log.info("Running EM")

    # Run EM on the lasso parameters from largest to smallest
    results_list = []
    val_set_evaluators = [None for _ in fold_indices]
    cmodel_algos = [
        ContextModelAlgo(feat_generator, args) for _ in fold_indices
    ]
    prev_pen_theta = None
    best_model_idx = 0
    for param_i, penalty_param in enumerate(args.penalty_params):
        param_results = []
        penalty_params_prev = None if param_i == 0 else args.penalty_params[
            param_i - 1]
        target_penalty_param = penalty_param if args.per_target_model else 0
        penalty_params = (penalty_param, target_penalty_param)
        log.info("==== Penalty parameters %f, %f ====" % penalty_params)
        workers = []
        for fold_idx, (train_set, val_set) in enumerate(data_folds):
            log.info("========= Fold %d ==============" % fold_idx)
            prev_pen_theta = results_list[
                param_i - 1][fold_idx].penalized_theta if param_i else None
            val_set_evaluator = val_set_evaluators[fold_idx]
            # Use the same number of order samples as previous validation set if possible
            prev_num_val_samples = val_set_evaluator.num_samples if val_set_evaluator is not None else args.num_val_samples
            samm_worker = SammWorker(fold_idx, cmodel_algos[fold_idx],
                                     train_set, penalty_params,
                                     args.em_max_iters, val_set_evaluator,
                                     prev_pen_theta, penalty_params_prev,
                                     val_set, prev_num_val_samples, args)
            workers.append(samm_worker)
        if args.k_folds > 1:
            # We will be using the MultiprocessingManager handle fitting theta for each fold (so python's multiprocessing lib)
            manager = MultiprocessingManager(all_runs_pool,
                                             workers,
                                             num_approx_batches=len(workers))
            results = manager.run()
        else:
            # We will be using the MultiprocessingManager parallelize computations within the M-step
            results = [w.run(all_runs_pool) for w in workers]

        param_results = [r[0] for r in results]
        results_list.append(param_results)
        val_set_evaluators = [r[1] for r in results]

        with open(args.out_file, "w") as f:
            pickle.dump(results_list, f)

        nonzeros = np.array(
            [res.penalized_num_nonzero for res in param_results])
        log_lik_ratios = np.array([r.log_lik_ratio for r in param_results])
        log.info("Log lik ratios %s" % log_lik_ratios)
        if any(nonzeros) and param_i > 0:
            cv_interval = get_interval(log_lik_ratios, zscore=1)
            log.info("log lik interval %s", cv_interval)
            if cv_interval[0] < -ZERO_THRES:
                # Make sure that the penalty isnt so big that theta is empty
                # One std error below the mean for the log lik ratios surrogate is negative
                # Time to stop shrinking penalty param
                # This model is not better than the previous model. Stop trying penalty parameters.
                # Time to refit the model
                log.info(
                    "EM surrogate function is decreasing. Stop trying penalty parameters. ll_ratios %s"
                    % log_lik_ratios)
                break

        best_model_idx = param_i

        if np.mean(nonzeros) == feat_generator.feature_vec_len:
            # Model is saturated so stop fitting new parameters
            log.info("Model is saturated with %d parameters. Stop fitting." %
                     np.mean(nonzeros))
            break

    # Pick out the best model
    # Make sure we have hte same support. Otherwise we need to refit
    if _check_same_support(results_list[best_model_idx]):
        # Just use the first fold as template for doing the refitting unpenalized
        method_res = results_list[best_model_idx][0]
    else:
        log.info("Need to refit to get the same support")
        # If support is not the same, refit penalized on all the data and get that support
        # Just use the first fold as template for doing the refitting unpenalized
        method_res_template = results_list[best_model_idx][0]
        prev_pen_theta = results_list[
            best_model_idx - 1][0].penalized_theta if best_model_idx else None
        method_res = cmodel_algos[0].fit_penalized(
            obs_data,
            method_res_template.penalty_params,
            max_em_iters=args.em_max_iters,
            init_theta=prev_pen_theta,
            pool=all_runs_pool,
        )
        results_list[best_model_idx].append(method_res)

    # Finally ready to refit as unpenalized model
    if args.num_cpu_threads > 1 and all_runs_pool is None:
        all_runs_pool = Pool(args.num_cpu_threads)
    cmodel_algos[0].refit_unpenalized(
        obs_data,
        model_result=method_res,
        max_em_iters=args.unpenalized_em_max_iters,
        hessian_check_iter=args.hessian_check_iter,
        get_hessian=not args.omit_hessian,
        pool=all_runs_pool,
    )

    # Pickle the refitted theta
    with open(args.out_file, "w") as f:
        pickle.dump(results_list, f)

    if not args.omit_hessian:
        full_feat_generator = HierarchicalMotifFeatureGenerator(
            motif_lens=[args.max_motif_len],
            left_motif_flank_len_list=args.max_mut_pos,
        )
        num_agg_cols = NUM_NUCLEOTIDES if args.per_target_model else 1
        agg_start_col = 1 if args.per_target_model else 0

        try:
            feat_generator_stage2 = HierarchicalMotifFeatureGenerator(
                motif_lens=args.motif_lens,
                model_truncation=method_res.model_masks,
                left_motif_flank_len_list=args.positions_mutating,
            )
            for col_idx in range(num_agg_cols):
                full_theta, theta_lower, theta_upper = feat_generator_stage2.combine_thetas_and_get_conf_int(
                    method_res.refit_theta,
                    variance_est=method_res.variance_est,
                    col_idx=col_idx + agg_start_col,
                    add_targets=args.per_target_model,
                )
        except ValueError as e:
            print(e)

            log.info("No fits had positive variance estimates")

    if all_runs_pool is not None:
        all_runs_pool.close()
        # helpful comment copied over: make sure we don't keep these processes open!
        all_runs_pool.join()
    log.info("Completed! Time: %s" % str(time.time() - st_time))