예제 #1
0
def test_array_param():

    # Check that arrays are correctly handled
    config_def = {
        'creator': ParamReturnCreator(param.Array),
        'val': np.arange(1, 6),
    }

    config_inst = process_config(config_def)

    assert np.all(config_inst == np.array([1, 2, 3, 4, 5]))

    # Check that non arrays cause errors
    config_def['val'] = 10

    try:
        config_inst = process_config(config_def)
    except param.ParamError:
        assert True
    else:
        assert False

    config_def['val'] = [1, 2, 3, 4, 5]

    try:
        config_inst = process_config(config_def)
    except param.ParamError:
        assert True
    else:
        assert False
예제 #2
0
def test_iterable_param():

    # Check that iterables are correctly handled
    config_def = {
        'item1': {
            'creator': ParamReturnCreator(param.Iterable),
            'val': [1, 2, 3, 4, 5]
        },
        'item2': {
            'creator': ParamReturnCreator(param.Iterable, val_type=np.int),
            'val': [1, 2, 3, 4, 5]
        },
    }

    config_inst = process_config(config_def)

    assert np.all(np.array(config_inst['item1']) == np.arange(1, 6))
    assert np.all(np.array(config_inst['item2']) == np.arange(1, 6))

    # Check arrays are considered iterable
    config_def['item1']['val'] = np.arange(1, 6)

    config_inst = process_config(config_def)

    assert np.all(np.array(config_inst['item1']) == np.arange(1, 6))
예제 #3
0
def test_param_choice():
    class ChoiceCreator(creator.base.Creator):
        some_val = param.Choice(param.Scalar(int), param.Scalar(float))

        def create(self, **kwargs):
            return self.param("some_val")

    config_def = {
        'item1': {
            'creator': ChoiceCreator,
            'some_val': 10,
        },
        'item2': {
            'creator': ChoiceCreator,
            'some_val': 5.0,
        },
    }

    config_inst = process_config(config_def)

    assert config_inst['item1'] == 10
    assert config_inst['item2'] == 5.0

    config_def['item1']['some_val'] = "ABC"

    try:
        config_inst = process_config(config_def)
    except param.ParamError:
        assert True
    else:
        assert False
예제 #4
0
def test_nested():

    config_def = {
        'item1': {
            'creator': ParamReturnCreator(param.Scalar),
            'val': 5,
        },
        'item2': {
            'creator': ParamReturnCreator(param.Scalar),
            'val': 10,
        },
        'item3': {
            'creator': ParamReturnCreator(param.Scalar),
            'val': {
                'creator': ParamReturnCreator(param.Scalar),
                'val': 15,
            },
        },
    }

    config_inst = process_config(config_def)

    assert config_inst['item1'] == 5
    assert config_inst['item2'] == 10
    assert config_inst['item3'] == 15
예제 #5
0
def test_array_with_unit_param():

    # Check that arrays are correctly handled
    config_def = {
        'creator': ParamReturnCreator(param.ArrayWithUnit),
        'val': rf.ArrayWithUnit_double_1(np.arange(1, 6), "m"),
    }

    config_inst = process_config(config_def)
예제 #6
0
def test_callable():

    config_def = {
        'creator': ParamReturnCreator(param.Scalar),
        'val': lambda **kwargs: 10,
    }

    config_inst = process_config(config_def)

    assert config_inst == 10
예제 #7
0
def test_scalar_with_type():

    # Check that scalar params are correctly checked
    config_def = {
        'creator': ParamReturnCreator(partial(param.Scalar, dtype=str)),
        'val': "10",
    }

    config_inst = process_config(config_def)

    assert config_inst == "10"

    # Check that non scalars cause errors
    config_def['val'] = 10

    try:
        config_inst = process_config(config_def)
    except param.ParamError:
        assert True
    else:
        assert False
예제 #8
0
def test_scalar_param():

    # Check that scalar params are correctly checked
    config_def = {
        'creator': ParamReturnCreator(param.Scalar),
        'val': 10,
    }

    config_inst = process_config(config_def)

    assert config_inst == 10

    # Check that non scalars cause errors
    config_def['val'] = np.arange(1, 6)

    try:
        config_inst = process_config(config_def)
    except param.ParamError:
        assert True
    else:
        assert False
예제 #9
0
    def config(self, sounding_id):

        logging.debug("Loading configuration for sounding: %s" % sounding_id)

        config_def = retrieval_config_definition(self.l1b_file, self.met_file,
                                                 sounding_id)

        config_def['atmosphere']['ground']['child'] = self.ground_type.value

        config_inst = process_config(config_def)
        config_inst.config_def = config_def

        return config_inst
예제 #10
0
def test_instanceof_param():
    class TestClass(object):
        pass

    # Check that iterables are correctly handled
    config_def = {
        'creator': ParamReturnCreator(param.InstanceOf, val_type=TestClass),
        'val': TestClass(),
    }

    config_inst = process_config(config_def)

    assert isinstance(config_inst, TestClass)

    # Force error on type
    config_def['val'] = 5

    try:
        config_inst = process_config(config_def)
    except param.ParamError:
        assert True
    else:
        assert False
예제 #11
0
def test_common_store():

    config_def = {
        'order': ['common', 'use_common'],
        'common': {
            'creator': creator.base.SaveToCommon,
            'x': 5,
            'y': 6,
        },
        'use_common': {
            'creator': AddCreator,
        },
    }

    config_inst = process_config(config_def)

    assert config_inst['use_common'] == 11
예제 #12
0
파일: run_test.py 프로젝트: ReFRACtor/oco
def test_load_example_config():
    data_dir = os.path.realpath('../test/in')
    l1b_file = os.path.join(
        data_dir,
        "oco2_L1bScND_16094a_170711_B7302r_171102090317-selected_ids.h5")
    met_file = os.path.join(
        data_dir,
        "oco2_L2MetND_16094a_170711_B8000r_171017214714-selected_ids.h5")

    sounding_id = "2017071110541471"
    config_def = oco_config.retrieval_config_definition(
        l1b_file, met_file, sounding_id)
    config_inst = process_config(config_def)
    pprint(config_inst, indent=2)
    fm = config_inst.forward_model
    atm = config_inst.atmosphere
    sv = config_inst.retrieval.state_vector
    solver = config_inst.retrieval.solver
예제 #13
0
def test_object_vector():

    config_def = {
        'item1': {
            'creator': ParamReturnCreator(param.ObjectVector),
            'val': rf.vector_double(),
        },
        'item2': {
            'creator': ParamReturnCreator(param.ObjectVector,
                                          vec_type="double"),
            'val': rf.vector_double(),
        }
    }

    config_inst = process_config(config_def)

    assert isinstance(config_inst['item1'], rf.vector_double)
    assert isinstance(config_inst['item2'], rf.vector_double)
예제 #14
0
def test_bound_params():
    class BoundParamCreator(creator.base.Creator):
        some_val = param.Choice(param.Scalar(int), param.Scalar(int))

        def create(self, **kwargs):
            return self.some_val()

    config_def = {
        'item1': {
            'creator': BoundParamCreator,
            'some_val': 10,
        },
        'item2': {
            'creator': BoundParamCreator,
            'some_val': 5,
        },
    }

    config_inst = process_config(config_def)

    assert config_inst['item1'] == 10
    assert config_inst['item2'] == 5
예제 #15
0
                'filename': covariance_file,
            }
        },
        'solver_nlls_gsl': {
            'creator': creator.retrieval.NLLSSolverGSLLMSDER,
            'max_cost_function_calls': 10,
            'dx_tol_abs': 1e-5,
            'dx_tol_rel': 1e-5, 
            'g_tol_abs': 1e-5,
        },
        'solver': {
            'creator': creator.retrieval.ConnorSolverMAP,
            'max_cost_function_calls': 14,
            'threshold': 2.0,
            'max_iteration': 7,
            'max_divergence': 2,
            'max_chisq': 1.4,
            'gamma_initial': 10.0,
        },
    },
}

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)

    config_inst = process_config(config_def)

    from pprint import pprint
    pprint(config_inst, indent=4)