예제 #1
0
    def __init__(self, config, dataset, sampler, kg_sampler, shuffle=False):

        # using sampler
        self.general_dataloader = TrainDataLoader(config,
                                                  dataset,
                                                  sampler,
                                                  shuffle=shuffle)

        # using kg_sampler
        self.kg_dataloader = KGDataLoader(config,
                                          dataset,
                                          kg_sampler,
                                          shuffle=True)

        self.state = None

        super().__init__(config, dataset, sampler, shuffle=shuffle)
예제 #2
0
class KnowledgeBasedDataLoader(AbstractDataLoader):
    """:class:`KnowledgeBasedDataLoader` is used for knowledge based model.

    It has three states, which is saved in :attr:`state`.
    In different states, :meth:`~_next_batch_data` will return different :class:`~recbole.data.interaction.Interaction`.
    Detailed, please see :attr:`~state`.

    Args:
        config (Config): The config of dataloader.
        dataset (Dataset): The dataset of dataloader.
        sampler (Sampler): The sampler of dataloader.
        kg_sampler (KGSampler): The knowledge graph sampler of dataloader.
        shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``.

    Attributes:
        state (KGDataLoaderState):
            This dataloader has three states:

                - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RS`
                - :obj:`~recbole.utils.enum_type.KGDataLoaderState.KG`
                - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RSKG`

            In the first state, this dataloader would only return the triplets with negative
            examples in a knowledge graph.

            In the second state, this dataloader would only return the user-item interaction.

            In the last state, this dataloader would return both knowledge graph information
            and user-item interaction information.
    """
    def __init__(self, config, dataset, sampler, kg_sampler, shuffle=False):

        # using sampler
        self.general_dataloader = TrainDataLoader(config,
                                                  dataset,
                                                  sampler,
                                                  shuffle=shuffle)

        # using kg_sampler
        self.kg_dataloader = KGDataLoader(config,
                                          dataset,
                                          kg_sampler,
                                          shuffle=True)

        self.state = None

        super().__init__(config, dataset, sampler, shuffle=shuffle)

    def _init_batch_size_and_step(self):
        pass

    def __iter__(self):
        if self.state is None:
            raise ValueError(
                'The dataloader\'s state must be set when using the kg based dataloader, '
                'you should call set_mode() before __iter__()')
        if self.state == KGDataLoaderState.KG:
            return self.kg_dataloader.__iter__()
        elif self.state == KGDataLoaderState.RS:
            return self.general_dataloader.__iter__()
        elif self.state == KGDataLoaderState.RSKG:
            self.kg_dataloader.__iter__()
            self.general_dataloader.__iter__()
            return self

    def _shuffle(self):
        pass

    def __next__(self):
        if self.general_dataloader.pr >= self.general_dataloader.pr_end:
            self.general_dataloader.pr = 0
            self.kg_dataloader.pr = 0
            raise StopIteration()
        return self._next_batch_data()

    def __len__(self):
        if self.state == KGDataLoaderState.KG:
            return len(self.kg_dataloader)
        else:
            return len(self.general_dataloader)

    @property
    def pr_end(self):
        if self.state == KGDataLoaderState.KG:
            return self.kg_dataloader.pr_end
        else:
            return self.general_dataloader.pr_end

    def _next_batch_data(self):
        try:
            kg_data = self.kg_dataloader.__next__()
        except StopIteration:
            kg_data = self.kg_dataloader.__next__()
        rec_data = self.general_dataloader.__next__()
        rec_data.update(kg_data)
        return rec_data

    def set_mode(self, state):
        """Set the mode of :class:`KnowledgeBasedDataLoader`, it can be set to three states:

            - KGDataLoaderState.RS
            - KGDataLoaderState.KG
            - KGDataLoaderState.RSKG

        The state of :class:`KnowledgeBasedDataLoader` would affect the result of _next_batch_data().

        Args:
            state (KGDataLoaderState): the state of :class:`KnowledgeBasedDataLoader`.
        """
        if state not in set(KGDataLoaderState):
            raise NotImplementedError(
                f'Kg data loader has no state named [{self.state}].')
        self.state = state