Beispiel #1
0
    def _prepare(self):
        super()._prepare()
        # select negative sampling implementation
        self._implementation = self.config.check(
            "negative_sampling.implementation",
            ["triple", "all", "batch", "auto"],
        )
        if self._implementation == "auto":
            max_nr_of_negs = max(self._sampler.num_samples)
            if self._sampler.shared:
                self._implementation = "batch"
            elif max_nr_of_negs <= 30:
                self._implementation = "triple"
            else:
                self._implementation = "batch"
            self.config.set("negative_sampling.implementation",
                            self._implementation,
                            log=True)

        self.config.log("Preparing negative sampling training job with "
                        "'{}' scoring function ...".format(
                            self._implementation))

        # construct dataloader
        self.num_examples = self.dataset.split(self.train_split).size(0)
        self.loader = torch.utils.data.DataLoader(
            range(self.num_examples),
            collate_fn=self._get_collate_fun(),
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.config.get("train.num_workers"),
            worker_init_fn=_generate_worker_init_fn(self.config),
            pin_memory=self.config.get("train.pin_memory"),
        )
Beispiel #2
0
    def _prepare(self):
        """Construct dataloader"""
        super()._prepare()

        self.num_examples = self.dataset.split(self.train_split).size(0)
        self.loader = torch.utils.data.DataLoader(
            range(self.num_examples),
            collate_fn=self._get_collate_fun(),
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.config.get("train.num_workers"),
            worker_init_fn=_generate_worker_init_fn(self.config),
            pin_memory=self.config.get("train.pin_memory"),
        )
 def _init_dataloader(self):
     mp_context = (
         torch.multiprocessing.get_context("fork")
         if self.config.get("train.num_workers") > 0
         else None
     )
     self.loader = torch.utils.data.DataLoader(
         self.dataloader_dataset,
         sampler=InfiniteSequentialSampler(self.dataloader_dataset),
         collate_fn=self._get_collate_fun(),
         shuffle=False,
         # shuffle needs to be False, since it is handled in the dataset object
         # batch_size=self.batch_size,  # batch size needs to be 1 since it is handled in the dataset object
         num_workers=self.config.get("train.num_workers"),
         worker_init_fn=_generate_worker_init_fn(self.config),
         pin_memory=self.config.get("train.pin_memory"),
         multiprocessing_context=mp_context,
     )
Beispiel #4
0
    def _prepare(self):
        super()._prepare()
        # determine enabled query types
        self.query_types = [
            key
            for key, enabled in self.config.get("KvsAll.query_types").items()
            if enabled
        ]

        # corresponding indexes
        self.query_indexes: List[KvsAllIndex] = []

        #' for each query type (ordered as in self.query_types), index right after last
        #' example of that type in the list of all examples (over all query types)
        self.query_last_example = []

        # construct relevant data structures
        self.num_examples = 0
        for query_type in self.query_types:
            index_type = ("sp_to_o" if query_type == "sp_" else
                          ("so_to_p" if query_type == "s_o" else "po_to_s"))
            index = self.dataset.index(f"{self.train_split}_{index_type}")
            self.query_indexes.append(index)
            self.num_examples += len(index)
            self.query_last_example.append(self.num_examples)

        # create dataloader
        self.loader = torch.utils.data.DataLoader(
            range(self.num_examples),
            collate_fn=self._get_collate_fun(),
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.config.get("train.num_workers"),
            worker_init_fn=_generate_worker_init_fn(self.config),
            pin_memory=self.config.get("train.pin_memory"),
        )