def main_clustering_exapnsion(path_to_params):
    # change to 200 after debugging
    n_clusters_for_expansion = 5
    sample_idxs_to_plot = [0, 1, 2, 3, 4]

    data_input, model = load_saved_results(path_to_params)
    gate_expander = expand_gate_from_saved_results(data_input, model,
                                                   n_clusters_for_expansion)
    plot_and_save_expanded_data_for_samples(gate_expander, model, data_input,
                                            sample_idxs_to_plot)

    all_expanded_data = np.concatenate(gate_expander.expanded_data_per_sample)
    cell_level_labels = gate_expander.get_catted_cell_level_labels_of_expanded_data(
    )

    catted_tr_data = np.concatenate(data_input.x_tr)
    plotter = DataAndGatesPlotterDepthOne(model, catted_tr_data)
    fig, axes = plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    size = 1000 * 1 / catted_tr_data.shape[0]
    pos_cells = all_expanded_data[cell_level_labels == 1]
    neg_cells = all_expanded_data[cell_level_labels == 0]
    axes[0].scatter(pos_cells[:, 0], pos_cells[:, 1], color='r', s=size)
    axes[1].scatter(neg_cells[:, 0], neg_cells[:, 1], color='r', s=size)
    plt.savefig('expanded_data_with_all_data.png')
示例#2
0
def single_run_single_gate(params):
    start_time = time.time()

    #evauntually uncomment this leaving asis in order ot keep the same results as before to compare.
    #set_random_seeds(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    data_input.split_data(split_seed=params['random_seed'])

    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()

    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        cells_to_subsample=params['transform_params']['cells_to_subsample'],
        num_cells_for_transformer=params['transform_params']['num_cells_for_transformer']
    )
    data_input.save_transformer(params['save_dir'])
    data_input.normalize_data()
    unused_cluster_gate_inits = init_plot_and_save_gates(data_input, params)
    #everything below differs from the other main_UMAP
    data_input.convert_all_data_to_tensors()
    init_gate_tree, unused_cluster_gate_inits = get_next_gate_tree(
        unused_cluster_gate_inits, data_input, params, model=None)
    model = initialize_model(params['model_params'], [init_gate_tree])
    performance_tracker = run_train_model(model, params['train_params'],
                                          data_input)

    model_save_path = os.path.join(params['save_dir'], 'model.pkl')
    torch.save(model.state_dict(), model_save_path)

    trackers_save_path = os.path.join(params['save_dir'],
                                      'last_CV_rounds_tracker.pkl')
    with open(trackers_save_path, 'wb') as f:
        pickle.dump(performance_tracker, f)
    results_plotter = DataAndGatesPlotterDepthOne(
        model, np.concatenate(data_input.x_tr))
    #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))
    results_plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    plt.savefig(os.path.join(params['save_dir'], 'final_gates.png'))

    with open(os.path.join(params['save_dir'], 'configs.pkl'), 'wb') as f:
        pickle.dump(params, f)

    print('Complete main loop took %.4f seconds' % (time.time() - start_time))
    return performance_tracker, model
def get_data_inside_both_visually_correct_and_learned_gate(
        path_to_params_learned):
    matplotlib.rcParams.update({'font.size': 22})

    path_to_params_learned, model_learned, data_input = load_saved_model_and_matching_data_input(
        path_to_params_learned)

    vis_correct_gate = [[['D1', 0.,
                          model_learned.get_gate_tree()[0][0][2]],
                         ['D2',
                          model_learned.get_gate_tree()[0][1][1], .75]]]
    vis_correct_model = DepthOneModel(vis_correct_gate,
                                      path_to_params_learned['model_params'])

    plotter_model = DataAndGatesPlotterDepthOne(
        model_learned, np.concatenate(data_input.x_tr))
    plotter_vis_corr = DataAndGatesPlotterDepthOne(
        vis_correct_model, np.concatenate(data_input.x_tr))

    model_learned_data_inside_gate = get_data_inside_gate(
        plotter_model, data_input, model_learned)
    vis_corr_data_inside_gate = get_data_inside_gate(plotter_vis_corr,
                                                     data_input,
                                                     vis_correct_model)

    with open('model_feat_diff_data_inside_gate.pkl', 'wb') as f:
        pickle.dump(model_learned_data_inside_gate, f)

    with open('vis_corr_data_inside_gate.pkl', 'wb') as f:
        pickle.dump(vis_corr_data_inside_gate, f)

    return model_learned_data_inside_gate, vis_corr_data_inside_gate
def plot_data_inside_semi_synth_gate_in_real_UMAP_space(
        path_to_semi_synth_params, path_to_params_real):

    params_semi_synth, model_semi_synth, data_input_semi_synth, umapper_semi_synth = load_saved_model_and_matching_data_input(
        path_to_semi_synth_params)

    params_real, model_real, data_input_real, umapper_real = load_saved_model_and_matching_data_input(
        path_to_params_real)

    plotter_semi_synth = DataAndGatesPlotterDepthOne(
        model_semi_synth, np.concatenate(data_input_semi_synth.x_tr))
    semi_synth_data_inside_gate = get_data_inside_gate(plotter_semi_synth,
                                                       data_input_semi_synth,
                                                       model_semi_synth)

    data_inside_semi_synth_in_real_umapper_space = umapper_real.transform(
        semi_synth_data_inside_gate)
    data_inside_semi_synth_in_real_umapper_space = normalize_data(
        data_inside_semi_synth_in_real_umapper_space)
    #labels_in_real_umapper_space = data_input.y_tr[data_inside_semi_synth_in_real_umapper_space_idxs]

    plotter_real_data = DataAndGatesPlotterDepthOne(
        model_real,
        np.concatenate(data_inside_semi_synth_in_real_umapper_space))
    plotter_real_data.plot_all_gates(plt.gca())
    size = 1000 * 1 / data_inside_semi_synth_in_real_umapper_space.shape[0]
    plt.scatter(data_inside_semi_synth_in_real_umapper_space[:, 0],
                data_inside_semi_synth_in_real_umapper_space[:, 1],
                s=size)
    plt.savefig('semi_synth_inside_gate_in_real_umapper_space.png')
示例#5
0
def make_and_save_plot_to_check_umap_stays_same(model, data_input, run,
                                                params):
    results_plotter = DataAndGatesPlotterDepthOne(
        model, np.concatenate(data_input.x_tr))
    results_plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    plt.savefig(os.path.join(params['save_dir'], 'test%d.png' % run))
def main(path_to_params):
    start_time = time.time()

    params = TransformParameterParser(path_to_params).parse_params()
    print(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    data_input.split_data()

    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()
    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        params['transform_params']['cells_to_subsample'],
        params['transform_params']['num_cells_for_transformer'])
    data_input.save_transformer(params['save_dir'])
    data_input.normalize_data()
    init_gate_tree = init_plot_and_save_gates(data_input, params)

    model = initialize_model(params['model_params'], init_gate_tree)
    data_input.prepare_data_for_training()
    performance_tracker = run_train_model(model, params['train_params'],
                                          data_input)

    model_save_path = os.path.join(params['save_dir'], 'model.pkl')
    torch.save(model.state_dict(), model_save_path)

    tracker_save_path = os.path.join(params['save_dir'], 'tracker.pkl')
    with open(tracker_save_path, 'wb') as f:
        pickle.dump(performance_tracker, f)
    results_plotter = DataAndGatesPlotterDepthOne(
        model, np.concatenate(data_input.x_tr))
    #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))
    results_plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    plt.savefig(os.path.join(params['save_dir'], 'final_gates.png'))
    print('Complete main loop took %.4f seconds' % (time.time() - start_time))
def plot_and_save_expanded_data_for_samples(gate_expander, model, data_input,
                                            sample_idxs_to_plot):
    for sample_idx in sample_idxs_to_plot:
        plotter = DataAndGatesPlotterDepthOne(model,
                                              np.concatenate(data_input.x_tr))
        plotter.plot_single_sample_with_gate(data_input.x_tr[sample_idx],
                                             data_input.idxs_tr[sample_idx],
                                             data_input.y_tr[sample_idx],
                                             plt.gca(),
                                             include_diagnostics=False)

        expanded_data = gate_expander.expanded_data_per_sample[sample_idx]
        if not (expanded_data.shape[0] == 0):
            print(expanded_data.shape)
            plotter.plot_single_sample_with_gate(
                expanded_data,
                data_input.idxs_tr[sample_idx],
                data_input.y_tr[sample_idx],
                plt.gca(),
                color='r',
                include_diagnostics=False)

        clusters = gate_expander.clusterers_per_sample[
            sample_idx].cluster_centers_
        plt.gca().scatter(clusters[:, 0], clusters[:, 1], color='b', s=2)
        plt.savefig('expanded_sample%d.png' % data_input.idxs_tr[sample_idx])
        plt.clf()
示例#8
0
def make_umap_plots_per_sample(model, data_input, sample_idxs_to_plot, plots_per_row=5, figlen=7, savename='plots_per_sample.png', background_data_to_plot=None, color='b', expanded_data_per_sample=None, sample_names_to_true_features=None, BALL=False):


    if len(sample_idxs_to_plot) == 0:
        print('idxs are empty!')
        return None
    vals_to_delete = []
    for i, idx in enumerate(sample_idxs_to_plot):
        if not (idx in data_input.sample_names_all):
            print('Sample %d not in training data' %idx)
            vals_to_delete.append(idx)
    for val in vals_to_delete:
        del sample_idxs_to_plot[sample_idxs_to_plot.index(val)]
    plotter = DataAndGatesPlotterDepthOne(model, [])
    idxs_in_data_input = [[1, data_input.idxs_tr.index(idx)] if idx in data_input.idxs_tr else [0, data_input.idxs_te.index(idx)] for idx in sample_idxs_to_plot]

    n_samples_to_plot =  len(sample_idxs_to_plot)

    evenly_divides = not(n_samples_to_plot % plots_per_row)
    n_rows = n_samples_to_plot//plots_per_row
    if not(evenly_divides):
        n_rows += 1
    
    print(n_rows, plots_per_row) 
    fig, axes = plt.subplots(n_rows, plots_per_row, figsize=((figlen) * plots_per_row, (figlen) * n_rows),sharex=True, sharey=True)

    #fig.suptitle('UMAP Embedding and Learned Gates per Sample')
    
    
        
    axes = [axes] if len(axes.shape) == 1 else axes

    if not (background_data_to_plot is None):
        for i in range(n_rows):
            for j in range(plots_per_row):
                axes[i][j].scatter(
                    background_data_to_plot[:, 0],
                    background_data_to_plot[:, 1],
                    c='lightgrey', s=1/100, alpha=.5 
                )                      

    axes[0][0].set_xlim(0, 1)
    axes[0][0].set_ylim(0, 1)
    fig.tight_layout(pad=1.3)
    row_start_idx = 0
    for i, row in enumerate(range(n_rows)):
        sample_row_idxs = sample_idxs_to_plot[row_start_idx: row_start_idx + plots_per_row]
        for j, sample_idx in enumerate(sample_row_idxs):
            cur_axis = axes[i][j]
            data_input_matching_idx = idxs_in_data_input[row_start_idx + j][1]
            if idxs_in_data_input[j][0]:
                sample = data_input.x_tr[data_input_matching_idx]
                label = data_input.y_tr[data_input_matching_idx]
            else:
                sample = data_input.x_te[data_input_matching_idx]
                label = data_input.y_te[data_input_matching_idx]
            name = sample_idxs_to_plot[row_start_idx + j]
            true_feature = None
            if sample_names_to_true_features:
                true_feature = sample_names_to_true_features[name]
            if not BALL:
                plotter.plot_single_sample_with_gate(
                    sample, name, label,
                    cur_axis, size=1, color='b',
                    true_feature=true_feature
                )
            else:
                plotter.plot_single_sample_with_gate(
                    sample, name, label,
                    cur_axis, size=1, color='b',
                    true_feature=true_feature,
                    BALL=True
                )
            if not (expanded_data_per_sample is None):
                expanded_data = expanded_data_per_sample[data_input_matching_idx]
                cur_axis.scatter(expanded_data[:, 0], expanded_data[:, 1], s=1, color='r')
        row_start_idx += plots_per_row

 
    plt.savefig(savename)
def cross_validate(path_to_params, n_runs, start_seed=0):
    start_time = time.time()

    params = TransformParameterParser(path_to_params).parse_params()
    print(params)
    check_consistency_of_params(params)

    #evauntually uncomment this leaving asis in order ot keep the same results as before to compare.
    set_random_seeds(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    te_accs = []
    tr_accs = []
    # to get to the correct new split at start
    for i in range(start_seed):
        data_input.split_data()

    for run in range(start_seed, n_runs):
        if not os.path.exists(os.path.join(params['save_dir'], 'run%d' % run)):
            os.makedirs(os.path.join(params['save_dir'], 'run%d' % run))
        savepath = os.path.join(params['save_dir'], 'run%d' % run)
        data_input.split_data()
        print(data_input.idxs_tr)

        data_transformer = DataTransformerFactory(
            params['transform_params'],
            params['random_seed']).manufacture_transformer()

        data_input.embed_data_and_fit_transformer(\
            data_transformer,
            cells_to_subsample=params['transform_params']['cells_to_subsample'],
            num_cells_for_transformer=params['transform_params']['num_cells_for_transformer'],
            use_labels_to_transform_data=params['transform_params']['use_labels_to_transform_data']
        )
        data_input.save_transformer(savepath)
        data_input.normalize_data()
        unused_cluster_gate_inits = init_plot_and_save_gates(
            data_input, params)
        #everything below differs from the other main_UMAP
        data_input.convert_all_data_to_tensors()
        init_gate_tree, unused_cluster_gate_inits = get_next_gate_tree(
            unused_cluster_gate_inits, data_input, params, model=None)
        model = initialize_model(params['model_params'], [init_gate_tree])
        performance_tracker = run_train_model(model, params['train_params'],
                                              data_input)

        model_save_path = os.path.join(savepath, 'model.pkl')
        torch.save(model.state_dict(), model_save_path)

        tracker_save_path = os.path.join(savepath, 'tracker.pkl')
        with open(tracker_save_path, 'wb') as f:
            pickle.dump(performance_tracker, f)
        results_plotter = DataAndGatesPlotterDepthOne(
            model, np.concatenate(data_input.x_tr))
        #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))
        results_plotter.plot_data_with_gates(
            np.array(
                np.concatenate([
                    data_input.y_tr[i] *
                    torch.ones([data_input.x_tr[i].shape[0], 1])
                    for i in range(len(data_input.x_tr))
                ])))

        plt.savefig(os.path.join(savepath, 'final_gates.png'))

        with open(os.path.join(savepath, 'configs.pkl'), 'wb') as f:
            pickle.dump(params, f)

        print('Complete main loop for run %d took %.4f seconds' %
              (run, time.time() - start_time))
        start_time = time.time()
        print('Accuracy tr %.3f, te %.3f' %
              (performance_tracker.metrics['tr_acc'][-1],
               performance_tracker.metrics['te_acc'][-1]))
        te_accs.append(performance_tracker.metrics['te_acc'][-1])
        tr_accs.append(performance_tracker.metrics['tr_acc'][-1])
    tr_accs = np.array(tr_accs)
    te_accs = np.array(te_accs)
    print('Average tr acc: %.3f, te acc %.3f' %
          (np.mean(tr_accs), np.mean(te_accs)))
    print('Std dev tr acc: %.3f, te_acc %.3f' %
          (np.std(tr_accs), np.std(te_accs)))
示例#10
0
def main(path_to_params):
    start_time = time.time()

    params = TransformParameterParser(path_to_params).parse_params()
    print(params)
    check_consistency_of_params(params)

    #evauntually uncomment this leaving asis in order ot keep the same results as before to compare.
    set_random_seeds(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    data_input.split_data()
    print('%d samples in the training data' % len(data_input.x_tr))
    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()

    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        cells_to_subsample=params['transform_params']['cells_to_subsample'],
        num_cells_for_transformer=params['transform_params']['num_cells_for_transformer'],
        use_labels_to_transform_data=params['transform_params']['use_labels_to_transform_data']
    )
    # can't pickle opentsne objects
    if not params['transform_params'] == 'tsne':
        data_input.save_transformer(params['save_dir'])
    data_input.normalize_data()

    potential_gates = get_all_potential_gates(data_input, params)
    data_input.convert_all_data_to_tensors()
    model = initialize_model(params['model_params'], potential_gates)

    if params['train_params']['fix_gates']:
        model.freeze_gate_params()
    tracker = run_train_model(\
        model, params['train_params'], data_input
    )

    #   if params['transform_params']['embed_dim'] == 3:
    #       unused_cluster_gate_inits = init_gates(data_input, params)
    #   else:
    #       unused_cluster_gate_inits = init_plot_and_save_gates(data_input, params)
    #   #everything below differs from the other main_UMAP
    #   data_input.convert_all_data_to_tensors()
    #   init_gate_tree, unused_cluster_gate_inits = get_next_gate_tree(unused_cluster_gate_inits, data_input, params, model=None)
    #   model = initialize_model(params['model_params'], [init_gate_tree])
    #   trackers_per_round = []
    #   num_gates_left = len(unused_cluster_gate_inits)
    #   #print(num_gates_left, 'asdfasdfasdfasdfasdfasdfas')
    #   for i in range(num_gates_left + 1):
    #       performance_tracker = run_train_model(model, params['train_params'], data_input)
    #       trackers_per_round.append(performance_tracker.get_named_tuple_rep())
    #       if i == params['train_params']['num_gates_to_learn'] - 1:
    #           break
    #       if not i == num_gates_left:
    #           next_gate_tree, unused_cluster_gate_inits = get_next_gate_tree(unused_cluster_gate_inits, data_input, params, model=model)
    #           model.add_node(next_gate_tree)

    model_save_path = os.path.join(params['save_dir'], 'model.pkl')
    torch.save(model.state_dict(), model_save_path)

    tracker_save_path = os.path.join(params['save_dir'], 'tracker.pkl')
    #    trackers_per_round = [tracker.get_named_tuple_rep() for tracker in trackers_per_round]
    with open(tracker_save_path, 'wb') as f:
        pickle.dump(tracker, f)
    if params['plot_umap_reflection']:
        # reflection is about x=.5 since the data is already in umap space here
        reflected_data = []
        for data in data_input.x_tr:
            data[:, 0] = 1 - data[:, 0]
            reflected_data.append(data)
        data_input.x_tr = reflected_data
        gate_tree = model.get_gate_tree()
        reflected_gates = []
        for gate in gate_tree:
            print(gate)
            #order switches since reflected over x=.5
            low_reflected = 1 - gate[0][2]
            high_reflected = 1 - gate[0][1]
            gate[0][1] = low_reflected
            gate[0][2] = high_reflected
            print(gate)

            reflected_gates.append(gate)
        model.init_nodes(reflected_gates)
        print(model.init_nodes)
        print(model.get_gates())
    results_plotter = DataAndGatesPlotterDepthOne(
        model, np.concatenate(data_input.x_tr))
    #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))

    if params['transform_params']['embed_dim'] == 2:
        results_plotter.plot_data_with_gates(
            np.array(
                np.concatenate([
                    data_input.y_tr[i] *
                    torch.ones([data_input.x_tr[i].shape[0], 1])
                    for i in range(len(data_input.x_tr))
                ])))
        plt.savefig(os.path.join(params['save_dir'], 'final_gates.png'))
    else:
        fig_pos, ax_pos, fig_neg, ax_neg = results_plotter.plot_data_with_gates(
            np.array(
                np.concatenate([
                    data_input.y_tr[i] *
                    torch.ones([data_input.x_tr[i].shape[0], 1])
                    for i in range(len(data_input.x_tr))
                ])))
        with open(os.path.join(params['save_dir'], 'final_gates_pos_3d.pkl'),
                  'wb') as f:
            pickle.dump(fig_pos, f)

        with open(os.path.join(params['save_dir'], 'final_gates_neg_3d.pkl'),
                  'wb') as f:
            pickle.dump(fig_neg, f)

    with open(os.path.join(params['save_dir'], 'configs.pkl'), 'wb') as f:
        pickle.dump(params, f)

    print('Learned weights:', model.linear.weight)
    print('Complete main loop took %.4f seconds' % (time.time() - start_time))
def main(path_to_params):
    start_time = time.time()

    params = TransformParameterParser(path_to_params).parse_params()
    print(params)
    check_consistency_of_params(params)

    #evauntually uncomment this leaving asis in order ot keep the same results as before to compare.
    set_random_seeds(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    data_input.split_data()

    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()

    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        cells_to_subsample=params['transform_params']['cells_to_subsample'],
        num_cells_for_transformer=params['transform_params']['num_cells_for_transformer'],
        use_labels_to_transform_data=params['transform_params']['use_labels_to_transform_data']
    )
    data_input.save_transformer(params['save_dir'])
    data_input.normalize_data()
    unused_cluster_gate_inits = init_plot_and_save_gates(data_input, params)

    data_input.convert_all_data_to_tensors()

    init_gate_tree, unused_cluster_gate_inits = get_next_gate_tree(
        unused_cluster_gate_inits, data_input, params, model=None)
    model1 = initialize_model(params['model_params'], [init_gate_tree])

    performance_tracker1 = run_train_model(model1, params['train_params'],
                                           data_input)

    model1_save_path = os.path.join(params['save_dir'], 'model1.pkl')
    torch.save(model1.state_dict(), model1_save_path)

    tracker1_save_path = os.path.join(params['save_dir'], 'tracker1.pkl')
    with open(tracker1_save_path, 'wb') as f:
        pickle.dump(performance_tracker1, f)

    # now select the data inside the learned model1 gate and re-run umap
    data_input.filter_data_inside_first_model_gate(model1)
    unused_cluster_gate_inits = init_plot_and_save_gates(data_input, params)

    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()

    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        cells_to_subsample=params['transform_params']['cells_to_subsample'],
        num_cells_for_transformer=params['transform_params']['num_cells_for_transformer'],
        use_labels_to_transform_data=params['transform_params']['use_labels_to_transform_data']
    )
    data_input.save_transformer(params['save_dir'])
    data_input.convert_all_data_to_tensors()

    init_gate_tree, _ = get_next_gate_tree(unused_cluster_gate_inits,
                                           data_input,
                                           params,
                                           model=None)
    model2 = initialize_model(params['model_params'], [init_gate_tree])

    performance_tracker2 = run_train_model(model2, params['train_params'],
                                           data_input)

    model2_save_path = os.path.join(params['save_dir'], 'model2.pkl')
    torch.save(model2.state_dict(), model2_save_path)

    tracker2_save_path = os.path.join(params['save_dir'], 'tracker2.pkl')
    with open(tracker2_save_path, 'wb') as f:
        pickle.dump(performance_tracker2, f)

    results_plotter = DataAndGatesPlotterDepthOne(
        model2, np.concatenate(data_input.x_tr))
    #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))
    results_plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    plt.savefig(os.path.join(params['save_dir'], 'final_gates.png'))

    with open(os.path.join(params['save_dir'], 'configs.pkl'), 'wb') as f:
        pickle.dump(params, f)

    print('Complete main loop took %.4f seconds' % (time.time() - start_time))
def main(path_to_params):
    start_time = time.time()

    params = TransformParameterParser(path_to_params).parse_params()
    print(params)
    check_consistency_of_params(params)

    set_random_seeds(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    data_input.split_data()

    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()
    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        params['transform_params']['cells_to_subsample'],
        params['transform_params']['num_cells_for_transformer']
    )
    data_input.save_transformer(params['save_dir'])
    data_input.normalize_data()
    #everything below differs from the other main_UMAP

    multi_gate_initializer = MultipleGateInitializerHeuristic(
        data_input, params['model_params']['node_type'],
        params['gate_init_multi_heuristic_params'])
    init_gate_tree = [multi_gate_initializer.init_next_gate()]

    model = initialize_model(params['model_params'], init_gate_tree)
    data_input.prepare_data_for_training()
    trackers_per_step = []
    num_gates = params['gate_init_multi_heuristic_params']['num_gates']
    for i in range(num_gates):
        performance_tracker = run_train_model(model, params['train_params'],
                                              data_input)
        multi_gate_initializer.gates = model.get_gates()
        if not (i == num_gates - 1):
            print(model.get_gates())
            next_gate = multi_gate_initializer.init_next_gate()
            if next_gate is None:
                print(
                    'There are no non-overlapping initializations left to try!'
                )
                break
            model.add_node(next_gate)

    model_save_path = os.path.join(params['save_dir'], 'model.pkl')
    torch.save(model.state_dict(), model_save_path)

    tracker_save_path = os.path.join(params['save_dir'], 'tracker.pkl')
    with open(tracker_save_path, 'wb') as f:
        pickle.dump(performance_tracker, f)
    results_plotter = DataAndGatesPlotterDepthOne(
        model, np.concatenate(data_input.x_tr))
    #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))
    results_plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    plt.savefig(os.path.join(params['save_dir'], 'final_gates.png'))
    print('Complete main loop took %.4f seconds' % (time.time() - start_time))
def main_kde_expansion(path_to_params):
    start_time = time.time()
    sample_idxs_to_plot = [0, 1, 2, 3, 4]
    step_size = .001
    sigma_thresh_factor = .5

    data_input, model, params = load_saved_results(path_to_params,
                                                   ret_params_too=True)
    init_gate = model.get_gates()[0]
    gate_expander = KDEGateExpander(data_input.x_tr,
                                    init_gate,
                                    step_size=step_size,
                                    sigma_thresh_factor=sigma_thresh_factor)
    expanded_gates = gate_expander.expand_gates()
    #    gate_expander = expand_gate_from_saved_results(data_input, model, n_clusters_for_expansion)
    #    plot_and_save_expanded_data_for_samples(gate_expander, model, data_input, sample_idxs_to_plot)

    gate_expander.collect_expanded_cells_per_sample()

    # rerun logistic regressor to get the new accuracy after expansion
    expanded_gate_tree = [[['D1', expanded_gates[0], expanded_gates[1]],
                           ['D2', expanded_gates[2], expanded_gates[3]]]]
    model.init_nodes(expanded_gate_tree)
    train_params = params['train_params']
    fit_classifier_params(model,
                          data_input,
                          train_params['learning_rate_classifier'],
                          l1_reg_strength=train_params['l1_reg_strength'])
    expanded_output = model(data_input.x_tr, data_input.y_tr)
    y_true = data_input.y_tr
    y_pred = (expanded_output['y_pred'].cpu().detach().numpy() >= 0.5) * 1.0
    y_pred = y_pred.reshape(y_true.cpu().numpy().shape)
    acc_tr = sum(y_pred == y_true.cpu().numpy()) * 1.0 / y_true.shape[0]
    print('Tr acc/loss after expansion: %.4f, %.4f' %
          (acc_tr, expanded_output['log_loss']))

    expanded_output = model(data_input.x_te, data_input.y_te)
    y_true = data_input.y_te
    y_pred = (expanded_output['y_pred'].cpu().detach().numpy() >= 0.5) * 1.0
    y_pred = y_pred.reshape(y_true.cpu().numpy().shape)
    acc_te = sum(y_pred == y_true.cpu().numpy()) * 1.0 / y_true.shape[0]
    print('Te acc/loss after expansion: %.4f, %.4f' %
          (acc_te, expanded_output['log_loss']))

    all_expanded_data = np.concatenate(gate_expander.expanded_data_per_sample)
    print(all_expanded_data.shape, 'all_expanded_data')
    # get catted labels here

    cell_level_labels = get_catted_cell_level_labels_of_expanded_data(
        gate_expander.expanded_data_per_sample, data_input.y_tr)
    catted_tr_data = np.concatenate(data_input.x_tr)
    plotter = DataAndGatesPlotterDepthOne(model, catted_tr_data)
    fig, axes = plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    size = 1000 * 1 / catted_tr_data.shape[0]
    pos_cells = all_expanded_data[cell_level_labels == 1, :]
    neg_cells = all_expanded_data[cell_level_labels == 0, :]
    axes[0].scatter(pos_cells[:, 0], pos_cells[:, 1], color='r', s=size)
    axes[1].scatter(neg_cells[:, 0], neg_cells[:, 1], color='r', s=size)
    plt.savefig('expanded_data_with_all_data.png')
    print('total time %.3f' % (time.time() - start_time))