def test_distributed_length_grouped(self): # Get some inputs of random lengths lengths = torch.randint(0, 25, (100,)).tolist() # Put one bigger than the others to check it ends up in first position lengths[32] = 50 indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths)) indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths)) # The biggest element should be first self.assertEqual(lengths[indices_process_0[0]], 50) # The indices should be a permutation of range(100) self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance( self.train_dataset, collections.abc.Sized): return None # Build the sampler. if self.args.group_by_length: # lengths = self.train_dataset[self.length_field_name] if self.length_field_name is not None else None model_input_name = self.tokenizer.model_input_names[ 0] if self.tokenizer is not None else None if self.args.world_size <= 1: return LengthGroupedSampler(self.train_dataset, self.args.train_batch_size, lengths=self.train_seq_lengths, model_input_name=model_input_name) else: return DistributedLengthGroupedSampler( self.train_dataset, self.args.train_batch_size, num_replicas=self.args.world_size, rank=self.args.process_index, lengths=self.train_seq_lengths, model_input_name=model_input_name, ) else: return super()._get_train_sampler()