Example #1
0
def test_initialization():
    # test handling of an inexistent config input file
    with pytest.raises(FileNotFoundError):
        instance = Estimator(base_config='non_existent.yaml')

    # assert correct instantiation based on a yaml file
    instance = Estimator(base_config=test_base_yaml)
Example #2
0
def main(argv):
    if len(argv) == 2:
        # this is in case hiding the base yaml is wanted
        base_config = 'base.yaml'
        input_yaml = argv[1]
    elif len(argv) == 3:
        base_config = argv[1]
        input_yaml = argv[2]
    else:
        print(len(argv))
        print("Usage: main <yaml file>")
        sys.exit()
    name = input_yaml.split("/")[-1].split(".")[0]

    with open(input_yaml, 'r') as f:
        config_dict = yaml.safe_load(f)


    print(config_dict)
    run_dict = config_dict

    try:
        run_dict['class_name'] = Estimator._find_subclass(name)
    except KeyError:
        raise ValueError(f"Class name {name} for PZ code is not defined")

    code = Estimator._find_subclass(name)
    print(f"code name: {code}")

    pz = code(base_config, run_dict)
    pz.inform()

    outf = initialize_writeout(pz.saveloc, pz.num_rows, pz.nzbins)

    for start, end, data in iter_chunk_hdf5_data(pz.testfile, pz._chunk_size,
                                                 'photometry'):
        pz_dict = pz.estimate(data)
        write_out_chunk(outf, pz_dict, start, end)
        print("finished " + name)

    finalize_writeout(outf, pz.zgrid)

    print("finished")
Example #3
0
def test_writing(tmpdir):
    instance = Estimator(test_base_yaml)
    instance.zmode = 0
    instance.zgrid = np.arange(0, 1, 0.2)
    instance.pz_pdf = np.ones(5)
    instance.saveloc = tmpdir.join("test.hdf5")
    instance.nzbins = len(instance.zgrid)
    test_dict = {'zmode': instance.zmode, 'pz_pdf': instance.pz_pdf}
    write_output_file(instance.saveloc, instance.num_rows, instance.nzbins,
                      test_dict, instance.zgrid)

    assert os.path.exists(instance.saveloc)
Example #4
0
def main(argv):
    if len(argv) == 2:
        # this is in case hiding the base yaml is wanted
        input_yaml = argv[1]
        base_config = 'base.yaml'
    elif len(argv) == 3:
        input_yaml = argv[1]
        base_config = argv[2]
    else:
        print(len(argv))
        print("Usage: main <config yaml file> [base config yaml]")
        sys.exit()

    with open(input_yaml, 'r') as f:
        run_dict = yaml.safe_load(f)

    name = run_dict['run_params']['class_name']

    try:
        Estimator._find_subclass(name)
    except KeyError:
        raise ValueError(f"Class name {name} for PZ code is not defined")

    code = Estimator._find_subclass(name)
    print(f"code name: {name}")

    pz = code(base_config, run_dict)

    pz.inform_dict = run_dict['run_params']['inform_options']
    if pz.inform_dict['load_model']:
        # note: specific options set in subclasss func def
        pz.load_pretrained_model()
    else:
        trainfile = pz.trainfile
        train_fmt = trainfile.split(".")[-1]
        training_data = load_training_data(trainfile, train_fmt, pz.groupname)
        pz.inform(training_data)

    if 'run_name' in run_dict['run_params']:
        outfile = run_dict['run_params']['run_name'] + '.hdf5'
        tmpfile = "temp_" + outfile
    else:
        outfile = 'output.hdf5'

    if pz.output_format == 'qp':
        tmploc = os.path.join(pz.outpath, name, tmpfile)
        outfile = run_dict['run_params']['run_name'] + "_qp.hdf5"
    saveloc = os.path.join(pz.outpath, name, outfile)

    if pz.output_format == 'qp':
        initialize_qp_output(saveloc)
    else:
        outf = initialize_writeout(saveloc, pz.num_rows, pz.nzbins)

    for chunk, (start, end, data) in enumerate(
            iter_chunk_hdf5_data(pz.testfile, pz._chunk_size, 'photometry')):
        pz_data_chunk = pz.estimate(data)
        if pz.output_format == 'qp':
            write_qp_output_chunk(tmploc, saveloc, pz_data_chunk, chunk)
        else:
            write_out_chunk(outf, pz_data_chunk, start, end)
        print("writing " + name + f"[{start}:{end}]")

    num_chunks = end // pz._chunk_size
    if end % pz._chunk_size > 0:
        num_chunks += 1

    if pz.output_format == 'qp':
        qp_reformat_output(tmploc, saveloc, num_chunks)
    else:
        finalize_writeout(outf, pz.zgrid)

    print("finished")
Example #5
0
def test_find_subclass():
    _ = Estimator._find_subclass('randomPZ')
Example #6
0
def test_estimate_not_implemented():
    fake_data = {'u': 99., 'g': 99., 'r': 99.}
    with pytest.raises(NotImplementedError):
        instance = Estimator(base_config=test_base_yaml)
        instance.estimate(fake_data)
Example #7
0
def test_init_with_dict():
    # test we can init with a dict we have already loaded
    d = yaml.safe_load(open(test_base_yaml))['base_config']
    _ = Estimator(d)