def save(training_job_in: TrainingJobIn) -> str: model_id = training_job_in.model if not ModelDAO.exists_by_id(ObjectId(model_id)): raise ServiceException(f'Model with ID {model_id} not exist.') training_job = TrainingJob(**training_job_in.dict(exclude_none=True)) return _collection.insert_one( training_job.dict(exclude_none=True)).inserted_id
def get_by_id(id: str) -> TrainingJob: # exists by ID if not bool( _collection.count_documents(filter={'_id': ObjectId(id)}, limit=1)): raise ValueError(f'id {id} not found.') document = _collection.find_one(filter={'_id': ObjectId(id)}) training_job = TrainingJob(**document) return training_job
def from_training_job(cls, training_job: TrainingJob) -> 'PyTorchTrainer': # TODO: only support fine-tune model_bo = ModelService.get_model_by_id(training_job.model) if model_bo.engine != Engine.PYTORCH: raise ValueError( f'Model engine expected `{Engine.PYTORCH}`, but got {model_bo.engine}.' ) # download local cache cache_path = get_remote_model_weight(model_bo) net = torch.load(cache_path) freeze(module=net, n=-1, train_bn=True) # build pytorch lightning module fine_tune_module_kwargs = { 'net': net, 'loss': eval(str(training_job.loss_function))(), # nosec 'batch_size': training_job.data_module.batch_size, 'num_workers': training_job.data_module.num_workers, } if training_job.optimizer_property.lr: fine_tune_module_kwargs['lr'] = training_job.optimizer_property.lr if training_job.lr_scheduler_property.gamma: fine_tune_module_kwargs[ 'lr_scheduler_gamma'] = training_job.lr_scheduler_property.gamma if training_job.lr_scheduler_property.step_size: fine_tune_module_kwargs[ 'step_size'] = training_job.lr_scheduler_property.step_size model = FineTuneModule(**fine_tune_module_kwargs) data_module = PyTorchDataModule(**training_job.data_module.dict( exclude_none=True)) trainer_kwargs = training_job.dict( exclude_none=True, include={'min_epochs', 'max_epochs'}) trainer = cls( id=training_job.id, model=model, data_loader_kwargs={'datamodule': data_module}, trainer_kwargs={ 'default_root_dir': training_job.data_module.data_dir or OUTPUT_DIR, 'weights_summary': None, 'progress_bar_refresh_rate': 1, 'num_sanity_val_steps': 0, 'gpus': 1, # TODO: set GPU number **trainer_kwargs, }) return trainer
def delete_by_id(id: str) -> TrainingJob: document: dict = _collection.find_one_and_delete( filter={'_id': ObjectId(id)}) return TrainingJob(**document)
def get_all() -> List[TrainingJob]: cursor = _collection.find({}) training_jobs = list(map(lambda d: TrainingJob(**d), cursor)) return training_jobs