Example #1
0
    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset,
                      torch.utils.data.IterableDataset) or not isinstance(
                          self.train_dataset, collections.abc.Sized):
            return None

        # Build the sampler.
        if self.args.group_by_length:
            # lengths = self.train_dataset[self.length_field_name] if self.length_field_name is not None else None
            model_input_name = self.tokenizer.model_input_names[
                0] if self.tokenizer is not None else None
            if self.args.world_size <= 1:
                return LengthGroupedSampler(self.train_dataset,
                                            self.args.train_batch_size,
                                            lengths=self.train_seq_lengths,
                                            model_input_name=model_input_name)
            else:
                return DistributedLengthGroupedSampler(
                    self.train_dataset,
                    self.args.train_batch_size,
                    num_replicas=self.args.world_size,
                    rank=self.args.process_index,
                    lengths=self.train_seq_lengths,
                    model_input_name=model_input_name,
                )

        else:
            return super()._get_train_sampler()
Example #2
0
    def test_group_by_length(self):
        # Get some inputs of random lengths
        lengths = torch.randint(0, 25, (100, )).tolist()
        # Put one bigger than the others to check it ends up in first position
        lengths[32] = 50

        indices = list(LengthGroupedSampler(lengths, 4, lengths=lengths))
        # The biggest element should be first
        self.assertEqual(lengths[indices[0]], 50)
        # The indices should be a permutation of range(100)
        self.assertEqual(list(sorted(indices)), list(range(100)))
Example #3
0
    def test_group_by_length_with_batch_encoding(self):
        # Get some inputs of random lengths
        data = []
        for _ in range(6):
            input_ids = torch.randint(0, 25, (100, )).tolist()
            data.append(BatchEncoding({"input_ids": input_ids}))
        # Put one bigger than the others to check it ends up in first position
        data[3]["input_ids"] = torch.randint(0, 25, (105, )).tolist()

        indices = list(LengthGroupedSampler(4, dataset=data))
        # The biggest element should be first
        self.assertEqual(len(data[indices[0]]["input_ids"]), 105)
        # The indices should be a permutation of range(6)
        self.assertEqual(list(sorted(indices)), list(range(6)))