コード例 #1
0
def _save_list_of_estimators(
    hdf_file: tables.File,
    group: tables.Group,
    estimator_list: List[BaseEstimator],
    fitted: bool,
):
    hdf_file.set_node_attr(group, "__type__",
                           GroupType.LIST_OF_ESTIMATORS.name)
    hdf_file.set_node_attr(group, "len", len(estimator_list))
    for i, estimator in enumerate(estimator_list):
        sub_group = hdf_file.create_group(group, f"item_{i}")
        _save_estimator_to_group(hdf_file, sub_group, estimator, fitted)
コード例 #2
0
def _save_params_to_group(hdf_file: tables.File, group: tables.Group,
                          params_dict: dict, fitted: bool):
    for param_name, param_value in params_dict.items():
        if is_estimator(param_value):
            param_group = hdf_file.create_group(group, param_name)
            _save_estimator_to_group(hdf_file, param_group, param_value,
                                     fitted)
        elif is_list_of_named_estimators(param_value):
            param_group = hdf_file.create_group(group, param_name)
            _save_list_of_named_estimators(hdf_file, param_group, param_value,
                                           fitted)
        elif is_list_of_estimators(param_value):
            param_group = hdf_file.create_group(group, param_name)
            _save_list_of_estimators(hdf_file, param_group, param_value,
                                     fitted)
        else:
            hdf_file.set_node_attr(group, param_name, param_value)
コード例 #3
0
def _save_list_of_named_estimators(
    hdf_file: tables.File,
    group: tables.Group,
    estimator_list: List[Tuple[str, BaseEstimator, Any]],
    fitted: bool,
):
    hdf_file.set_node_attr(group, "__type__",
                           GroupType.LIST_OF_NAMED_ESTIMATORS.name)
    hdf_file.set_node_attr(group, "names",
                           [n for (n, e, *r) in estimator_list])
    hdf_file.set_node_attr(group, "rests",
                           [r for (n, e, *r) in estimator_list])
    for (name, estimator, *_rest) in estimator_list:
        sub_group = hdf_file.create_group(group, name)
        _save_estimator_to_group(hdf_file, sub_group, estimator, fitted)
コード例 #4
0
def _save_validation_to_group(
    hdf_file: tables.File,
    group: tables.Group,
    estimator: BaseEstimator,
    validation_func: str,
    validation_data: Any,
    is_validation_array: bool,
):
    hdf_file.set_node_attr(group, "validation_func", validation_func)
    if is_validation_array:
        # this mode handle well large inputs, but might cast the array, and don't work with mixed type arrays
        _save_array_to_group(hdf_file, group, "X", "input", validation_data)
        y = getattr(estimator, validation_func)(group["X"])
        _save_array_to_group(hdf_file, group, "y", "expected_output", y)
    else:
        hdf_file.set_node_attr(group, "X", validation_data)
        y = getattr(estimator, validation_func)(validation_data)
        hdf_file.set_node_attr(group, "y", y)
コード例 #5
0
def _save_estimator_to_group(hdf_file: tables.File, group: tables.Group,
                             estimator: BaseEstimator, fitted: bool):
    # save estimator metadata
    class_name = estimator.__class__.__module__ + "." + estimator.__class__.__name__
    module_version = getattr(__import__(estimator.__class__.__module__),
                             "__version__")
    hdf_file.set_node_attr(group, "__class_name__", class_name)
    hdf_file.set_node_attr(group, "__module_version__", module_version)
    hdf_file.set_node_attr(group, "__type__", GroupType.ESTIMATOR.name)

    # save params
    params_dict = get_params_dict(estimator)
    # one would expect that those params are not fitted, and fitted can be set to Flase
    # but some of them (for example pipeline.Pipeline.steps) do includes fitted estimators.
    _save_params_to_group(hdf_file, group, params_dict, fitted=fitted)

    if fitted:
        # create fit group
        fit_group = hdf_file.create_group(group, FIT_GROUP)
        hdf_file.set_node_attr(fit_group, "__type__",
                               GroupType.FITTED_ATTRIBUTES.name)
        # save fit params
        fit_params_dict = get_fit_params_dict(estimator)
        _save_params_to_group(hdf_file, fit_group, fit_params_dict, fitted)