Exemplo n.º 1
0
def process_catalog_csv_pmf(csv="../data/boss_catalog.csv",
                            model_checkpoint=default_model,
                            output_dir="../tmp/visuals/",
                            kernel_size=400):
    pmf = np.loadtxt(csv, dtype=np.int64, delimiter=',')
    ids = [Id_DR12(row[0], row[1], row[2]) for row in pmf]
    process_catalog(ids,
                    model_path=model_checkpoint,
                    output_dir=output_dir,
                    kernel_size=kernel_size)
Exemplo n.º 2
0
def preprocess_data_from_dr9(kernel=400,
                             stride=3,
                             pos_sample_kernel_percent=0.3,
                             train_keys_csv="../data/dr9_train_set.csv",
                             test_keys_csv="../data/dr9_test_set.csv"):
    dr9_train = np.genfromtxt(train_keys_csv, delimiter=',')
    dr9_test = np.genfromtxt(test_keys_csv, delimiter=',')

    # Dedup ---(there aren't any in dr9_train, so skipping for now)
    # dr9_train_keys = np.vstack({tuple(row) for row in dr9_train[:,0:3]})

    sightlines_train = [
        Sightline(Id_DR12(s[0], s[1], s[2]), [Dla(s[3], s[4])])
        for s in dr9_train
    ]
    sightlines_test = [
        Sightline(Id_DR12(s[0], s[1], s[2]), [Dla(s[3], s[4])])
        for s in dr9_test
    ]

    prepare_localization_training_set(kernel, stride,
                                      pos_sample_kernel_percent,
                                      sightlines_train, sightlines_test)
Exemplo n.º 3
0
def add_s2n(outfile='visuals_dr12/predictions_BOSSDR12_s2n.json'):
    from dla_cnn.data_model.Id_DR12 import Id_DR12

    csv_plate_mjd_fiber = resource_filename('dla_cnn',
                                            "catalogs/boss_dr12/dr12_set.csv")
    csv = Table.read(csv_plate_mjd_fiber)
    ids = [Id_DR12(c[0], c[1], c[2], c[3], c[4]) for c in csv]
    jfile = 'visuals_dr12/predictions_DR12.json'
    # Call
    predictions = add_s2n_after(ids, jfile, CHUNK_SIZE=1000)

    # Write JSON string
    with open(outfile, 'w') as f:
        json.dump(predictions, f, indent=4)
Exemplo n.º 4
0
def process_catalog_fits_pmf(fits_dir="../../BOSS_dat_all",
                             model_checkpoint=default_model,
                             output_dir="../tmp/visuals/",
                             kernel_size=400):
    ids = []
    for f in glob.glob(fits_dir + "/*.fits"):
        match = re.match(r'.*-(\d+)-(\d+)-(\d+)\..*', f)
        if not match:
            print("Match failed on: ", f)
            exit()
        ids.append(
            Id_DR12(int(match.group(1)), int(match.group(2)),
                    int(match.group(3))))

    process_catalog(ids,
                    kernel_size=kernel_size,
                    model_path=model_checkpoint,
                    output_dir=output_dir)
Exemplo n.º 5
0
def process_catalog_dr12(csv_plate_mjd_fiber="../data/dr12_test_set.csv",
                         kernel_size=400,
                         pfiber=None,
                         make_pdf=False,
                         model_checkpoint=default_model,
                         output_dir="../tmp/visuals_dr12"):
    #csv = np.genfromtxt(csv_plate_mjd_fiber, delimiter=',')
    csv = Table.read(csv_plate_mjd_fiber)
    ids = [Id_DR12(c[0], c[1], c[2], c[3], c[4]) for c in csv]
    if pfiber is not None:
        plates = np.array([iid.plate for iid in ids])
        fibers = np.array([iid.fiber for iid in ids])
        imt = np.where((plates == pfiber[0]) & (fibers == pfiber[1]))[0]
        if len(imt) != 1:
            print("Plate/Fiber not in DR12!!")
            pdb.set_trace()
        else:
            ids = [ids[imt[0]]]
    process_catalog(ids,
                    kernel_size,
                    model_checkpoint,
                    CHUNK_SIZE=500,
                    make_pdf=make_pdf,
                    output_dir=output_dir)