예제 #1
0
    def __init__(self,
                 config,
                 dataset,
                 sampler,
                 kg_sampler,
                 neg_sample_args,
                 batch_size=1,
                 dl_format=InputType.POINTWISE,
                 shuffle=False):

        # using sampler
        self.general_dataloader = GeneralNegSampleDataLoader(
            config=config,
            dataset=dataset,
            sampler=sampler,
            neg_sample_args=neg_sample_args,
            batch_size=batch_size,
            dl_format=dl_format,
            shuffle=shuffle)

        # using kg_sampler and dl_format is pairwise
        self.kg_dataloader = KGDataLoader(config,
                                          dataset,
                                          kg_sampler,
                                          batch_size=batch_size,
                                          dl_format=InputType.PAIRWISE,
                                          shuffle=shuffle)

        self.main_dataloader = self.general_dataloader

        super().__init__(config,
                         dataset,
                         batch_size=batch_size,
                         dl_format=dl_format,
                         shuffle=shuffle)
예제 #2
0
class KnowledgeBasedDataLoader(AbstractDataLoader):
    """:class:`KnowledgeBasedDataLoader` is used for knowledge based model.

    It has three states, which is saved in :attr:`state`.
    In different states, :meth:`~_next_batch_data` will return different :class:`~recbole.data.interaction.Interaction`.
    Detailed, please see :attr:`~state`.

    Args:
        config (Config): The config of dataloader.
        dataset (Dataset): The dataset of dataloader.
        sampler (Sampler): The sampler of dataloader.
        kg_sampler (KGSampler): The knowledge graph sampler of dataloader.
        neg_sample_args (dict): The neg_sample_args of dataloader.
        batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``.
        dl_format (InputType, optional): The input type of dataloader. Defaults to
            :obj:`~recbole.utils.enum_type.InputType.POINTWISE`.
        shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.

    Attributes:
        state (KGDataLoaderState): 
            This dataloader has three states:

                - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RS`
                - :obj:`~recbole.utils.enum_type.KGDataLoaderState.KG`
                - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RSKG`

            In the first state, this dataloader would only return the triplets with negative examples in a knowledge graph.

            In the second state, this dataloader would only return the user-item interaction.

            In the last state, this dataloader would return both knowledge graph information
            and user-item interaction information.
    """
    def __init__(self,
                 config,
                 dataset,
                 sampler,
                 kg_sampler,
                 neg_sample_args,
                 batch_size=1,
                 dl_format=InputType.POINTWISE,
                 shuffle=False):

        # using sampler
        self.general_dataloader = GeneralNegSampleDataLoader(
            config=config,
            dataset=dataset,
            sampler=sampler,
            neg_sample_args=neg_sample_args,
            batch_size=batch_size,
            dl_format=dl_format,
            shuffle=shuffle)

        # using kg_sampler and dl_format is pairwise
        self.kg_dataloader = KGDataLoader(config,
                                          dataset,
                                          kg_sampler,
                                          batch_size=batch_size,
                                          dl_format=InputType.PAIRWISE,
                                          shuffle=shuffle)

        self.main_dataloader = self.general_dataloader

        super().__init__(config,
                         dataset,
                         batch_size=batch_size,
                         dl_format=dl_format,
                         shuffle=shuffle)

    @property
    def pr(self):
        """Pointer of :class:`KnowledgeBasedDataLoader`. It would be affect by self.state.
        """
        return self.main_dataloader.pr

    @pr.setter
    def pr(self, value):
        self.main_dataloader.pr = value

    def __iter__(self):
        if not hasattr(self, 'state') or not hasattr(self, 'main_dataloader'):
            raise ValueError(
                'The dataloader\'s state and main_dataloader must be set '
                'when using the kg based dataloader')
        return super().__iter__()

    def _shuffle(self):
        if self.state == KGDataLoaderState.RSKG:
            self.general_dataloader._shuffle()
            self.kg_dataloader._shuffle()
        else:
            self.main_dataloader._shuffle()

    def __next__(self):
        if self.pr >= self.pr_end:
            if self.state == KGDataLoaderState.RSKG:
                self.general_dataloader.pr = 0
                self.kg_dataloader.pr = 0
            else:
                self.pr = 0
            raise StopIteration()
        return self._next_batch_data()

    def __len__(self):
        return len(self.main_dataloader)

    @property
    def pr_end(self):
        return self.main_dataloader.pr_end

    def _next_batch_data(self):
        if self.state == KGDataLoaderState.KG:
            return self.kg_dataloader._next_batch_data()
        elif self.state == KGDataLoaderState.RS:
            return self.general_dataloader._next_batch_data()
        elif self.state == KGDataLoaderState.RSKG:
            if self.kg_dataloader.pr >= self.kg_dataloader.pr_end:
                self.kg_dataloader.pr = 0
            kg_data = self.kg_dataloader._next_batch_data()
            rec_data = self.general_dataloader._next_batch_data()
            rec_data.update(kg_data)
            return rec_data

    def set_mode(self, state):
        """Set the mode of :class:`KnowledgeBasedDataLoader`, it can be set to three states:

            - KGDataLoaderState.RS
            - KGDataLoaderState.KG
            - KGDataLoaderState.RSKG

        The state of :class:`KnowledgeBasedDataLoader` would affect the result of _next_batch_data().

        Args:
            state (KGDataLoaderState): the state of :class:`KnowledgeBasedDataLoader`.
        """
        if state not in set(KGDataLoaderState):
            raise NotImplementedError(
                'kg data loader has no state named [{}]'.format(self.state))
        self.state = state
        if self.state == KGDataLoaderState.RS:
            self.main_dataloader = self.general_dataloader
        elif self.state == KGDataLoaderState.KG:
            self.main_dataloader = self.kg_dataloader
        else:  # RSKG
            kgpr = self.kg_dataloader.pr_end
            rspr = self.general_dataloader.pr_end
            self.main_dataloader = self.general_dataloader if rspr < kgpr else self.kg_dataloader