def __init__(self, config: Config, dataset: Dataset, parent_job: Job = None, model=None) -> None: from kge.job import EvaluationJob super().__init__(config, dataset, parent_job) if model is None: self.model: KgeModel = KgeModel.create(config, dataset) else: self.model: KgeModel = model self.optimizer = KgeOptimizer.create(config, self.model) self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get( "train.trace_level") == "batch" self.epoch: int = 0 self.valid_trace: List[Dict[str, Any]] = [] valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": valid_conf.set("eval.split", self.config.get("valid.split")) valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) self.valid_job = EvaluationJob.create(valid_conf, dataset, parent_job=self, model=self.model) # attributes filled in by implementing classes self.loader = None self.num_examples = None self.type_str: Optional[str] = None #: Hooks run after training for an epoch. #: Signature: job, trace_entry self.post_epoch_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run before starting a batch. #: Signature: job self.pre_batch_hooks: List[Callable[[Job], Any]] = [] #: Hooks run before outputting the trace of a batch. Can modify trace entry. #: Signature: job, trace_entry self.post_batch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run before outputting the trace of an epoch. Can modify trace entry. #: Signature: job, trace_entry self.post_epoch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run after a validation job. #: Signature: job, trace_entry self.post_valid_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run after training #: Signature: job, trace_entry self.post_train_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] if self.__class__ == TrainingJob: for f in Job.job_created_hooks: f(self) self.model.train()
def __init__( self, config, dataset, parent_job=None, model=None, optimizer=None, forward_only=False, parameter_client=None, init_for_load_only=False, ): self.parameter_client = parameter_client self.min_rank = get_min_rank(config) self.work_scheduler_client = SchedulerClient(config) ( max_partition_entities, max_partition_relations, ) = self.work_scheduler_client.get_init_info() if model is None: model: KgeModel = KgeModel.create( config, dataset, parameter_client=parameter_client, max_partition_entities=max_partition_entities, ) model.get_s_embedder().to_device() model.get_p_embedder().to_device() lapse_indexes = [ torch.arange(dataset.num_entities(), dtype=torch.int), torch.arange(dataset.num_relations(), dtype=torch.int) + dataset.num_entities(), ] if optimizer is None: optimizer = KgeOptimizer.create( config, model, parameter_client=parameter_client, lapse_indexes=lapse_indexes, ) # barrier to wait for loading of pretrained embeddings self.parameter_client.barrier() super().__init__( config, dataset, parent_job, model=model, optimizer=optimizer, forward_only=forward_only, parameter_client=parameter_client, ) self.type_str = "negative_sampling" self.load_batch = self.config.get("job.distributed.load_batch") self.entity_localize = self.config.get("job.distributed.entity_localize") self.relation_localize = self.config.get("job.distributed.relation_localize") self.entity_partition_localized = False self.relation_partition_localized = False self.entity_async_write_back = self.config.get( "job.distributed.entity_async_write_back" ) self.relation_async_write_back = self.config.get( "job.distributed.relation_async_write_back" ) self.entity_sync_level = self.config.get("job.distributed.entity_sync_level") self.relation_sync_level = self.config.get( "job.distributed.relation_sync_level" ) self.entity_pre_pull = self.config.get("job.distributed.entity_pre_pull") self.relation_pre_pull = self.config.get("job.distributed.relation_pre_pull") self.pre_localize_batch = int( self.config.get("job.distributed.pre_localize_batch") ) self.entity_mapper_tensors = deque() for i in range(self.config.get("train.num_workers") + 1): self.entity_mapper_tensors.append( torch.full((self.dataset.num_entities(),), -1, dtype=torch.long) ) self._initialize_parameter_server(init_for_load_only=init_for_load_only) def stop_and_wait(job): job.parameter_client.stop() job.parameter_client.barrier() self.early_stop_hooks.append(stop_and_wait) def check_stopped(job): print("checking for", job.parameter_client.rank) job.parameter_client.barrier() return job.parameter_client.is_stopped() self.early_stop_conditions.append(check_stopped) self.work_pre_localized = False if self.config.get("job.distributed.pre_localize_partition"): self.pre_localized_entities = None self.pre_localized_relations = None self.pre_batch_hooks.append(self._pre_localize_work) if self.__class__ == TrainingJobNegativeSamplingDistributed: for f in Job.job_created_hooks: f(self)