def get_number_of_elements(object_): """ Get the sum of the number of elements in all tensors stored in `object_`. This is particularly useful for sampling the largest objects based on tensor size like in: `OomBatchSampler.__init__.get_item_size`. Args: object (any) Returns: (int): The number of elements in the `object_`. """ return sum([t.numel() for t in get_tensors(object_)])
def __init__( self, data, batch_size, drop_last, sort_key=lambda e: e, biggest_batches_first=lambda o: sum( [t.numel() for t in get_tensors(o)]), bucket_size_multiplier=100, shuffle=True, ): self.biggest_batches_first = biggest_batches_first self.sort_key = sort_key self.bucket_size_multiplier = bucket_size_multiplier self.batch_size = batch_size self.drop_last = drop_last self.data = data self.shuffle = shuffle self.bucket_size_multiplier = bucket_size_multiplier self.bucket_sampler = BatchSampler(RandomSampler(data), batch_size * bucket_size_multiplier, False)
def test_get_tensors_object(): object_ = GetTensorsObjectMock() tensors = get_tensors(object_) assert len(tensors) == 5
def test_get_tensors_tuple(): tuple_ = tuple([{'t': torch.tensor([1, 2])}, torch.tensor([2, 3])]) tensors = get_tensors(tuple_) assert len(tensors) == 2
def test_get_tensors_dict(): list_ = [{'t': torch.tensor([1, 2])}, torch.tensor([2, 3])] tensors = get_tensors(list_) assert len(tensors) == 2
def test_get_tensors_list(): list_ = [torch.tensor([1, 2]), torch.tensor([2, 3])] tensors = get_tensors(list_) assert len(tensors) == 2
def _biggest_batches_first(o): return sum([t.numel() for t in get_tensors(o)])