示例#1
0
def init_data_input(params, transformer_path):
    data_input = DataInput(params['data_params'])
    data_input.split_data()
    with open(transformer_path, 'rb') as f:
        data_transformer = pickle.load(f)
    print(data_transformer)
    data_input.embed_data(data_transformer, \
        params['transform_params']['cells_to_subsample'], 
        params['transform_params']['num_cells_for_transformer']
    )
    data_input.normalize_data()
    data_input.prepare_data_for_training() 
    return data_input
def main(path_to_params):
    start_time = time.time()

    params = TransformParameterParser(path_to_params).parse_params()
    print(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    data_input.split_data()

    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()
    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        params['transform_params']['cells_to_subsample'],
        params['transform_params']['num_cells_for_transformer'])
    data_input.save_transformer(params['save_dir'])
    data_input.normalize_data()
    init_gate_tree = init_plot_and_save_gates(data_input, params)

    model = initialize_model(params['model_params'], init_gate_tree)
    data_input.prepare_data_for_training()
    performance_tracker = run_train_model(model, params['train_params'],
                                          data_input)

    model_save_path = os.path.join(params['save_dir'], 'model.pkl')
    torch.save(model.state_dict(), model_save_path)

    tracker_save_path = os.path.join(params['save_dir'], 'tracker.pkl')
    with open(tracker_save_path, 'wb') as f:
        pickle.dump(performance_tracker, f)
    results_plotter = DataAndGatesPlotterDepthOne(
        model, np.concatenate(data_input.x_tr))
    #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))
    results_plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    plt.savefig(os.path.join(params['save_dir'], 'final_gates.png'))
    print('Complete main loop took %.4f seconds' % (time.time() - start_time))
def main(path_to_params):
    start_time = time.time()

    params = TransformParameterParser(path_to_params).parse_params()
    print(params)
    check_consistency_of_params(params)

    set_random_seeds(params)

    if not os.path.exists(params['save_dir']):
        os.makedirs(params['save_dir'])

    with open(os.path.join(params['save_dir'], 'params.pkl'), 'wb') as f:
        pickle.dump(params, f)

    data_input = DataInput(params['data_params'])
    data_input.split_data()

    data_transformer = DataTransformerFactory(
        params['transform_params'],
        params['random_seed']).manufacture_transformer()
    data_input.embed_data_and_fit_transformer(\
        data_transformer,
        params['transform_params']['cells_to_subsample'],
        params['transform_params']['num_cells_for_transformer']
    )
    data_input.save_transformer(params['save_dir'])
    data_input.normalize_data()
    #everything below differs from the other main_UMAP

    multi_gate_initializer = MultipleGateInitializerHeuristic(
        data_input, params['model_params']['node_type'],
        params['gate_init_multi_heuristic_params'])
    init_gate_tree = [multi_gate_initializer.init_next_gate()]

    model = initialize_model(params['model_params'], init_gate_tree)
    data_input.prepare_data_for_training()
    trackers_per_step = []
    num_gates = params['gate_init_multi_heuristic_params']['num_gates']
    for i in range(num_gates):
        performance_tracker = run_train_model(model, params['train_params'],
                                              data_input)
        multi_gate_initializer.gates = model.get_gates()
        if not (i == num_gates - 1):
            print(model.get_gates())
            next_gate = multi_gate_initializer.init_next_gate()
            if next_gate is None:
                print(
                    'There are no non-overlapping initializations left to try!'
                )
                break
            model.add_node(next_gate)

    model_save_path = os.path.join(params['save_dir'], 'model.pkl')
    torch.save(model.state_dict(), model_save_path)

    tracker_save_path = os.path.join(params['save_dir'], 'tracker.pkl')
    with open(tracker_save_path, 'wb') as f:
        pickle.dump(performance_tracker, f)
    results_plotter = DataAndGatesPlotterDepthOne(
        model, np.concatenate(data_input.x_tr))
    #fig, axes = plt.subplots(params['gate_init_params']['n_clusters'], figsize=(1 * params['gate_init_params']['n_clusters'], 3 * params['gate_init_params']['n_clusters']))
    results_plotter.plot_data_with_gates(
        np.array(
            np.concatenate([
                data_input.y_tr[i] *
                torch.ones([data_input.x_tr[i].shape[0], 1])
                for i in range(len(data_input.x_tr))
            ])))

    plt.savefig(os.path.join(params['save_dir'], 'final_gates.png'))
    print('Complete main loop took %.4f seconds' % (time.time() - start_time))