def _train_single_model(self, model_desc=None, model_id=None, weights_file=None): cls_trainer = ClassFactory.get_cls('trainer') step_name = self.task.step_name if model_desc is not None: sample = dict(worker_id=model_id, desc=model_desc, step_name=step_name) record = ReportRecord().load_dict(sample) logging.debug("Broadcast Record=%s", str(record)) trainer = cls_trainer(model_desc=model_desc, id=model_id, pretrained_model_file=weights_file) else: trainer = cls_trainer(None, 0) record = ReportRecord(trainer.step_name, trainer.worker_id, desc=trainer.model_desc) ReportClient.broadcast(record) ReportServer.add_watched_var(trainer.step_name, trainer.worker_id) # resume training if vega.is_torch_backend() and General._resume: trainer.load_checkpoint = True trainer._resume_training = True if self._distributed_training: self._do_distributed_fully_train(trainer) else: self._do_single_fully_train(trainer)
def _dispatch_trainer(self, samples): for (id, desc, hps) in samples: cls_trainer = ClassFactory.get_cls(ClassType.TRAINER) TrainerConfig.from_dict(self.user_trainer_config) trainer = cls_trainer(id=id, model_desc=desc, hps=hps) evaluator = self._get_evaluator(trainer) logging.info("submit trainer, id={}".format(id)) ReportServer.add_watched_var(General.step_name, trainer.worker_id) self.master.run(trainer, evaluator)