def _get_or_make_function(model, mode, key_fn, make_fn): """Helper function for managing cached execution functions.""" model._init_distributed_function_cache_if_not_compiled() key = key_fn(mode) function = dist_utils.get_distributed_function(model, key) if function: return function function = make_fn(model, mode) dist_utils.set_distributed_function(model, key, function) return function
def _get_or_make_execution_function(model, mode): """Makes or reuses function to run one step of distributed model execution.""" model._init_distributed_function_cache_if_not_compiled() # Use a key with 'v2' to distinguish from fall-back execution functions. key = (mode, 'v2') distributed_function = dist_utils.get_distributed_function(model, key) if distributed_function: return distributed_function distribution_function = _make_execution_function(model, mode) dist_utils.set_distributed_function(model, key, distribution_function) return distribution_function