def O_mean(forward_fn, params, samples, holomorphic=True): r""" compute \langle O \rangle i.e. the mean of the rows of the jacobian of forward_fn """ # determine the output type of the forward pass dtype = jax.eval_shape(forward_fn, params, samples).dtype w = jnp.ones(samples.shape[0], dtype=dtype) * (1.0 / (samples.shape[0] * mpi.n_nodes)) homogeneous = nkjax.tree_ishomogeneous(params) real_params = not nkjax.tree_leaf_iscomplex(params) real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params, samples)) if homogeneous and (real_params or holomorphic): if real_params and not real_out: # R->C return O_vjp_rc(forward_fn, params, samples, w) else: # R->R and holomorphic C->C return O_vjp(forward_fn, params, samples, w) else: # R&C -> C # non-holomorphic # C->R assert False
def wrapper(*args, **kwargs): base.assert_context("optimize_rng_use") # Extract all current state. frame = base.current_frame() params = frame.params or None if params is not None: params = data_structures.to_haiku_dict(params) state = frame.state or None if state is not None: state = base.extract_state(state, initial=True) rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state def pure_fun(params, state, rng, *args, **kwargs): with base.new_context(params=params, state=state, rng=rng): return fun(*args, **kwargs) with count_hk_rngs_requested() as rng_count_f: jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs) rng_count = rng_count_f() if rng_count: base.current_frame().rng_stack.peek().reserve(rng_count) return fun(*args, **kwargs)
def __init__( self, stages: tp.List[int], block_type: tp.Union[tp.Type[ResNetBlock], tp.Type[BottleneckResNetBlock]], lowres: bool = False, weights: tp.Optional[str] = None, dtype: tp.Optional[tp.Any] = jnp.float32, *args, **kwargs, ): """ Arguments: stages: A list of integers representing the number of blocks in each stage. e.g: [3, 4, 6, 3] for a ResNet50 block_type: Which ResNet block type to use. lowres: Optional, whether to use the low resolution version as described in subsection 4.2 of the orignal paper. This version is better suited for datasets like CIFAR10. (Default: False) weights: One of None (random initialization) or a path to a weights file dtype: Optional dtype of the convolutions and linear operations, either jnp.float32 (default) or jnp.float16 for mixed precision. """ super().__init__(*args, **kwargs) self.stages = stages self.block_type = block_type self.lowres = lowres if weights is not None: if weights.endswith(".pkl"): collections = pickle.load(open(weights, "rb")) elif weights == "imagenet": clsname = self.__class__.__name__ urldict = PRETRAINED_URLS.get(clsname, None) if urldict is None: raise ValueError( f"No pretrained weights for {clsname} available") fname = utils.download_file(urldict["url"], sha256=urldict["sha256"]) collections = pickle.load(open(fname, "rb")) else: raise ValueError("Unknown weights value: ", weights) if isinstance(collections, tuple): parameters, collections = collections elif "parameters" in collections: parameters = collections.pop("parameters") else: raise ValueError( "Unknown parameters structure, expected either tuple (parameters, collections) or a collections dict with a 'parameters' field." ) x = np.empty([0, 224, 224, 3], dtype=self.dtype) # quick but dirty module initialization jax.eval_shape(self.init(rng=types.RNGSeq(42)), x) self.set_default_parameters(parameters, collections)
def apply(self, input_values, state=None): if state is None: output_values, _ = jax.eval_shape(self.model.init_with_output, jax.random.PRNGKey(0), input_values) else: output_values = jax.eval_shape(self.model.apply, state, input_values) return output_values
def eval_shape(fun, *args, has_aux=False, **kwargs): """ Returns the dtype of forward_fn(pars, v) """ if has_aux: out, _ = jax.eval_shape(fun, *args, **kwargs) else: out = jax.eval_shape(fun, *args, **kwargs) return out
def test_strict_promotion(self, module_fn: ModuleFn, shape, dtype): if descriptors.module_type(module_fn) in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA): self.skipTest('Requires: https://github.com/google/jax/pull/2901') f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(42) x = np.ones(shape, dtype) params, state = jax.eval_shape(f.init, rng, x) self.assertIsNotNone(jax.eval_shape(f.apply, params, state, rng, x))
def test_run(self, module_fn: ModuleFn, shape, dtype): def g(x): return module_fn()(x) f = hk.transform_with_state(g) def run(): rng = jax.random.PRNGKey(42) x = jnp.zeros(shape, dtype) params, state = f.init(rng, x) return f.apply(params, state, rng, x) jax.eval_shape(run)
def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype self.input_shape = input_shape # To check if the model was intialized automatically. self._is_initialized = _do_init if _do_init: # randomly initialized parameters random_params = self.init_weights(self.key, input_shape) params_shape_tree = jax.eval_shape(lambda params: params, random_params) else: init_fn = partial(self.init_weights, input_shape=input_shape) params_shape_tree = jax.eval_shape(init_fn, self.key) logger.info( "Model weights are not initialized as `_do_init` is set to `False`. " f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." ) # get the shape of the parameters self._params_shape_tree = params_shape_tree # save required_params as set self._required_params = set( flatten_dict(unfreeze(params_shape_tree)).keys()) # initialize the parameters if _do_init: self.params = random_params
def assert_dtype(self, test_dtype, module_fn: ModuleFn, shape, input_dtype): """Checks that modules accepting float32 input_dtype output test_dtype.""" if input_dtype != jnp.float32: self.skipTest('Skipping module with non-f32 input') def ones_creator(next_creator, shape, dtype, init, context): if context.full_name == 'vector_quantizer/embeddings': # NOTE: vector_quantizer/embeddings is created using a ctor argument # so dtype is not expected to follow input to __call__. dtype = test_dtype else: self.assertEqual(dtype, test_dtype, msg=context.full_name) # NOTE: We need to do this since some initializers (e.g. random.uniform) # do not support <32bit dtypes. This also makes the test run a bit faster. init = jnp.ones return next_creator(shape, dtype, init) def g(x): with hk.custom_creator(ones_creator): mod = module_fn() return mod(x) g = hk.transform_with_state(g) # No custom creator for state so we need to do this manually. def cast_if_floating(x): if jnp.issubdtype(x.dtype, jnp.floating): x = x.astype(test_dtype) return x def init_fn(rng, x): params, state = g.init(rng, x) state = jax.tree_map(cast_if_floating, state) return params, state x = np.ones(shape, test_dtype) rng = jax.random.PRNGKey(42) params, state = jax.eval_shape(init_fn, rng, x) for _ in range(2): y, state = jax.eval_shape(g.apply, params, state, rng, x) def assert_dtype(path, v): if jnp.issubdtype(v.dtype, jnp.floating): self.assertEqual(v.dtype, test_dtype, msg=path) tree.map_structure_with_path(assert_dtype, y) tree.map_structure_with_path(assert_dtype, state)
def wrapped_fun(*args) -> str: dot_out = '' # eval_shape cannot evaluate functions which return str, as str is not a # valid JAX types. # The following function extracts the created dot string during the # abstract evaluation. def dot_extractor_fn(*inner_args): nonlocal dot_out dot_out = to_dot(fun)(*inner_args) jax.eval_shape(dot_extractor_fn, *args) assert dot_out, 'Failed to extract dot graph from abstract evaluation' return dot_out
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: """Create a parameter.""" self.reserve(name) if self.has_variable('params', name): abs_rng = jax.ShapeDtypeStruct((2, ), jnp.uint32) value = self.get_variable('params', name) # validate shape of init_fn output is the same as the shape of the existing # parameter. abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng) abs_value_flat = jax.tree_leaves(abs_value) value_flat = jax.tree_leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): # NOTE: we could check dtype consistency here as well but it's usefuleness is less obvious. # we might intentionally change the dtype for inference to a half float type for example. if jnp.shape(val) != jnp.shape(abs_val): raise ValueError( 'Inconsistent shapes between value and initializer ' f'for parameter "{name}" in "{self.path_text}": {jnp.shape(val)}, {jnp.shape(abs_val)}' ) return value else: if not self.is_mutable_collection('params'): raise ValueError( f'No paramater named "{name}" exists in "{self.path_text}".' ) value = init_fn(self.make_rng('params'), *init_args) self.put_variable('params', name, value) return value
def test_abstract_eval_simple(): add_two = primitive( dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) x = jax.ShapeDtypeStruct((10, ), np.float32) output_shape = jax.eval_shape(add_two, x) assert output_shape.shape == (10, ) assert output_shape.dtype == np.int32
def forward_event_shape_tensor(self, event_shape) -> Array: """Returns the shape of the output of `forward` as a `jnp.array`.""" self._check_shape("Forward", event_shape, base_bijector.event_ndims_in) forward_event_shape = jax.eval_shape(base_bijector.forward, jnp.zeros(event_shape)).shape return jnp.array(forward_event_shape, dtype=jnp.int32)
def fast_eval_shape(fun, *args, **kwargs): """Equivalent to ``eval_shape`` in JAX. This utility is equivalent to ``eval_shape`` in JAX except that it avoids running Haiku functions whose shapes are trivially known. This can avoid some Python overheads in JAX which can accumulate for very large models. Optimizations: * All parameter/state initialisers replaced with zeros. * ``hk.dropout`` replaced with identity. * ``jax.random.fold_in`` replaced with identity. Args: fun: The function to trace. *args: Positional arguments to ``fun``. **kwargs: Keyword arguments to ``fun``. Returns: The shape produced by ``fun`` for the given args/kwargs. """ with base.custom_creator_unsafe(zeros_creator), \ mock.patch.object(basic, 'dropout_impl', noop_dropout), \ mock.patch.object(jax.random, 'fold_in', lambda key, data: key): if base.inside_transform(): return stateful.eval_shape(fun, *args, **kwargs) else: return jax.eval_shape(fun, *args, **kwargs)
def test_fast_eval_shape_fold_in(self): f = lambda rng, x: jax.random.fold_in(rng, 1) rng = jax.random.PRNGKey(0) x = jnp.ones([1]) y_slow = jax.eval_shape(f, rng, x) y_fast = eval_shape.fast_eval_shape(f, rng, x) self.assertEqual(y_slow, y_fast)
def inverse_event_shape(self, event_shape) -> tfp.tf2jax.TensorShape: """Returns the shape of the output of `inverse` as a `TensorShape`.""" self._check_shape("Inverse", event_shape, base_bijector.event_ndims_out) inverse_event_shape = jax.eval_shape(base_bijector.inverse, jnp.zeros(event_shape)).shape return tfp.tf2jax.TensorShape(inverse_event_shape)
def get_output_spec(fn, *args, **kwargs): """Traces a callable to determine shape and dtype of its return value(s). Args: fn: Python `callable` accepting (structures of) `Tensor` arguments and returning (structures) of `Tensor`s. *args: `Tensor` and/or `tf.TensorSpec` instances representing positional arguments to `fn`. **kwargs: `Tensor` and/or `tf.TensorSpec` instances representing named arguments to `fn`. Returns: structured_outputs: Object or structure of objects corresponding to the value(s) returned by `fn`. These objects have `.shape` and `.dtype` attributes; nothing else about them is guaranteed by the API. """ if NUMPY_MODE: raise NotImplementedError( 'Either TensorFlow or JAX is required in order ' 'to trace a function without executing it.') if JAX_MODE: import jax # pylint: disable=g-import-not-at-top return jax.eval_shape(fn, *args, **kwargs) def _as_tensor_spec(t): if isinstance(t, tf.TensorSpec): return t return tf.TensorSpec.from_tensor(tf.convert_to_tensor(t)) return tf.function(fn, autograph=False).get_concrete_function( *tf.nest.map_structure(_as_tensor_spec, args), **tf.nest.map_structure(_as_tensor_spec, kwargs)).structured_outputs
def inverse_event_shape_tensor(self, event_shape) -> Array: """Returns the shape of the output of `inverse` as a `jnp.array`.""" self._check_shape("Inverse", event_shape, base_bijector.event_ndims_out) inverse_event_shape = jax.eval_shape(base_bijector.inverse, jnp.zeros(event_shape)).shape return jnp.array(inverse_event_shape, dtype=jnp.int32)
def get_output_tree( jax_function, *args, **kwargs, ): # we need to remove the static arguments first # we first do it for the kwars static_kwargs = {} var_kwargs = {} for name, arg in list(kwargs.items()): if not isvar(arg): static_kwargs.update({name: arg}) else: var_kwargs.update({name: arg}) # we need to do the same for the args who_static = [int(not isvar(arg)) for arg in args] static_args = [arg for i, arg in zip(who_static, args) if i] var_args = [arg for i, arg in zip(who_static, args) if not i] # we need to define an abstract function that only takes as input the # non-static arguments, internally join them with the static ones # and return the output. This is because the jax shape inference # functions does not work with static arguments (such as the dimensions # of the transpose function) def abstract_func(*args, **kwargs): all_args = _args_formatting(args, static_args, who_static) return jax_function(*all_args, **kwargs, **static_kwargs) # now we evaluate the shape from the jax built-in function tree = jax.eval_shape(abstract_func, *var_args, **var_kwargs) return tree
def forward_event_shape(self, event_shape) -> tfp.tf2jax.TensorShape: """Returns the shape of the output of `forward` as a `TensorShape`.""" self._check_shape("Forward", event_shape, base_bijector.event_ndims_in) forward_event_shape = jax.eval_shape(base_bijector.forward, jnp.zeros(event_shape)).shape return tfp.tf2jax.TensorShape(forward_event_shape)
def test_fast_eval_shape_dropout(self): f = lambda rng, x: basic.dropout(rng, 0.5, x) rng = jax.random.PRNGKey(0) x = jnp.ones([1]) y_slow = jax.eval_shape(f, rng, x) y_fast = eval_shape.fast_eval_shape(f, rng, x) self.assertEqual(y_slow, y_fast)
def batch_shape(self) -> Tuple[int, ...]: """Shape of batch of distribution samples.""" sample_shape = jax.eval_shape( lambda: self.sample(seed=jax.random.PRNGKey(0), sample_shape=())).shape if not self.event_shape: return sample_shape return sample_shape[:-len(self.event_shape)]
def _init_state(sampler, machine, parameters, key): rgen = np.random.default_rng(np.asarray(key)) σ = np.zeros((sampler.n_batches, sampler.hilbert.size), dtype=sampler.dtype) ma_out = jax.eval_shape(machine.apply, parameters, σ) state = MetropolisNumpySamplerState( σ=σ, σ1=np.copy(σ), log_values=np.zeros(sampler.n_batches, dtype=ma_out.dtype), log_values_1=np.zeros(sampler.n_batches, dtype=ma_out.dtype), log_prob_corr=np.zeros( sampler.n_batches, dtype=nkjax.dtype_real(ma_out.dtype) ), rng=rgen, rule_state=sampler.rule.init_state(sampler, machine, parameters, rgen), ) if not sampler.reset_chains: key = jnp.asarray( state.rng.integers(0, 1 << 32, size=2, dtype=np.uint32), dtype=np.uint32 ) state.σ = np.copy( sampler.rule.random_state(sampler, machine, parameters, state, key) ) return state
def test_abstract_to_dot(self, module_fn: ModuleFn, shape, dtype): f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(42) x = np.ones(shape, dtype) params, state = jax.eval_shape(f.init, rng, x) self.assertIsNotNone( hk.experimental.abstract_to_dot(f.apply)(params, state, rng, x))
def assert_dtype( self, test_dtype: DType, module_fn: descriptors.ModuleFn, shape: Shape, input_dtype: DType, ): """Checks that modules accepting float32 input_dtype output test_dtype.""" if jax.local_devices()[0].platform != 'tpu': self.skipTest('bfloat16 only supported on TPU') if input_dtype != jnp.float32: self.skipTest('Skipping module without float32 input') rng = jax.random.PRNGKey(42) def g(x): mod = module_fn() return mod(x) init_fn, apply_fn = hk.transform_with_state(g) # Create state in f32 to start. # NOTE: We need to do this since some initializers (e.g. random.uniform) do # not support <32bit dtypes. x = jax.random.uniform(rng, shape) params, state = jax.eval_shape(init_fn, rng, x) # Cast f32 to test_dtype. def make_param(v): dtype = test_dtype if v.dtype == jnp.float32 else v.dtype return jnp.ones(v.shape, dtype) params, state = jax.tree_map(make_param, (params, state)) # test_dtype in should result in test_dtype out. x = x.astype(test_dtype) for _ in range(2): y, state = jax.eval_shape(apply_fn, params, state, rng, x) def assert_dtype(path, v): if v.dtype != jnp.int32: self.assertEqual(v.dtype, test_dtype, msg=path) tree.map_structure_with_path(assert_dtype, y) tree.map_structure_with_path(assert_dtype, state)
def shape_dtype(self): if self._shape_dtype is not None: return self._shape_dtype fun = functools.partial( self.primitive.bind, **self.params) # type: ignore # bind-properties self._shape_dtype = jax.eval_shape(fun, *self.operands) return self._shape_dtype
def transform_and_run_once(f, *args, **kwargs): f = transform.transform(f) def g(*args, **kwargs): rng = jax.random.PRNGKey(28) params = f.init(rng, *args, **kwargs) out = f.apply(params, None, *args, **kwargs) return params, out return jax.tree_map(lambda x: x.dtype, jax.eval_shape(g, *args, **kwargs))
def O_mean(samples, params, forward_fn, **kwargs): r""" compute \langle O \rangle i.e. the mean of the rows of the jacobian of forward_fn """ dtype = jax.eval_shape(forward_fn, params, samples).dtype v = jnp.ones(samples.shape[0], dtype=dtype) * (1.0 / (samples.shape[0] * n_nodes)) return O_vjp(samples, params, v, forward_fn, **kwargs)
def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: self.init_params(feat) logging.info('Running eval_shape with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) logging.info('Output shape was %s', shape) return shape
def force_fn(R, **kwargs): nonlocal _force_fn if _force_fn is None: out_shape = eval_shape(energy_or_force_fn, R, **kwargs).shape if out_shape == (): _force_fn = force(energy_or_force_fn) else: _force_fn = energy_or_force_fn return _force_fn(R, **kwargs)