def batch(func, initial_batchsize: int, *args, **kwargs): gc_cuda() batchsize = initial_batchsize while True: try: return func(batchsize, *args, **kwargs) except RuntimeError as exception: if batchsize > 1 and should_reduce_batch_size(exception): batchsize //= 2 gc_cuda() else: raise
def range(func, start: int, end: int, initial_step: int, *args, **kwargs): gc_cuda() stepsize = initial_step current = start while current < end: try: func(current, min(current + stepsize, end), *args, **kwargs) current += stepsize except RuntimeError as exception: if stepsize > 1 and should_reduce_batch_size(exception): stepsize //= 2 gc_cuda() else: raise
def range( func, start: int, end: int, initial_step: int, *args, toma_context=None, toma_cache_type: Type = DEFAULT_CACHE_TYPE, **kwargs, ): gc_cuda() cache = get_cache_for_context(toma_cache_type, toma_context or func) batchsize = cache.get_batchsize(initial_step) gc_cuda() current = start while current < end: try: func(current, min(current + batchsize.get(), end), *args, **kwargs) current += batchsize.get() except RuntimeError as exception: if batchsize.get() > 1 and should_reduce_batch_size(exception): batchsize.decrease_batchsize() gc_cuda() else: raise
def batch(func, initial_batchsize: int, *args, toma_context=None, toma_cache_type: Type = DEFAULT_CACHE_TYPE, **kwargs): gc_cuda() cache = get_cache_for_context(toma_cache_type, toma_context or func) batchsize = cache.get_batchsize(initial_batchsize) while True: try: value = batchsize.get() return func(value, *args, **kwargs) except RuntimeError as exception: if value > 1 and should_reduce_batch_size(exception): batchsize.decrease_batchsize() gc_cuda() else: raise