def test_get_extension_by_model(self):
     self.assertIsNone(get_extension_by_model(DummyModel()))
     with self.assertRaisesRegex(ValueError, 'No extension registered which can handle model:'):
         get_extension_by_model(DummyModel(), raise_if_no_extension=True)
     register_extension(DummyExtension1)
     self.assertIsInstance(get_extension_by_model(DummyModel()), DummyExtension1)
     register_extension(DummyExtension2)
     self.assertIsInstance(get_extension_by_model(DummyModel()), DummyExtension1)
     register_extension(DummyExtension1)
     with self.assertRaisesRegex(
         ValueError,
         'Multiple extensions registered which can handle model:',
     ):
         get_extension_by_model(DummyModel())
Ejemplo n.º 2
0
def run_model_on_task(
    model: Any,
    task: OpenMLTask,
    avoid_duplicate_runs: bool = True,
    flow_tags: List[str] = None,
    seed: int = None,
    add_local_measures: bool = True,
    upload_flow: bool = False,
    return_flow: bool = False,
) -> Union[OpenMLRun, Tuple[OpenMLRun, OpenMLFlow]]:
    """Run the model on the dataset defined by the task.

    Parameters
    ----------
    model : sklearn model
        A model which has a function fit(X,Y) and predict(X),
        all supervised estimators of scikit learn follow this definition of a model [1]
        [1](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html)
    task : OpenMLTask
        Task to perform. This may be a model instead if the first argument is an OpenMLTask.
    avoid_duplicate_runs : bool, optional (default=True)
        If True, the run will throw an error if the setup/task combination is already present on
        the server. This feature requires an internet connection.
    flow_tags : List[str], optional (default=None)
        A list of tags that the flow should have at creation.
    seed: int, optional (default=None)
        Models that are not seeded will get this seed.
    add_local_measures : bool, optional (default=True)
        Determines whether to calculate a set of evaluation measures locally,
        to later verify server behaviour.
    upload_flow : bool (default=False)
        If True, upload the flow to OpenML if it does not exist yet.
        If False, do not upload the flow to OpenML.
    return_flow : bool (default=False)
        If True, returns the OpenMLFlow generated from the model in addition to the OpenMLRun.

    Returns
    -------
    run : OpenMLRun
        Result of the run.
    flow : OpenMLFlow (optional, only if `return_flow` is True).
        Flow generated from the model.
    """

    # TODO: At some point in the future do not allow for arguments in old order (6-2018).
    # Flexibility currently still allowed due to code-snippet in OpenML100 paper (3-2019).
    # When removing this please also remove the method `is_estimator` from the extension
    # interface as it is only used here (MF, 3-2019)
    if isinstance(model, OpenMLTask):
        warnings.warn(
            "The old argument order (task, model) is deprecated and "
            "will not be supported in the future. Please use the "
            "order (model, task).", DeprecationWarning)
        task, model = model, task

    extension = get_extension_by_model(model, raise_if_no_extension=True)
    if extension is None:
        # This should never happen and is only here to please mypy will be gone soon once the
        # whole function is removed
        raise TypeError(extension)

    flow = extension.model_to_flow(model)

    run = run_flow_on_task(
        task=task,
        flow=flow,
        avoid_duplicate_runs=avoid_duplicate_runs,
        flow_tags=flow_tags,
        seed=seed,
        add_local_measures=add_local_measures,
        upload_flow=upload_flow,
    )
    if return_flow:
        return run, flow
    return run