Beispiel #1
0
    def test_normalize_phases(self):
        model = KgeModel.create(self.config, self.dataset)
        model.eval()
        num_entities = self.dataset.num_entities()
        num_relations = self.dataset.num_relations()

        # start with embeddings outside of [-pi,pi]
        data = model.get_p_embedder()._embeddings.weight.data
        data[:] = (torch.rand(data.shape) - 0.5) * 100

        # perform initial predictions
        s = torch.arange(num_entities).repeat_interleave(num_relations *
                                                         num_entities)
        p = (torch.arange(num_relations).repeat_interleave(
            num_entities).repeat(num_entities))
        o = torch.arange(num_entities).repeat(num_relations * num_entities)
        scores_org = model.score_spo(s, p, o)

        # now normalize phases
        model.normalize_phases()

        # check if predictions are unaffected
        scores_new = model.score_spo(s, p, o)
        self.assertTrue(
            torch.allclose(scores_org, scores_new),
            msg="test that normalizing phases does not change predictions",
        )

        # check that phases are normalized
        data = model.get_p_embedder()._embeddings.weight.data
        self.assertTrue(
            torch.all((data >= -math.pi) & (data < math.pi)),
            msg="check that phases are normalized",
        )
Beispiel #2
0
    def __init__(
        self, config: Config, dataset: Dataset, parent_job: Job = None
    ) -> None:
        from kge.job import EvaluationJob

        super().__init__(config, dataset, parent_job)
        self.model: KgeModel = KgeModel.create(config, dataset)
        self.optimizer = KgeOptimizer.create(config, self.model)
        self.lr_scheduler, self.metric_based_scheduler = KgeLRScheduler.create(
            config, self.optimizer
        )
        self.loss = KgeLoss.create(config)
        self.batch_size: int = config.get("train.batch_size")
        self.device: str = self.config.get("job.device")
        valid_conf = config.clone()
        valid_conf.set("job.type", "eval")
        valid_conf.set("eval.data", "valid")
        valid_conf.set("eval.trace_level", self.config.get("valid.trace_level"))
        self.valid_job = EvaluationJob.create(
            valid_conf, dataset, parent_job=self, model=self.model
        )
        self.config.check("train.trace_level", ["batch", "epoch"])
        self.trace_batch: bool = self.config.get("train.trace_level") == "batch"
        self.epoch: int = 0
        self.valid_trace: List[Dict[str, Any]] = []
        self.is_prepared = False
        self.model.train()

        # attributes filled in by implementing classes
        self.loader = None
        self.num_examples = None
        self.type_str: Optional[str] = None

        #: Hooks run after training for an epoch.
        #: Signature: job, trace_entry
        self.post_epoch_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        #: Hooks run before starting a batch.
        #: Signature: job
        self.pre_batch_hooks: List[Callable[[Job], Any]] = []

        #: Hooks run before outputting the trace of a batch. Can modify trace entry.
        #: Signature: job, trace_entry
        self.post_batch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        #: Hooks run before outputting the trace of an epoch. Can modify trace entry.
        #: Signature: job, trace_entry
        self.post_epoch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        #: Hooks run after a validation job.
        #: Signature: job, trace_entry
        self.post_valid_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        #: Hooks run after training
        #: Signature: job, trace_entry
        self.post_train_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        if self.__class__ == TrainingJob:
            for f in Job.job_created_hooks:
                f(self)
Beispiel #3
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        parent_job: Job = None,
        model=None,
        forward_only=False,
    ) -> None:
        from kge.job import EvaluationJob

        super().__init__(config, dataset, parent_job)
        if model is None:
            self.model: KgeModel = KgeModel.create(config, dataset)
        else:
            self.model: KgeModel = model
        self.loss = KgeLoss.create(config)
        self.abort_on_nan: bool = config.get("train.abort_on_nan")
        self.batch_size: int = config.get("train.batch_size")
        self._subbatch_auto_tune: bool = config.get("train.subbatch_auto_tune")
        self._max_subbatch_size: int = config.get("train.subbatch_size")
        self.device: str = self.config.get("job.device")
        self.train_split = config.get("train.split")

        self.config.check("train.trace_level", ["batch", "epoch"])
        self.trace_batch: bool = self.config.get(
            "train.trace_level") == "batch"
        self.epoch: int = 0
        self.is_forward_only = forward_only

        if not self.is_forward_only:
            self.model.train()
            self.optimizer = KgeOptimizer.create(config, self.model)
            self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer)

            self.valid_trace: List[Dict[str, Any]] = []
            valid_conf = config.clone()
            valid_conf.set("job.type", "eval")
            if self.config.get("valid.split") != "":
                valid_conf.set("eval.split", self.config.get("valid.split"))
            valid_conf.set("eval.trace_level",
                           self.config.get("valid.trace_level"))
            self.valid_job = EvaluationJob.create(valid_conf,
                                                  dataset,
                                                  parent_job=self,
                                                  model=self.model)

        # attributes filled in by implementing classes
        self.loader = None
        self.num_examples = None
        self.type_str: Optional[str] = None

        # Hooks run after validation. The corresponding valid trace entry can be found
        # in self.valid_trace[-1] Signature: job
        self.post_valid_hooks: List[Callable[[Job], Any]] = []

        if self.__class__ == TrainingJob:
            for f in Job.job_created_hooks:
                f(self)
Beispiel #4
0
 def setUp(self):
     self.config = create_config(self.dataset_name, model=self.model_name)
     self.config.set_all({"lookup_embedder.dim": 32})
     self.config.set_all(self.options)
     self.dataset_folder = get_dataset_folder(self.dataset_name)
     self.dataset = Dataset.create(
         self.config, folder=get_dataset_folder(self.dataset_name)
     )
     self.model = KgeModel.create(self.config, self.dataset)
Beispiel #5
0
 def setUp(self):
     self.config = create_config(self.dataset_name)
     self.config.set_all({"lookup_embedder.dim": 32})
     self.config.set("job.type", "train")
     self.config.set("train.type", self.train_type)
     self.config.set_all(self.options)
     self.dataset_folder = get_dataset_folder(self.dataset_name)
     self.dataset = Dataset.create(self.config,
                                   folder=get_dataset_folder(
                                       self.dataset_name))
     self.model = KgeModel.create(self.config, self.dataset)
    def handle_validation(self, metric_name):
        # move all models to cpu and store as tmp model
        tmp_model = self.model.cpu()
        #self.valid_job.model = tmp_model
        del self.model
        if hasattr(self.valid_job, "model"):
            del self.valid_job.model
        gc.collect()
        with torch.cuda.device(self.device):
            torch.cuda.empty_cache()
        self.parameter_client.barrier()
        if self.parameter_client.rank == self.min_rank:
            # create a model for validation with entity embedder size
            #  batch_size x 2 + eval.chunk_size
            self.config.set(self.config.get("model") + ".create_eval", True)

            tmp_pretrain_model_filename = self.config.get("lookup_embedder.pretrain.model_filename")
            self.config.set("lookup_embedder.pretrain.model_filename", "")
            self.model = KgeModel.create(
                self.config, self.dataset, parameter_client=self.parameter_client
            )
            self.model.get_s_embedder().to_device(move_optim_data=False)
            self.model.get_p_embedder().to_device(move_optim_data=False)
            self.config.set("lookup_embedder.pretrain.model_filename", tmp_pretrain_model_filename)
            self.config.set(self.config.get("model") + ".create_eval", False)

            self.valid_job.model = self.model
            # validate and update learning rate
            super(TrainingJobNegativeSamplingDistributed, self).handle_validation(
                metric_name
            )

            # clean up valid model
            del self.model
            del self.valid_job.model
            gc.collect()
            with torch.cuda.device(self.device):
                torch.cuda.empty_cache()
        else:
            self.kge_lr_scheduler.step()
        self.parameter_client.barrier()
        self.model = tmp_model.to(self.device)
        del tmp_model
        gc.collect()
Beispiel #7
0
    def create_from(cls,
                    checkpoint: Dict,
                    new_config: Config = None,
                    dataset: Dataset = None,
                    parent_job=None,
                    parameter_client=None) -> Job:
        """
        Creates a Job based on a checkpoint
        Args:
            checkpoint: loaded checkpoint
            new_config: optional config object - overwrites options of config
                              stored in checkpoint
            dataset: dataset object
            parent_job: parent job (e.g. search job)

        Returns: Job based on checkpoint

        """
        from kge.model import KgeModel

        model: KgeModel = None
        # search jobs don't have a model
        if "model" in checkpoint and checkpoint["model"] is not None:
            model = KgeModel.create_from(checkpoint,
                                         new_config=new_config,
                                         dataset=dataset,
                                         parameter_client=parameter_client)
            config = model.config
            dataset = model.dataset
        else:
            config = Config.create_from(checkpoint)
            if new_config:
                config.load_config(new_config)
            dataset = Dataset.create_from(checkpoint, config, dataset)
        job = Job.create(config,
                         dataset,
                         parent_job,
                         model,
                         parameter_client=parameter_client,
                         init_for_load_only=True)
        job._load(checkpoint)
        job.config.log("Loaded checkpoint from {}...".format(
            checkpoint["file"]))
        return job
Beispiel #8
0
    def load_distributed(self, checkpoint_name):
        """
        Separate function for loading distributed checkpoints.
        The main worker iterates over all checkpoints in the dir loads all of them and
        pushes them to the parameter server.
        Args:
            checkpoint_name: Path to the checkpoint

        Returns:
            None
        """
        from kge.distributed.misc import get_min_rank
        self.parameter_client.barrier()
        if self.parameter_client.rank == get_min_rank(self.config):
            if self.model is None:
                from kge.model import KgeModel
                self.model = KgeModel.create(
                    config=self.config,
                    dataset=self.dataset,
                    parameter_client=self.parameter_client)
            checkpoint_name, file_ending = checkpoint_name.rsplit(".", 1)
            entities_dir = checkpoint_name + "_entities"
            entities_ps_offset = self.model.get_s_embedder().lapse_offset
            for file in os.listdir(entities_dir):
                entity_start, entity_end = (
                    os.path.basename(file).split(".")[0].split("-"))
                push_tensor = torch.load(os.path.join(entities_dir, file))
                entity_ids = torch.arange(int(entity_start),
                                          int(entity_end),
                                          dtype=torch.long)
                self.parameter_client.push(entity_ids + entities_ps_offset,
                                           push_tensor)
            relations_ps_offset = self.model.get_p_embedder().lapse_offset
            push_tensor = torch.load(
                f"{checkpoint_name}_relations.{file_ending}")
            relation_ids = torch.arange(self.dataset.num_relations(),
                                        dtype=torch.long)
            self.parameter_client.push(relation_ids + relations_ps_offset,
                                       push_tensor)
        self.parameter_client.barrier()
import torch
from kge.model import KgeModel
from kge.util.io import load_checkpoint
import numpy as np
# Link prediction performances of RESCAL, ComplEx, ConvE, DistMult and TransE on WN18RR* (out-of-vocabulary entities are removed)
models = ['rescal', 'complex', 'conve', 'distmult', 'transe']

for m in models:
    if m == 'conex':
        """ """
        raise NotImplementedError()
    else:
        # 1. Load pretrained model via LibKGE
        checkpoint = load_checkpoint(
            f'pretrained_models/FB15K-237/fb15k-237-{m}.pt')
        model = KgeModel.create_from(checkpoint)

        # 3. Create mappings.
        # 3.1 Entity index mapping.
        entity_idxs = {
            e: e_idx
            for e, e_idx in zip(model.dataset.entity_ids(),
                                range(len(model.dataset.entity_ids())))
        }
        # 3.2 Relation index mapping.
        relation_idxs = {
            r: r_idx
            for r, r_idx in zip(model.dataset.relation_ids(),
                                range(len(model.dataset.relation_ids())))
        }
    # 2. Load Dataset
Beispiel #10
0
def train(data_path,
          neg_batch_size,
          batch_size,
          shuffle,
          num_workers,
          nb_epochs,
          embedding_dim,
          hidden_dim,
          relation_dim,
          gpu,
          use_cuda,
          patience,
          freeze,
          validate_every,
          hops,
          lr,
          entdrop,
          reldrop,
          scoredrop,
          l3_reg,
          model_name,
          decay,
          ls,
          load_from,
          outfile,
          do_batch_norm,
          valid_data_path=None):
    print('Loading entities and relations')
    kg_type = 'full'
    if 'half' in hops:
        kg_type = 'half'
    checkpoint_file = '../../pretrained_models/embeddings/ComplEx_fbwq_' + kg_type + '/checkpoint_best.pt'
    print('Loading kg embeddings from', checkpoint_file)
    kge_checkpoint = load_checkpoint(checkpoint_file)
    kge_model = KgeModel.create_from(kge_checkpoint)
    kge_model.eval()
    e = getEntityEmbeddings(kge_model, hops)

    print('Loaded entities and relations')

    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)
    data = process_text_file(data_path, split=False)
    print('Train file processed, making dataloader')
    # word2ix,idx2word, max_len = get_vocab(data)
    # hops = str(num_hops)
    device = torch.device(gpu if use_cuda else "cpu")
    dataset = DatasetMetaQA(data, e, entity2idx)
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=num_workers)
    print('Creating model...')
    model = RelationExtractor(embedding_dim=embedding_dim,
                              num_entities=len(idx2entity),
                              relation_dim=relation_dim,
                              pretrained_embeddings=embedding_matrix,
                              freeze=freeze,
                              device=device,
                              entdrop=entdrop,
                              reldrop=reldrop,
                              scoredrop=scoredrop,
                              l3_reg=l3_reg,
                              model=model_name,
                              ls=ls,
                              do_batch_norm=do_batch_norm)
    print('Model created!')
    if load_from != '':
        # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt"))
        fname = "checkpoints/roberta_finetune/" + load_from + ".pt"
        model.load_state_dict(
            torch.load(fname, map_location=lambda storage, loc: storage))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, decay)
    optimizer.zero_grad()
    best_score = -float("inf")
    best_model = model.state_dict()
    no_update = 0
    # time.sleep(10)
    for epoch in range(nb_epochs):
        phases = []
        for i in range(validate_every):
            phases.append('train')
        phases.append('valid')
        for phase in phases:
            if phase == 'train':
                model.train()
                # model.apply(set_bn_eval)
                loader = tqdm(data_loader,
                              total=len(data_loader),
                              unit="batches")
                running_loss = 0
                for i_batch, a in enumerate(loader):
                    model.zero_grad()
                    question_tokenized = a[0].to(device)
                    attention_mask = a[1].to(device)
                    positive_head = a[2].to(device)
                    positive_tail = a[3].to(device)
                    loss = model(question_tokenized=question_tokenized,
                                 attention_mask=attention_mask,
                                 p_head=positive_head,
                                 p_tail=positive_tail)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    loader.set_postfix(Loss=running_loss /
                                       ((i_batch + 1) * batch_size),
                                       Epoch=epoch)
                    loader.set_description('{}/{}'.format(epoch, nb_epochs))
                    loader.update()

                scheduler.step()

            elif phase == 'valid':
                model.eval()
                eps = 0.0001
                answers, score = validate_v2(model=model,
                                             data_path=valid_data_path,
                                             entity2idx=entity2idx,
                                             train_dataloader=dataset,
                                             device=device,
                                             model_name=model_name)
                if score > best_score + eps:
                    best_score = score
                    no_update = 0
                    best_model = model.state_dict()
                    print(
                        hops +
                        " hop Validation accuracy (no relation scoring) increased from previous epoch",
                        score)
                    # writeToFile(answers, 'results_' + model_name + '_' + hops + '.txt')
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                elif (score < best_score + eps) and (no_update < patience):
                    no_update += 1
                    print(
                        "Validation accuracy decreases to %f from %f, %d more epoch to check"
                        % (score, best_score, patience - no_update))
                elif no_update == patience:
                    print(
                        "Model has exceed patience. Saving best model and exiting"
                    )
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
                if epoch == nb_epochs - 1:
                    print("Final Epoch has reached. Stoping and saving model.")
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
Beispiel #11
0
def perform_experiment(data_path,
                       mode,
                       neg_batch_size,
                       batch_size,
                       shuffle,
                       num_workers,
                       nb_epochs,
                       embedding_dim,
                       hidden_dim,
                       relation_dim,
                       gpu,
                       use_cuda,
                       patience,
                       freeze,
                       validate_every,
                       hops,
                       lr,
                       entdrop,
                       reldrop,
                       scoredrop,
                       l3_reg,
                       model_name,
                       decay,
                       ls,
                       load_from,
                       outfile,
                       do_batch_norm,
                       que_embedding_model,
                       valid_data_path=None,
                       test_data_path=None):
    webqsp_checkpoint_folder = f"../../checkpoints/WebQSP/{model_name}_{que_embedding_model}_{outfile}/"
    if not os.path.exists(webqsp_checkpoint_folder):
        os.makedirs(webqsp_checkpoint_folder)

    print('Loading entities and relations')
    kg_type = 'full'
    if 'half' in hops:
        kg_type = 'half'

    checkpoint_file = f"../../pretrained_models/embeddings/{model_name}_fbwq_{kg_type}/checkpoint_best.pt"

    print('Loading kg embeddings from', checkpoint_file)
    kge_checkpoint = load_checkpoint(checkpoint_file)
    kge_model = KgeModel.create_from(kge_checkpoint)
    kge_model.eval()
    e = getEntityEmbeddings(model_name, kge_model, hops)

    print('Loaded entities and relations')

    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)

    # word2ix,idx2word, max_len = get_vocab(data)
    # hops = str(num_hops)
    device = torch.device(gpu if use_cuda else "cpu")
    model = RelationExtractor(embedding_dim=embedding_dim,
                              num_entities=len(idx2entity),
                              relation_dim=relation_dim,
                              pretrained_embeddings=embedding_matrix,
                              freeze=freeze,
                              device=device,
                              entdrop=entdrop,
                              reldrop=reldrop,
                              scoredrop=scoredrop,
                              l3_reg=l3_reg,
                              model=model_name,
                              que_embedding_model=que_embedding_model,
                              ls=ls,
                              do_batch_norm=do_batch_norm)

    # time.sleep(10)
    if mode == 'train':
        data = process_text_file(data_path)
        dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model,
                                model_name)

        # if model_name=="ComplEx":
        #     data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        # else:
        #     data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=custom_collate_fn)

        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=num_workers)

        if load_from != '':
            # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt"))
            fname = f"checkpoints/{que_embedding_model}_finetune/{load_from}.pt"
            model.load_state_dict(
                torch.load(fname, map_location=lambda storage, loc: storage))
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = ExponentialLR(optimizer, decay)
        optimizer.zero_grad()
        best_score = -float("inf")
        best_model = model.state_dict()
        no_update = 0
        for epoch in range(nb_epochs):
            phases = []
            for i in range(validate_every):
                phases.append('train')
            phases.append('valid')
            for phase in phases:
                if phase == 'train':
                    model.train()
                    # model.apply(set_bn_eval)
                    loader = tqdm(data_loader,
                                  total=len(data_loader),
                                  unit="batches")
                    running_loss = 0
                    for i_batch, a in enumerate(loader):
                        model.zero_grad()
                        question_tokenized = a[0].to(device)
                        attention_mask = a[1].to(device)
                        positive_head = a[2].to(device)
                        positive_tail = a[3].to(device)
                        loss = model(question_tokenized=question_tokenized,
                                     attention_mask=attention_mask,
                                     p_head=positive_head,
                                     p_tail=positive_tail)
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()
                        loader.set_postfix(Loss=running_loss /
                                           ((i_batch + 1) * batch_size),
                                           Epoch=epoch)
                        loader.set_description('{}/{}'.format(
                            epoch, nb_epochs))
                        loader.update()

                    scheduler.step()

                elif phase == 'valid':
                    model.eval()
                    eps = 0.0001
                    answers, score = test(model=model,
                                          data_path=valid_data_path,
                                          entity2idx=entity2idx,
                                          dataloader=dataset,
                                          device=device,
                                          model_name=model_name,
                                          return_hits_at_k=False)
                    if score > best_score + eps:
                        best_score = score
                        no_update = 0
                        best_model = model.state_dict()
                        print(
                            hops +
                            " hop Validation accuracy (no relation scoring) increased from previous epoch",
                            score)
                        writeToFile(
                            answers,
                            f'results/{model_name}_{que_embedding_model}_{outfile}.txt'
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                    elif (score < best_score + eps) and (no_update < patience):
                        no_update += 1
                        print(
                            "Validation accuracy decreases to %f from %f, %d more epoch to check"
                            % (score, best_score, patience - no_update))
                    elif no_update == patience:
                        print(
                            "Model has exceed patience. Saving best model and exiting"
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                        exit(0)
                    if epoch == nb_epochs - 1:
                        print(
                            "Final Epoch has reached. Stoping and saving model."
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                        exit()
                    # torch.save(model.state_dict(), "checkpoints/roberta_finetune/"+str(epoch)+".pt")
                    # torch.save(model.state_dict(), "checkpoints/roberta_finetune/x.pt")

    elif mode == 'test':
        data = process_text_file(test_data_path)
        dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model,
                                model_name)
        model_chkpt_file_path = get_chkpt_path(model_name, que_embedding_model,
                                               outfile)
        model.load_state_dict(
            torch.load(model_chkpt_file_path,
                       map_location=lambda storage, loc: storage))
        model.to(device)
        for parameter in model.parameters():
            parameter.requires_grad = False
        model.eval()
        answers, accuracy, hits_at_1, hits_at_5, hits_at_10 = test(
            model=model,
            data_path=test_data_path,
            entity2idx=entity2idx,
            dataloader=dataset,
            device=device,
            model_name=model_name,
            return_hits_at_k=True)

        d = {
            'KG-Model': model_name,
            'KG-Type': kg_type,
            'Que-Embedding-Model': que_embedding_model,
            'Accuracy': [accuracy],
            'Hits@1': [hits_at_1],
            'Hits@5': [hits_at_5],
            'Hits@10': [hits_at_10]
        }
        df = pd.DataFrame(data=d)
        df.to_csv(f"final_results.csv", mode='a', index=False, header=False)
    def __init__(
        self,
        config,
        dataset,
        parent_job=None,
        model=None,
        optimizer=None,
        forward_only=False,
        parameter_client=None,
        init_for_load_only=False,
    ):
        self.parameter_client = parameter_client
        self.min_rank = get_min_rank(config)

        self.work_scheduler_client = SchedulerClient(config)
        (
            max_partition_entities,
            max_partition_relations,
        ) = self.work_scheduler_client.get_init_info()
        if model is None:
            model: KgeModel = KgeModel.create(
                config,
                dataset,
                parameter_client=parameter_client,
                max_partition_entities=max_partition_entities,
            )
        model.get_s_embedder().to_device()
        model.get_p_embedder().to_device()
        lapse_indexes = [
            torch.arange(dataset.num_entities(), dtype=torch.int),
            torch.arange(dataset.num_relations(), dtype=torch.int)
            + dataset.num_entities(),
        ]
        if optimizer is None:
            optimizer = KgeOptimizer.create(
                config,
                model,
                parameter_client=parameter_client,
                lapse_indexes=lapse_indexes,
            )
        # barrier to wait for loading of pretrained embeddings
        self.parameter_client.barrier()
        super().__init__(
            config,
            dataset,
            parent_job,
            model=model,
            optimizer=optimizer,
            forward_only=forward_only,
            parameter_client=parameter_client,
        )
        self.type_str = "negative_sampling"
        self.load_batch = self.config.get("job.distributed.load_batch")
        self.entity_localize = self.config.get("job.distributed.entity_localize")
        self.relation_localize = self.config.get("job.distributed.relation_localize")
        self.entity_partition_localized = False
        self.relation_partition_localized = False
        self.entity_async_write_back = self.config.get(
            "job.distributed.entity_async_write_back"
        )
        self.relation_async_write_back = self.config.get(
            "job.distributed.relation_async_write_back"
        )
        self.entity_sync_level = self.config.get("job.distributed.entity_sync_level")
        self.relation_sync_level = self.config.get(
            "job.distributed.relation_sync_level"
        )
        self.entity_pre_pull = self.config.get("job.distributed.entity_pre_pull")
        self.relation_pre_pull = self.config.get("job.distributed.relation_pre_pull")
        self.pre_localize_batch = int(
            self.config.get("job.distributed.pre_localize_batch")
        )
        self.entity_mapper_tensors = deque()
        for i in range(self.config.get("train.num_workers") + 1):
            self.entity_mapper_tensors.append(
                torch.full((self.dataset.num_entities(),), -1, dtype=torch.long)
            )

        self._initialize_parameter_server(init_for_load_only=init_for_load_only)

        def stop_and_wait(job):
            job.parameter_client.stop()
            job.parameter_client.barrier()
        self.early_stop_hooks.append(stop_and_wait)

        def check_stopped(job):
            print("checking for", job.parameter_client.rank)
            job.parameter_client.barrier()
            return job.parameter_client.is_stopped()
        self.early_stop_conditions.append(check_stopped)
        self.work_pre_localized = False
        if self.config.get("job.distributed.pre_localize_partition"):
            self.pre_localized_entities = None
            self.pre_localized_relations = None
            self.pre_batch_hooks.append(self._pre_localize_work)

        if self.__class__ == TrainingJobNegativeSamplingDistributed:
            for f in Job.job_created_hooks:
                f(self)