def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): """Returns a JIT-compiled predict function (unless jit=False).""" model_predict = layers.Serial(model_predict, metric_fn).apply_forward if not jit: return model_predict model_predict = backend.accelerate(model_predict, n_devices) if n_devices == 1: return model_predict def predict(x, params, state, rng): """Predict function jited and parallelized as requested.""" res, state = backend.combine_devices( model_predict(backend.reshape_by_device(x, n_devices), params, state, np.stack(jax_random.split(rng, n_devices)))) return layers.nested_map(lambda y: np.mean(y, axis=0), res), state return predict
def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): """Returns a JIT-compiled predict function (unless jit=False).""" model = tl.Serial(model_predict, metric_fn) model_predict = model._forward_internal # pylint: disable=protected-access if not jit: return model_predict model_predict = backend.accelerate(model_predict, n_devices) if n_devices == 1: return model_predict def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = _combine_devices( model_predict(_reshape_by_device(x, n_devices), weights, state, np.stack(jax_random.split(rng, n_devices)))) return backend.nested_map(lambda y: np.mean(y, axis=0), res), state return predict