def __init__( self, config: Config, dataset: Dataset, parent_job: Job = None ) -> None: from kge.job import EvaluationJob super().__init__(config, dataset, parent_job) self.model: KgeModel = KgeModel.create(config, dataset) self.optimizer = KgeOptimizer.create(config, self.model) self.lr_scheduler, self.metric_based_scheduler = KgeLRScheduler.create( config, self.optimizer ) self.loss = KgeLoss.create(config) self.batch_size: int = config.get("train.batch_size") self.device: str = self.config.get("job.device") valid_conf = config.clone() valid_conf.set("job.type", "eval") valid_conf.set("eval.data", "valid") valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) self.valid_job = EvaluationJob.create( valid_conf, dataset, parent_job=self, model=self.model ) self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get("train.trace_level") == "batch" self.epoch: int = 0 self.valid_trace: List[Dict[str, Any]] = [] self.is_prepared = False self.model.train() # attributes filled in by implementing classes self.loader = None self.num_examples = None self.type_str: Optional[str] = None #: Hooks run after training for an epoch. #: Signature: job, trace_entry self.post_epoch_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run before starting a batch. #: Signature: job self.pre_batch_hooks: List[Callable[[Job], Any]] = [] #: Hooks run before outputting the trace of a batch. Can modify trace entry. #: Signature: job, trace_entry self.post_batch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run before outputting the trace of an epoch. Can modify trace entry. #: Signature: job, trace_entry self.post_epoch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run after a validation job. #: Signature: job, trace_entry self.post_valid_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run after training #: Signature: job, trace_entry self.post_train_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] if self.__class__ == TrainingJob: for f in Job.job_created_hooks: f(self)
def __init__( self, config: Config, dataset: Dataset, parent_job: Job = None, model=None, forward_only=False, ) -> None: from kge.job import EvaluationJob super().__init__(config, dataset, parent_job) if model is None: self.model: KgeModel = KgeModel.create(config, dataset) else: self.model: KgeModel = model self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") self._subbatch_auto_tune: bool = config.get("train.subbatch_auto_tune") self._max_subbatch_size: int = config.get("train.subbatch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get( "train.trace_level") == "batch" self.epoch: int = 0 self.is_forward_only = forward_only if not self.is_forward_only: self.model.train() self.optimizer = KgeOptimizer.create(config, self.model) self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.valid_trace: List[Dict[str, Any]] = [] valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": valid_conf.set("eval.split", self.config.get("valid.split")) valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) self.valid_job = EvaluationJob.create(valid_conf, dataset, parent_job=self, model=self.model) # attributes filled in by implementing classes self.loader = None self.num_examples = None self.type_str: Optional[str] = None # Hooks run after validation. The corresponding valid trace entry can be found # in self.valid_trace[-1] Signature: job self.post_valid_hooks: List[Callable[[Job], Any]] = [] if self.__class__ == TrainingJob: for f in Job.job_created_hooks: f(self)