Example #1
0
def test_subclass(subclass, *args, **kwargs):
    """
    Tests a provided :class:`~prism.modellink.ModelLink` `subclass` by
    initializing it with the given `args` and `kwargs` and checking if all
    required methods can be properly called.

    This function needs to be called by all MPI ranks.

    Parameters
    ----------
    subclass : :class:`~prism.modellink.ModelLink` subclass
        The :class:`~prism.modellink.ModelLink` subclass that requires testing.
    args : positional arguments
        Positional arguments that need to be provided to the constructor of the
        `subclass`.
    kwargs : keyword arguments
        Keyword arguments that need to be provided to the constructor of the
        `subclass`.

    Returns
    -------
    modellink_obj : :obj:`~prism.modellink.ModelLink` object
        Instance of the provided `subclass` if all tests pass successfully.
        Specific exceptions are raised if a test fails.

    Note
    ----
    Depending on the complexity of the model wrapped in the given `subclass`,
    this function may take a while to execute.

    """

    # Import ModelLink class
    from prism.modellink import ModelLink

    # Check if provided subclass is a class
    if not isclass(subclass):
        raise e13.InputError("Input argument 'subclass' must be a class!")

    # Check if provided subclass is a subclass of ModelLink
    if not issubclass(subclass, ModelLink):
        raise TypeError("Input argument 'subclass' must be a subclass of the "
                        "ModelLink class!")

    # Try to initialize provided subclass
    try:
        modellink_obj = subclass(*args, **kwargs)
    except Exception as error:
        raise e13.InputError("Input argument 'subclass' cannot be initialized!"
                             " (%s)" % (error))

    # Check if modellink_obj was initialized properly
    if not e13.check_instance(modellink_obj, ModelLink):
        obj_name = modellink_obj.__class__.__name__
        raise e13.InputError("Provided ModelLink subclass %r was not "
                             "initialized properly! Make sure that %r calls "
                             "the super constructor during initialization!"
                             % (obj_name, obj_name))

    # Obtain list of arguments call_model should take
    call_model_args = list(signature(ModelLink.call_model).parameters)
    call_model_args.remove('self')

    # Check if call_model takes the correct arguments
    obj_call_model_args = dict(signature(modellink_obj.call_model).parameters)
    for arg in call_model_args:
        if arg not in obj_call_model_args.keys():
            raise e13.InputError("The 'call_model()'-method in provided "
                                 "ModelLink subclass %r does not take required"
                                 " input argument %r!"
                                 % (modellink_obj._name, arg))
        else:
            obj_call_model_args.pop(arg)

    # Check if call_model takes any other arguments
    for arg, par in obj_call_model_args.items():
        # If this parameter has no default value and is not *args or **kwargs
        if(par.default == _empty and par.kind != _VAR_POSITIONAL and
           par.kind != _VAR_KEYWORD):
            # Raise error
            raise e13.InputError("The 'call_model()'-method in provided "
                                 "ModelLink subclass %r takes an unknown "
                                 "non-optional input argument %r!"
                                 % (modellink_obj._name, arg))

    # Obtain list of arguments get_md_var should take
    get_md_var_args = list(signature(ModelLink.get_md_var).parameters)
    get_md_var_args.remove('self')

    # Check if get_md_var takes the correct arguments
    obj_get_md_var_args = dict(signature(modellink_obj.get_md_var).parameters)
    for arg in get_md_var_args:
        if arg not in obj_get_md_var_args.keys():
            raise e13.InputError("The 'get_md_var()'-method in provided "
                                 "ModelLink subclass %r does not take required"
                                 " input argument %r!"
                                 % (modellink_obj._name, arg))
        else:
            obj_get_md_var_args.pop(arg)

    # Check if get_md_var takes any other arguments
    for arg, par in obj_get_md_var_args.items():
        # If this parameter has no default value and is not *args or **kwargs
        if(par.default == _empty and par.kind != _VAR_POSITIONAL and
           par.kind != _VAR_KEYWORD):
            # Raise an error
            raise e13.InputError("The 'get_md_var()'-method in provided "
                                 "ModelLink subclass %r takes an unknown "
                                 "non-optional input argument %r!"
                                 % (modellink_obj._name, arg))

    # Set MPI intra-communicator
    comm = get_HybridComm_obj()

    # Obtain random sam_set on controller
    if not comm._rank:
        sam_set = modellink_obj._to_par_space(rand(1, modellink_obj._n_par))
    # Workers get dummy sam_set
    else:
        sam_set = []

    # Broadcast random sam_set to workers
    sam_set = comm.bcast(sam_set, 0)

    # Try to evaluate sam_set in the model
    try:
        # Check who needs to call the model
        if not comm._rank or modellink_obj._MPI_call:
            # Do multi-call
            if modellink_obj._multi_call:
                mod_set = modellink_obj.call_model(
                    emul_i=0,
                    par_set=modellink_obj._get_sam_dict(sam_set),
                    data_idx=modellink_obj._data_idx)

            # Single-call
            else:
                # Initialize mod_set
                mod_set = np.zeros([sam_set.shape[0], modellink_obj._n_data])

                # Loop over all samples in sam_set
                for i, par_set in enumerate(sam_set):
                    mod_set[i] = modellink_obj.call_model(
                        emul_i=0,
                        par_set=modellink_obj._get_sam_dict(par_set),
                        data_idx=modellink_obj._data_idx)

    # If call_model was not overridden, catch NotImplementedError
    except NotImplementedError:
        raise NotImplementedError("Provided ModelLink subclass %r has no "
                                  "user-written 'call_model()'-method!"
                                  % (modellink_obj._name))

    # If successful, check if obtained mod_set has correct shape
    if not comm._rank:
        mod_set = modellink_obj._check_mod_set(mod_set, 'mod_set')

    # Check if the model discrepancy variance can be obtained
    try:
        md_var = modellink_obj.get_md_var(
            emul_i=0,
            par_set=modellink_obj._get_sam_dict(sam_set[0]),
            data_idx=modellink_obj._data_idx)

    # If get_md_var was not overridden, catch NotImplementedError
    except NotImplementedError:
        warn_msg = ("Provided ModelLink subclass %r has no user-written "
                    "'get_md_var()'-method! Default model discrepancy variance"
                    " description would be used instead!"
                    % (modellink_obj._name))
        warnings.warn(warn_msg, RequestWarning, stacklevel=2)

    # If successful, check if obtained md_var has correct shape
    else:
        md_var = modellink_obj._check_md_var(md_var, 'md_var')

    # Return modellink_obj
    return(modellink_obj)
Example #2
0
 def test_invalid_comm(self):
     with pytest.raises(TypeError):
         get_HybridComm_obj(0)
Example #3
0
 def test_comm_size_unity(self):
     s_comm = comm.Split(comm.Get_rank(), 0)
     assert get_HybridComm_obj(s_comm) is d_comm
     s_comm.Free()
Example #4
0
 def test_d_comm(self):
     assert get_HybridComm_obj(d_comm) is d_comm
Example #5
0
 def test_h_comm(self):
     assert get_HybridComm_obj(h_comm) is h_comm
Example #6
0
 def test_default(self):
     assert get_HybridComm_obj() is h_comm