Esempio n. 1
0
def get_number_of_elements(object_):
    """ Get the sum of the number of elements in all tensors stored in `object_`.

    This is particularly useful for sampling the largest objects based on tensor size like in:
    `OomBatchSampler.__init__.get_item_size`.

    Args:
        object (any)

    Returns:
        (int): The number of elements in the `object_`.
    """
    return sum([t.numel() for t in get_tensors(object_)])
Esempio n. 2
0
    def __init__(
        self,
        data,
        batch_size,
        drop_last,
        sort_key=lambda e: e,
        biggest_batches_first=lambda o: sum(
            [t.numel() for t in get_tensors(o)]),
        bucket_size_multiplier=100,
        shuffle=True,
    ):
        self.biggest_batches_first = biggest_batches_first
        self.sort_key = sort_key
        self.bucket_size_multiplier = bucket_size_multiplier
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.data = data
        self.shuffle = shuffle

        self.bucket_size_multiplier = bucket_size_multiplier
        self.bucket_sampler = BatchSampler(RandomSampler(data),
                                           batch_size * bucket_size_multiplier,
                                           False)
Esempio n. 3
0
def test_get_tensors_object():
    object_ = GetTensorsObjectMock()
    tensors = get_tensors(object_)
    assert len(tensors) == 5
Esempio n. 4
0
def test_get_tensors_tuple():
    tuple_ = tuple([{'t': torch.tensor([1, 2])}, torch.tensor([2, 3])])
    tensors = get_tensors(tuple_)
    assert len(tensors) == 2
Esempio n. 5
0
def test_get_tensors_dict():
    list_ = [{'t': torch.tensor([1, 2])}, torch.tensor([2, 3])]
    tensors = get_tensors(list_)
    assert len(tensors) == 2
Esempio n. 6
0
def test_get_tensors_list():
    list_ = [torch.tensor([1, 2]), torch.tensor([2, 3])]
    tensors = get_tensors(list_)
    assert len(tensors) == 2
def _biggest_batches_first(o):
    return sum([t.numel() for t in get_tensors(o)])