def _save_model(model_inst: BaseModel, proposal: Proposal, result: TrialResult, param_cache: ParamCache, param_store: ParamStore): if not proposal.to_cache_params and not proposal.to_save_params: return None print('Dumping model parameters...') params = model_inst.dump_parameters() if proposal.to_cache_params: print('Storing shared params in cache...') param_cache.store_params(params, score=result.score, time=datetime.now()) store_params_id = None if proposal.to_save_params: print('Saving shared params...') store_params_id = param_store.save(params) return store_params_id
class TrainWorker: def __init__(self, service_id, worker_id): self._worker_id = worker_id self._monitor: _SubTrainJobMonitor = _SubTrainJobMonitor(service_id) self._redis_host = os.environ['REDIS_HOST'] self._redis_port = os.environ['REDIS_PORT'] self._param_store: ParamStore = FileParamStore() self._trial_id = None # ID of currently running trial self._train_cache: TrainCache = None self._param_cache: ParamCache = None self._trial_errors = 0 # Consecutive traial errors def start(self): self._monitor.pull_job_info() self._train_cache = TrainCache(self._monitor.sub_train_job_id, self._redis_host, self._redis_port) self._param_cache = ParamCache(self._monitor.sub_train_job_id, self._redis_host, self._redis_port) logger.info( f'Starting worker for sub train job "{self._monitor.sub_train_job_id}"...' ) self._notify_start() while True: proposal = self._fetch_proposal() if proposal is not None: result = self._perform_trial(proposal) self._submit_result(result) time.sleep(LOOP_SLEEP_SECS) def stop(self): self._notify_stop() # If worker is currently running a trial, mark it has errored try: if self._trial_id is not None: self._monitor.mark_trial_as_errored(self._trial_id) except: logger.error('Error marking trial as errored:') logger.error(traceback.format_exc()) # Run model class teardown try: self._monitor.model_class.teardown() except: logger.error('Error tearing down model class:') logger.error(traceback.format_exc()) def _notify_start(self): superadmin_client().send_event( 'train_job_worker_started', sub_train_job_id=self._monitor.sub_train_job_id) self._train_cache.add_worker(self._worker_id) def _fetch_proposal(self): proposal = self._train_cache.get_proposal(self._worker_id) return proposal def _perform_trial(self, proposal: Proposal) -> TrialResult: self._trial_id = proposal.trial_id logger.info( f'Starting trial {self._trial_id} with proposal {proposal}...') try: # Setup logging logger_info = self._start_logging_to_trial( lambda log_line, log_lvl: self._monitor.log_to_trial( self._trial_id, log_line, log_lvl)) self._monitor.mark_trial_as_running(self._trial_id, proposal) shared_params = self._pull_shared_params(proposal) model_inst = self._load_model(proposal) self._train_model(model_inst, proposal, shared_params) result = self._evaluate_model(model_inst, proposal) store_params_id = self._save_model(model_inst, proposal, result) model_inst.destroy() self._monitor.mark_trial_as_completed(self._trial_id, result.score, store_params_id) self._trial_errors = 0 return result except Exception as e: logger.error('Error while running trial:') logger.error(traceback.format_exc()) self._monitor.mark_trial_as_errored(self._trial_id) # Ensure that trial doesn't error too many times consecutively self._trial_errors += 1 if self._trial_errors > MAX_CONSEC_TRIAL_ERRORS: logger.error( f'Reached {MAX_CONSEC_TRIAL_ERRORS} consecutive errors - raising exception' ) raise e return TrialResult(proposal) finally: self._stop_logging_to_trial(logger_info) # Untie from done trial self._trial_id = None def _notify_stop(self): self._train_cache.delete_worker(self._worker_id) superadmin_client().send_event( 'train_job_worker_stopped', sub_train_job_id=self._monitor.sub_train_job_id) def _start_logging_to_trial(self, handle_log): # Add log handlers for trial, including adding handler to root logger # to capture any logs emitted with level above INFO during model training & evaluation log_handler = LoggerUtilsHandler(handle_log) py_model_logger = logging.getLogger('{}.trial'.format(__name__)) py_model_logger.setLevel(logging.INFO) py_model_logger.propagate = False # Avoid duplicate logs in root logger py_model_logger.addHandler(log_handler) model_logger.set_logger(py_model_logger) root_logger = logging.getLogger() root_logger.addHandler(log_handler) return (root_logger, py_model_logger, log_handler) def _load_model(self, proposal: Proposal): logger.info('Creating model instance...') py_model_class = self._monitor.model_class model_inst = py_model_class(**proposal.knobs) return model_inst def _pull_shared_params(self, proposal: Proposal): if proposal.params_type == ParamsType.NONE: return None logger.info('Retrieving shared params from cache...') shared_params = self._param_cache.retrieve_params(proposal.params_type) return shared_params def _train_model(self, model_inst: BaseModel, proposal: Proposal, shared_params: Union[dict, None]): train_dataset_path = self._monitor.train_dataset_path train_args = self._monitor.train_args logger.info('Training model...') model_inst.train(train_dataset_path, shared_params=shared_params, **(train_args or {})) def _evaluate_model(self, model_inst: BaseModel, proposal: Proposal) -> TrialResult: val_dataset_path = self._monitor.val_dataset_path if not proposal.to_eval: return TrialResult(proposal) logger.info('Evaluating model...') score = model_inst.evaluate(val_dataset_path) logger.info(f'Score on validation dataset: {score}') return TrialResult(proposal, score=score) def _save_model(self, model_inst: BaseModel, proposal: Proposal, result: TrialResult): if not proposal.to_cache_params and not proposal.to_save_params: return None logger.info('Dumping model parameters...') params = model_inst.dump_parameters() if proposal.to_cache_params: logger.info('Storing shared params in cache...') self._param_cache.store_params(params, score=result.score, time=datetime.now()) store_params_id = None if proposal.to_save_params: logger.info('Saving shared params...') store_params_id = self._param_store.save(params) return store_params_id def _submit_result(self, result: TrialResult): self._train_cache.create_result(self._worker_id, result) self._train_cache.delete_proposal(self._worker_id) def _stop_logging_to_trial(self, logger_info): (root_logger, py_model_logger, log_handler) = logger_info # Remove log handlers from loggers for this trial root_logger.removeHandler(log_handler) py_model_logger.removeHandler(log_handler)