def train( self, tasks: list, train_func: Callable = None, experiment_name: str = None, before_status: str = TaskManager.STATUS_WAITING, after_status: str = TaskManager.STATUS_DONE, **kwargs, ) -> List[Recorder]: """ Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. This method defaults to a single process, but TaskManager offered a great way to parallel training. Users can customize their train_func to realize multiple processes or even multiple machines. Args: tasks (list): a list of definitions based on `task` dict train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method. experiment_name (str): the experiment name, None for use default name. before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. kwargs: the params for train_func. Returns: List[Recorder]: a list of Recorders """ if isinstance(tasks, dict): tasks = [tasks] if len(tasks) == 0: return [] if train_func is None: train_func = self.train_func if experiment_name is None: experiment_name = self.experiment_name task_pool = self.task_pool if task_pool is None: task_pool = experiment_name tm = TaskManager(task_pool=task_pool) _id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB query = {"_id": {"$in": _id_list}} run_task( train_func, task_pool, query=query, # only train these tasks experiment_name=experiment_name, before_status=before_status, after_status=after_status, **kwargs, ) if not self.is_delay(): tm.wait(query=query) recs = [] for _id in _id_list: rec = tm.re_query(_id)["res"] rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) rec.set_tags(**{self.TM_ID: _id}) recs.append(rec) return recs
def worker(self): # NOTE: this is only used for TrainerRM # train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker. print("========== worker ==========") run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. Args: recs (list): a list of Recorder, the tasks have been saved to them. end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. experiment_name (str): the experiment name, None for use default name. kwargs: the params for end_train_func. Returns: List[Recorder]: a list of Recorders """ if isinstance(recs, Recorder): recs = [recs] if end_train_func is None: end_train_func = self.end_train_func if experiment_name is None: experiment_name = self.experiment_name task_pool = self.task_pool if task_pool is None: task_pool = experiment_name _id_list = [] for rec in recs: _id_list.append(rec.list_tags()[self.TM_ID]) query = {"_id": {"$in": _id_list}} if not self.skip_run_task: run_task( end_train_func, task_pool, query=query, # only train these tasks experiment_name=experiment_name, before_status=TaskManager.STATUS_PART_DONE, **kwargs, ) TaskManager(task_pool=task_pool).wait(query=query) for rec in recs: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``. Args: recs (list): a list of Recorder, the tasks have been saved to them. end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. experiment_name (str): the experiment name, None for use default name. kwargs: the params for end_train_func. Returns: List[Recorder]: a list of Recorders """ if end_train_func is None: end_train_func = self.end_train_func if experiment_name is None: experiment_name = self.experiment_name task_pool = self.task_pool if task_pool is None: task_pool = experiment_name tasks = [] for rec in recs: tasks.append(rec.load_object("task")) run_task( end_train_func, task_pool, query={"filter": { "$in": tasks }}, # only train these tasks experiment_name=experiment_name, before_status=TaskManager.STATUS_PART_DONE, **kwargs, ) for rec in recs: rec.set_tags(**{self.STATUS_KEY: self.STATUS_END}) return recs
def worker( self, train_func: Callable = None, experiment_name: str = None, ): """ The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines. Args: train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method. experiment_name (str): the experiment name, None for use default name. """ if train_func is None: train_func = self.train_func if experiment_name is None: experiment_name = self.experiment_name task_pool = self.task_pool if task_pool is None: task_pool = experiment_name run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
def worker(self, end_train_func=None, experiment_name: str = None): """ The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines. Args: end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. experiment_name (str): the experiment name, None for use default name. """ if end_train_func is None: end_train_func = self.end_train_func if experiment_name is None: experiment_name = self.experiment_name task_pool = self.task_pool if task_pool is None: task_pool = experiment_name run_task( end_train_func, task_pool=task_pool, experiment_name=experiment_name, before_status=TaskManager.STATUS_PART_DONE, )