Beispiel #1
0
    def get_collector(self,
                      process_list=[RollingGroup()],
                      rec_key_func=None,
                      rec_filter_func=None,
                      artifacts_key=None):
        """
        Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.

        Assumption: the models can be distinguished based on the model name and rolling test segments.
        If you do not want this assumption, please implement your method or use another rec_key_func.

        Args:
            rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
            rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
            artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
        """
        def rec_key(recorder):
            task_config = recorder.load_object("task")
            model_key = task_config["model"]["class"]
            rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
            return model_key, rolling_key

        if rec_key_func is None:
            rec_key_func = rec_key

        artifacts_collector = RecorderCollector(
            experiment=self.exp_name,
            process_list=process_list,
            rec_key_func=rec_key_func,
            rec_filter_func=rec_filter_func,
            artifacts_key=artifacts_key,
        )

        return artifacts_collector
Beispiel #2
0
    def task_collecting(self):
        print("========== task_collecting ==========")

        def rec_key(recorder):
            task_config = recorder.load_object("task")
            model_key = task_config["model"]["class"]
            rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
            return model_key, rolling_key

        def my_filter(recorder):
            # only choose the results of "LGBModel"
            model_key, rolling_key = rec_key(recorder)
            if model_key == "LGBModel":
                return True
            return False

        collector = RecorderCollector(
            experiment=self.experiment_name,
            process_list=RollingGroup(),
            rec_key_func=rec_key,
            rec_filter_func=my_filter,
        )
        print(collector())