def batch(func, initial_batchsize: int, *args, **kwargs):
        gc_cuda()

        batchsize = initial_batchsize
        while True:
            try:
                return func(batchsize, *args, **kwargs)
            except RuntimeError as exception:
                if batchsize > 1 and should_reduce_batch_size(exception):
                    batchsize //= 2
                    gc_cuda()
                else:
                    raise
    def range(func, start: int, end: int, initial_step: int, *args, **kwargs):
        gc_cuda()

        stepsize = initial_step
        current = start
        while current < end:
            try:
                func(current, min(current + stepsize, end), *args, **kwargs)
                current += stepsize
            except RuntimeError as exception:
                if stepsize > 1 and should_reduce_batch_size(exception):
                    stepsize //= 2
                    gc_cuda()
                else:
                    raise
    def range(
        func,
        start: int,
        end: int,
        initial_step: int,
        *args,
        toma_context=None,
        toma_cache_type: Type = DEFAULT_CACHE_TYPE,
        **kwargs,
    ):
        gc_cuda()

        cache = get_cache_for_context(toma_cache_type, toma_context or func)

        batchsize = cache.get_batchsize(initial_step)

        gc_cuda()
        current = start
        while current < end:
            try:
                func(current, min(current + batchsize.get(), end), *args,
                     **kwargs)
                current += batchsize.get()
            except RuntimeError as exception:
                if batchsize.get() > 1 and should_reduce_batch_size(exception):
                    batchsize.decrease_batchsize()
                    gc_cuda()
                else:
                    raise
    def batch(func,
              initial_batchsize: int,
              *args,
              toma_context=None,
              toma_cache_type: Type = DEFAULT_CACHE_TYPE,
              **kwargs):
        gc_cuda()

        cache = get_cache_for_context(toma_cache_type, toma_context or func)

        batchsize = cache.get_batchsize(initial_batchsize)

        while True:
            try:
                value = batchsize.get()
                return func(value, *args, **kwargs)
            except RuntimeError as exception:
                if value > 1 and should_reduce_batch_size(exception):
                    batchsize.decrease_batchsize()
                    gc_cuda()
                else:
                    raise