def test_mini(self): """ Check if MCMC EM will run to completion """ MOTIF_LEN = 3 feat_generator = HierarchicalMotifFeatureGenerator(motif_lens=[MOTIF_LEN]) obs_data_raw = read_gene_seq_csv_data(INPUT_GENES, INPUT_SEQS, motif_len=MOTIF_LEN) obs_data = [] for obs_seq_mutation in obs_data_raw: obs_seq_m = obs_seq_mutation feat_generator.add_base_features(obs_seq_m) obs_data.append(obs_seq_m) init_theta = np.random.rand(feat_generator.feature_vec_len, 1) theta_mask = np.ones((feat_generator.feature_vec_len, 1), dtype=bool) # check SurvivalProblemLasso em_algo = MCMC_EM( obs_data, feat_generator, MutationOrderGibbsSampler, SurvivalProblemLasso, theta_mask, num_jobs=1, ) em_algo.run(init_theta, max_em_iters=1) # check SurvivalProblemFusedLasso em_algo = MCMC_EM( obs_data, feat_generator, MutationOrderGibbsSampler, SurvivalProblemFusedLassoProximal, theta_mask, num_jobs=1, ) em_algo.run(init_theta, max_em_iters=1) # Check SurvivalProblemLassoCVXPY em_algo = MCMC_EM( obs_data, feat_generator, MutationOrderGibbsSampler, SurvivalProblemLassoCVXPY, theta_mask, num_jobs=1, ) em_algo.run(init_theta, max_em_iters=1) # Check SurvivalProblemFusedLassoCVXPY em_algo = MCMC_EM( obs_data, feat_generator, MutationOrderGibbsSampler, SurvivalProblemFusedLassoCVXPY, theta_mask, num_jobs=1, ) em_algo.run(init_theta, max_em_iters=1)
def _collect_statistics(fitted_models, args, true_thetas, stat, fit_type): """ @param fit_type: either "refit" or "penalized" """ statistics = [] stat_func = _get_stat_func(stat) for fmodel in fitted_models: if fmodel is not None: if stat == 'discovered': # don't use aggregate fit for false discovery rate feat_gen = fmodel.refit_feature_generator feat_gen.update_feats_after_removing([]) true_theta = true_thetas[1] possible_mask = feat_gen.get_possible_motifs_to_targets( true_theta.shape, ) else: feat_gen = HierarchicalMotifFeatureGenerator( motif_lens=[args.agg_motif_len], left_motif_flank_len_list=[[args.agg_pos_mutating]], ) true_theta = true_thetas[0] possible_mask = feat_gen.get_possible_motifs_to_targets( true_theta.shape, ) try: s = stat_func(fmodel, feat_gen, true_theta, possible_mask, fit_type) if s is not None: statistics.append(s) except ValueError as e: print(e) return statistics
def create_simulator(args): feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=[args.agg_motif_len], left_motif_flank_len_list=[[args.agg_motif_len / 2]], ) with open(args.input_model, 'r') as f: agg_theta, _ = pickle.load(f) if agg_theta.shape[1] == NUM_NUCLEOTIDES: simulator = SurvivalModelSimulatorMultiColumn(agg_theta, feat_generator, lambda0=args.lambda0) elif agg_theta.shape[1] == 1: agg_theta_shape = (agg_theta.size, NUM_NUCLEOTIDES) probability_matrix = np.ones(agg_theta_shape) * 1.0 / 3 possible_motifs_mask = feat_generator.get_possible_motifs_to_targets( agg_theta_shape) probability_matrix[~possible_motifs_mask] = 0 simulator = SurvivalModelSimulatorSingleColumn(agg_theta, probability_matrix, feat_generator, lambda0=args.lambda0) else: raise ValueError("Aggregate theta shape is wrong") return simulator
def setUpClass(cls): np.random.seed(1) cls.motif_len = 3 cls.feat_gen = HierarchicalMotifFeatureGenerator(motif_lens=[3]) cls.feat_gen_hier = HierarchicalMotifFeatureGenerator(motif_lens=[2,3], left_motif_flank_len_list=[[0,1], [1]]) cls.obs_seq_mut = ObservedSequenceMutations("agtctggcatcaaagaaagagcgatttag", "aggctcgtattcgctaaaataagcaccag", cls.motif_len) cls.mutation_order = [12, 18, 3, 5, 19, 16, 8, 17, 21, 0, 22, 10, 24, 11, 9, 23]
def main(args=sys.argv[1:]): args = parse_args() # Load fitted theta file with open(args.input_pkl, "r") as f: method_results = pickle.load(f) method_res = pick_best_model(method_results) per_target_model = method_res.refit_theta.shape[ 1] == NUM_NUCLEOTIDES + 1 feat_generator = method_res.refit_feature_generator max_motif_len = feat_generator.motif_len full_feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=[max_motif_len], left_motif_flank_len_list=feat_generator.left_motif_flank_len, ) theta = method_res.refit_theta if args.center_median: theta -= np.median(theta) num_agg_cols = NUM_NUCLEOTIDES if per_target_model else 1 agg_start_col = 1 if per_target_model else 0 full_theta = np.zeros((full_feat_generator.feature_vec_len, num_agg_cols)) theta_lower = np.zeros((full_feat_generator.feature_vec_len, num_agg_cols)) theta_upper = np.zeros((full_feat_generator.feature_vec_len, num_agg_cols)) for col_idx in range(num_agg_cols): full_theta[:, col_idx], theta_lower[:, col_idx], theta_upper[:, col_idx] = feat_generator.combine_thetas_and_get_conf_int( method_res. refit_theta, variance_est =method_res. variance_est, col_idx= col_idx + agg_start_col, add_targets= False, ) agg_possible_motif_mask = full_feat_generator.get_possible_motifs_to_targets( full_theta.shape) full_theta[~agg_possible_motif_mask] = -np.inf theta_lower[~agg_possible_motif_mask] = -np.inf theta_upper[~agg_possible_motif_mask] = -np.inf if args.no_conf_int: theta_lower = full_theta theta_upper = full_theta plot_theta(args.output_csv, full_theta, theta_lower, theta_upper, args.output_pdf, per_target_model, full_feat_generator, max_motif_len, args.center_nucs, args.y_lab)
def main(args=sys.argv[1:]): args = parse_args() MOTIF_LEN = 5 full_feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=[MOTIF_LEN], left_motif_flank_len_list=[[MOTIF_LEN / 2]], ) # Load fitted theta file if args.logistic_pkl is not None: fitted_model = load_logistic_model(args.logistic_pkl) elif args.mut is not None: # If it came from fit_shumlate_model.py, it's in wide format fitted_model = ShazamModel(MOTIF_LEN, args.mut, args.sub, wide_format=True) full_theta = fitted_model.agg_refit_theta # center median theta_med = np.median(full_theta[~np.isinf(full_theta)]) full_theta -= theta_med theta_lower = full_theta theta_upper = full_theta per_target_model = full_theta.shape[1] > 1 plot_theta(args.output_csv, full_theta, theta_lower, theta_upper, args.output_pdf, per_target_model, full_feat_generator, MOTIF_LEN, args.center_nucs, args.y_lab)
def setUpClass(cls): np.random.seed(1) cls.motif_len = 3 cls.burn_in = 10 cls.feat_gen = HierarchicalMotifFeatureGenerator( motif_lens=[cls.motif_len]) cls.theta = np.random.rand(cls.feat_gen.feature_vec_len, 1) * 2
def likelihood_of_tree_from_shazam(tree, mutability_file, substitution_file=None, num_jobs=1, scratch_dir='_output', num_samples=1000, burn_in=0, num_tries=5): """ Given an ETE tree and theta vector, compute the likelihood of that tree @param tree: an ETE tree @param mutability_file: csv of mutability fit from SHazaM @param substitution_file: csv of substitution fit from SHazaM; if empty assume all targets equiprobable @param num_jobs: how many jobs to run @param scratch_dir: where to put temporary output if running more than one job @param num_samples: number of chibs samples @param burn_in: number of burn-in iterations @param num_tries: number of tries for Chibs sampler @return: log likelihood of a tree given a SHazaM fit """ # Default for SHazaM is S5F feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=[5], left_motif_flank_len_list=[[2]], ) theta_ref = get_shazam_theta(mutability_file, substitution_file) per_target_model = theta_ref.shape[1] == NUM_NUCLEOTIDES + 1 obs_data = get_sequence_mutations_from_tree( tree, feat_generator.motif_len, feat_generator.max_left_motif_flank_len, feat_generator.max_right_motif_flank_len, ) feat_generator.add_base_features_for_list(obs_data) log_like_evaluator = LogLikelihoodEvaluator(obs_data, feat_generator, num_jobs, scratch_dir) return log_like_evaluator.get_log_lik( theta_ref, num_samples=num_samples, burn_in=burn_in, num_tries=num_tries, )
def setUpClass(cls): np.random.seed(10) cls.motif_len = 3 cls.BURN_IN = 10 cls.feat_gen = HierarchicalMotifFeatureGenerator( motif_lens=[3], left_motif_flank_len_list=[[1]]) cls.feat_gen_hier = HierarchicalMotifFeatureGenerator( motif_lens=[1, 3], left_motif_flank_len_list=[[0], [1]]) cls.obs = ObservedSequenceMutations("attcaaatgatatac", "ataaatagggtttac", cls.motif_len, left_flank_len=1, right_flank_len=1) cls.feat_gen_off = HierarchicalMotifFeatureGenerator( motif_lens=[3], left_motif_flank_len_list=[[0, 1, 2]]) cls.obs_off = ObservedSequenceMutations("attcaaatgatatac", "ataaatagggtttac", cls.motif_len, left_flank_len=2, right_flank_len=2)
def main(args=sys.argv[1:]): args = parse_args() # Randomly generate number of mutations or use default np.random.seed(args.seed) hier_feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=args.motif_lens, left_motif_flank_len_list=args.positions_mutating, ) if args.use_shmulate_as_truth: h5f_theta = _read_mutability_probability_params(args) h5f_theta[h5f_theta == 0] = -np.inf h5f_theta[h5f_theta != -np.inf] = np.log( h5f_theta[h5f_theta != -np.inf]) agg_h5f_theta = h5f_theta[:, 0:1] + h5f_theta[:, 1:] dump_parameters(agg_h5f_theta, h5f_theta, args, hier_feat_generator) else: theta_sampling_col0, theta_sampling_col_prob = _make_theta_sampling_distribution( args) avg_sampled_magnitude = np.sqrt(np.var(theta_sampling_col0)) theta_raw = _generate_true_parameters( hier_feat_generator, args, theta_sampling_col0, theta_sampling_col_prob, ) agg_theta_raw = hier_feat_generator.create_aggregate_theta( theta_raw, keep_col0=False) # Now rescale theta according to effect size mult_factor = 1.0 / np.sqrt( np.var(agg_theta_raw[agg_theta_raw != -np.inf]) ) * args.effect_size * avg_sampled_magnitude agg_theta = agg_theta_raw * mult_factor theta_raw = theta_raw * mult_factor dump_parameters(agg_theta, theta_raw, args, hier_feat_generator)
def setUpClass(cls): """ Set up state """ np.random.seed(1) cls.motif_len = 3 cls.mut_pos_list = [[1]] cls.feat_gen = HierarchicalMotifFeatureGenerator( motif_lens=[cls.motif_len], left_motif_flank_len_list=cls.mut_pos_list, ) cls.num_jobs = 1 cls.scratch_dir = 'test/_output/' cls.num_e_samples = 4 cls.sampling_rate = 1 cls.burn_in = 1 cls.nonzero_ratio = 0.5
def setUpClass(cls): np.random.seed(10) cls.motif_len = 5 cls.feat_gen_hier = HierarchicalMotifFeatureGenerator( motif_lens=[3, 5], left_motif_flank_len_list=[[0, 1], [2]]) obs_seq_mut = ObservedSequenceMutations( "agtctggcatcaaagaaagagcgatttag", "aggctcgtattcgctaaaataagcaccag", cls.motif_len) cls.mutation_order = [ 12, 18, 3, 5, 19, 16, 8, 17, 21, 0, 22, 10, 24, 11, 9, 23 ] cls.feat_gen_hier.add_base_features(obs_seq_mut) cls.sample_hier = ImputedSequenceMutations(obs_seq_mut, cls.mutation_order)
def _test_value_calculation_size(self, theta_num_col): np.random.seed(10) motif_len = 3 penalty_param = 0.5 feat_gen = HierarchicalMotifFeatureGenerator(motif_lens=[motif_len]) motif_list = feat_gen.motif_list theta = np.random.rand(feat_gen.feature_vec_len, theta_num_col) theta_mask = feat_gen.get_possible_motifs_to_targets(theta.shape) theta[~theta_mask] = -np.inf obs = ObservedSequenceMutations("aggtgggttac", "aggagagttac", motif_len) feat_gen.add_base_features(obs) sample = ImputedSequenceMutations(obs, obs.mutation_pos_dict.keys()) problem_cvx = SurvivalProblemLassoCVXPY(feat_gen, [sample], penalty_param, theta_mask) ll_cvx = problem_cvx.calculate_per_sample_log_lik(theta, sample) value_cvx = problem_cvx.get_value(theta) feature_mut_steps = feat_gen.create_for_mutation_steps(sample) problem_custom = SurvivalProblemLasso(feat_gen, [sample], penalty_param, theta_mask) ll_custom = problem_custom.calculate_per_sample_log_lik(np.exp(theta), problem_custom.precalc_data[0]) value_custom = problem_custom.get_value(theta) self.assertTrue(np.isclose(ll_cvx.value, ll_custom)) self.assertTrue(np.isclose(value_cvx.value, -value_custom))
bar.set_hatch(hatch) for i, bar in enumerate(sns_plot.axes[0, 0].patches): hatch = hatches[i % 3] bar.set_hatch(hatch) plt.ylabel('Diff. from true theta', fontsize=17) plt.xlabel("True theta size", fontsize=17) sns_plot.set(ylim=(-5, 5)) x = sns_plot.axes[0, 0].get_xlim() sns_plot.axes[0, 0].plot(x, len(x) * [0], 'k--', alpha=.4) plt.legend(loc='upper right', fontsize='large') sns_plot.savefig(fname) dense_agg_feat_gen = HierarchicalMotifFeatureGenerator( motif_lens=[MOTIF_LEN], left_motif_flank_len_list=[[MUT_POS]], ) all_df = pd.DataFrame( columns=['theta', 'samm_fit', 'shazam_fit', 'sim_method', 'seed']) for sim_method in SIM_METHODS: for seed in NSEEDS: with open(TRUE_MODEL_STR % (sim_method, seed), 'r') as f: theta, _ = pickle.load(f) model_shape = theta.shape possible_agg_mask = dense_agg_feat_gen.get_possible_motifs_to_targets( model_shape, ) theta = theta[possible_agg_mask] - np.median(theta[possible_agg_mask]) tmp_df = pd.DataFrame() tmp_df['theta'] = theta
def main(args=sys.argv[1:]): args = parse_args() log.basicConfig(format="%(message)s", filename=args.log_file, level=log.DEBUG) np.random.seed(args.seed) feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=args.motif_lens, left_motif_flank_len_list=args.positions_mutating, ) theta_shape = (feat_generator.feature_vec_len, NUM_NUCLEOTIDES + 1 if args.per_target_model else 1) log.info("Reading data") obs_data, metadata = read_gene_seq_csv_data( args.input_naive, args.input_mutated, motif_len=args.max_motif_len, left_flank_len=args.max_left_flank, right_flank_len=args.max_right_flank, ) log.info("num observations %d", len(obs_data)) feat_generator.add_base_features_for_list(obs_data) # Process data fold_indices = data_split.split( len(obs_data), metadata, args.tuning_sample_ratio, args.k_folds, validation_column=args.validation_col, ) data_folds = [] for train_idx, val_idx in fold_indices: train_set = [obs_data[i] for i in train_idx] val_set = [obs_data[i] for i in val_idx] train_X, train_y, train_y_orig = get_X_y_matrices( train_set, args.per_target_model) val_X, val_y, val_y_orig = get_X_y_matrices(val_set, args.per_target_model) logistic_reg = LogisticRegressionMotif( theta_shape, train_X, train_y, train_y_orig, per_target_model=args.per_target_model) data_folds.append((val_X, val_y, val_y_orig, logistic_reg)) # Fit the models for each penalty parameter best_pen_param = get_best_penalty_param(args.penalty_params, data_folds) log.info("best penalty param %f", best_pen_param) # Refit penalized with all the data penalized_theta = fit_to_data(obs_data, best_pen_param, theta_shape, args.per_target_model) lines = get_nonzero_theta_print_lines(penalized_theta, feat_generator) log.info("========penalized==========") log.info(lines) # Convert theta to log probability of mutation if args.per_target_model: possible_mask = feat_generator.get_possible_motifs_to_targets( penalized_theta.shape) penalized_theta[~possible_mask] = np.inf penalized_theta_exp_sum = np.sum(np.exp(-penalized_theta[:, 1:]), axis=1).reshape( (penalized_theta.shape[0], 1)) theta_prob = np.hstack([ 1.0 / (1.0 + np.exp(-penalized_theta[:, 0:1])), np.exp(-penalized_theta[:, 1:]) / penalized_theta_exp_sum ]) theta_log_prob = np.log(theta_prob) theta_est = feat_generator.create_aggregate_theta(theta_log_prob, keep_col0=False) hier_full_feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=[args.max_motif_len], left_motif_flank_len_list=[[args.max_left_flank]], ) if args.per_target_model: possible_mask = hier_full_feat_generator.get_possible_motifs_to_targets( theta_est.shape) theta_est[~possible_mask] = -np.inf agg_lines = get_nonzero_theta_print_lines(theta_est, hier_full_feat_generator) log.info("===========aggregate=======") log.info(agg_lines) with open(args.model_pkl, "w") as f: pickle.dump(LogisticModel(theta_est), f)
def main(args=sys.argv[1:]): args = parse_args() print(args) np.random.seed(args.seed) thetas = [] labels = [] # get data obs_data, metadata = read_gene_seq_csv_data( args.input_naive, args.input_mutated, motif_len=args.motif_len, left_flank_len=args.left_flank, right_flank_len=args.right_flank, ) ## get samm theta with open(args.input_samm, "r") as f: method_results = pickle.load(f) method_res = pick_best_model(method_results) feat_generator = method_res.refit_feature_generator per_target_model = method_res.refit_theta.shape[1] == NUM_NUCLEOTIDES + 1 max_motif_len = args.motif_len full_feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=[max_motif_len], left_motif_flank_len_list=[[args.left_flank]], ) full_feat_generator.add_base_features_for_list(obs_data) samm_theta = feat_generator.create_aggregate_theta(method_res.refit_theta, keep_col0=False) agg_possible_motif_mask = full_feat_generator.get_possible_motifs_to_targets(samm_theta.shape) samm_theta[~agg_possible_motif_mask] = -np.inf thetas += [samm_theta] labels += ['samm'] # get shazam theta shazam_model = ShazamModel(max_motif_len, args.input_shazam, None, wide_format=True) shazam_theta = shazam_model.agg_refit_theta thetas += [shazam_theta] labels += ['shazam'] # get logistic theta logistic_model = load_logistic_model(args.input_logistic) logistic_theta = logistic_model.agg_refit_theta thetas += [logistic_theta] labels += ['logistic'] # do all comparisons cur_ref_idx = -1 for idx in itertools.permutations(range(len(thetas)), 2): if idx[0] != cur_ref_idx: # need to update val_set_evaluator cur_ref_idx = idx[0] val_set_evaluator = LikelihoodComparer( obs_data, full_feat_generator, theta_ref=thetas[idx[0]], num_samples=args.num_val_samples, burn_in=args.num_val_burnin, num_jobs=args.num_jobs, scratch_dir=args.scratch_dir, ) log_lik_ratio, lower_bound, upper_bound = val_set_evaluator.get_log_likelihood_ratio(thetas[idx[1]]) print "{} with {} reference:".format(labels[idx[1]], labels[idx[0]]) print "(lower, ratio, upper) = (%.4f, %.4f, %.4f)" % (lower_bound, log_lik_ratio, upper_bound)
def main(args=sys.argv[1:]): args = parse_args() log.basicConfig(format="%(message)s", filename=args.log_file, level=log.DEBUG) np.random.seed(args.seed) if max(args.k_folds, args.num_cpu_threads) > 1: all_runs_pool = Pool(max(args.k_folds, args.num_cpu_threads)) else: all_runs_pool = None feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=args.motif_lens, left_motif_flank_len_list=args.positions_mutating, ) log.info("Reading data") obs_data, metadata = read_gene_seq_csv_data( args.input_naive, args.input_mutated, motif_len=args.max_motif_len, left_flank_len=args.max_left_flank, right_flank_len=args.max_right_flank, ) feat_generator.add_base_features_for_list(obs_data) fold_indices = data_split.split( len(obs_data), metadata, args.tuning_sample_ratio, args.k_folds, validation_column=args.validation_col, ) data_folds = [] for train_idx, val_idx in fold_indices: train_set = [obs_data[i] for i in train_idx] val_set = [obs_data[i] for i in val_idx] data_folds.append((train_set, val_set)) st_time = time.time() log.info("Data statistics:") log.info(" Number of sequences: Train %d, Val %d" % (len(train_idx), len(val_idx))) log.info(get_data_statistics_print_lines(obs_data, feat_generator)) log.info("Settings %s" % args) log.info("Running EM") # Run EM on the lasso parameters from largest to smallest results_list = [] val_set_evaluators = [None for _ in fold_indices] cmodel_algos = [ ContextModelAlgo(feat_generator, args) for _ in fold_indices ] prev_pen_theta = None best_model_idx = 0 for param_i, penalty_param in enumerate(args.penalty_params): param_results = [] penalty_params_prev = None if param_i == 0 else args.penalty_params[ param_i - 1] target_penalty_param = penalty_param if args.per_target_model else 0 penalty_params = (penalty_param, target_penalty_param) log.info("==== Penalty parameters %f, %f ====" % penalty_params) workers = [] for fold_idx, (train_set, val_set) in enumerate(data_folds): log.info("========= Fold %d ==============" % fold_idx) prev_pen_theta = results_list[ param_i - 1][fold_idx].penalized_theta if param_i else None val_set_evaluator = val_set_evaluators[fold_idx] # Use the same number of order samples as previous validation set if possible prev_num_val_samples = val_set_evaluator.num_samples if val_set_evaluator is not None else args.num_val_samples samm_worker = SammWorker(fold_idx, cmodel_algos[fold_idx], train_set, penalty_params, args.em_max_iters, val_set_evaluator, prev_pen_theta, penalty_params_prev, val_set, prev_num_val_samples, args) workers.append(samm_worker) if args.k_folds > 1: # We will be using the MultiprocessingManager handle fitting theta for each fold (so python's multiprocessing lib) manager = MultiprocessingManager(all_runs_pool, workers, num_approx_batches=len(workers)) results = manager.run() else: # We will be using the MultiprocessingManager parallelize computations within the M-step results = [w.run(all_runs_pool) for w in workers] param_results = [r[0] for r in results] results_list.append(param_results) val_set_evaluators = [r[1] for r in results] with open(args.out_file, "w") as f: pickle.dump(results_list, f) nonzeros = np.array( [res.penalized_num_nonzero for res in param_results]) log_lik_ratios = np.array([r.log_lik_ratio for r in param_results]) log.info("Log lik ratios %s" % log_lik_ratios) if any(nonzeros) and param_i > 0: cv_interval = get_interval(log_lik_ratios, zscore=1) log.info("log lik interval %s", cv_interval) if cv_interval[0] < -ZERO_THRES: # Make sure that the penalty isnt so big that theta is empty # One std error below the mean for the log lik ratios surrogate is negative # Time to stop shrinking penalty param # This model is not better than the previous model. Stop trying penalty parameters. # Time to refit the model log.info( "EM surrogate function is decreasing. Stop trying penalty parameters. ll_ratios %s" % log_lik_ratios) break best_model_idx = param_i if np.mean(nonzeros) == feat_generator.feature_vec_len: # Model is saturated so stop fitting new parameters log.info("Model is saturated with %d parameters. Stop fitting." % np.mean(nonzeros)) break # Pick out the best model # Make sure we have hte same support. Otherwise we need to refit if _check_same_support(results_list[best_model_idx]): # Just use the first fold as template for doing the refitting unpenalized method_res = results_list[best_model_idx][0] else: log.info("Need to refit to get the same support") # If support is not the same, refit penalized on all the data and get that support # Just use the first fold as template for doing the refitting unpenalized method_res_template = results_list[best_model_idx][0] prev_pen_theta = results_list[ best_model_idx - 1][0].penalized_theta if best_model_idx else None method_res = cmodel_algos[0].fit_penalized( obs_data, method_res_template.penalty_params, max_em_iters=args.em_max_iters, init_theta=prev_pen_theta, pool=all_runs_pool, ) results_list[best_model_idx].append(method_res) # Finally ready to refit as unpenalized model if args.num_cpu_threads > 1 and all_runs_pool is None: all_runs_pool = Pool(args.num_cpu_threads) cmodel_algos[0].refit_unpenalized( obs_data, model_result=method_res, max_em_iters=args.unpenalized_em_max_iters, hessian_check_iter=args.hessian_check_iter, get_hessian=not args.omit_hessian, pool=all_runs_pool, ) # Pickle the refitted theta with open(args.out_file, "w") as f: pickle.dump(results_list, f) if not args.omit_hessian: full_feat_generator = HierarchicalMotifFeatureGenerator( motif_lens=[args.max_motif_len], left_motif_flank_len_list=args.max_mut_pos, ) num_agg_cols = NUM_NUCLEOTIDES if args.per_target_model else 1 agg_start_col = 1 if args.per_target_model else 0 try: feat_generator_stage2 = HierarchicalMotifFeatureGenerator( motif_lens=args.motif_lens, model_truncation=method_res.model_masks, left_motif_flank_len_list=args.positions_mutating, ) for col_idx in range(num_agg_cols): full_theta, theta_lower, theta_upper = feat_generator_stage2.combine_thetas_and_get_conf_int( method_res.refit_theta, variance_est=method_res.variance_est, col_idx=col_idx + agg_start_col, add_targets=args.per_target_model, ) except ValueError as e: print(e) log.info("No fits had positive variance estimates") if all_runs_pool is not None: all_runs_pool.close() # helpful comment copied over: make sure we don't keep these processes open! all_runs_pool.join() log.info("Completed! Time: %s" % str(time.time() - st_time))