Ejemplo n.º 1
0
def O_mean(forward_fn, params, samples, holomorphic=True):
    r"""
    compute \langle O \rangle
    i.e. the mean of the rows of the jacobian of forward_fn
    """

    # determine the output type of the forward pass
    dtype = jax.eval_shape(forward_fn, params, samples).dtype
    w = jnp.ones(samples.shape[0],
                 dtype=dtype) * (1.0 / (samples.shape[0] * mpi.n_nodes))

    homogeneous = nkjax.tree_ishomogeneous(params)
    real_params = not nkjax.tree_leaf_iscomplex(params)
    real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params,
                                                   samples))

    if homogeneous and (real_params or holomorphic):
        if real_params and not real_out:
            # R->C
            return O_vjp_rc(forward_fn, params, samples, w)
        else:
            # R->R and holomorphic C->C
            return O_vjp(forward_fn, params, samples, w)
    else:
        # R&C -> C
        # non-holomorphic
        # C->R
        assert False
Ejemplo n.º 2
0
  def wrapper(*args, **kwargs):
    base.assert_context("optimize_rng_use")

    # Extract all current state.
    frame = base.current_frame()
    params = frame.params or None
    if params is not None:
      params = data_structures.to_haiku_dict(params)
    state = frame.state or None
    if state is not None:
      state = base.extract_state(state, initial=True)
    rng = frame.rng_stack.peek()
    if rng is not None:
      rng = rng.internal_state

    def pure_fun(params, state, rng, *args, **kwargs):
      with base.new_context(params=params, state=state, rng=rng):
        return fun(*args, **kwargs)

    with count_hk_rngs_requested() as rng_count_f:
      jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs)
    rng_count = rng_count_f()

    if rng_count:
      base.current_frame().rng_stack.peek().reserve(rng_count)
    return fun(*args, **kwargs)
Ejemplo n.º 3
0
    def __init__(
        self,
        stages: tp.List[int],
        block_type: tp.Union[tp.Type[ResNetBlock],
                             tp.Type[BottleneckResNetBlock]],
        lowres: bool = False,
        weights: tp.Optional[str] = None,
        dtype: tp.Optional[tp.Any] = jnp.float32,
        *args,
        **kwargs,
    ):
        """
        Arguments:
            stages: A list of integers representing the number of blocks in each stage.
                    e.g: [3, 4, 6, 3] for a ResNet50
            block_type: Which ResNet block type to use.
            lowres: Optional, whether to use the low resolution version
                    as described in subsection 4.2 of the orignal paper.
                    This version is better suited for datasets like CIFAR10. (Default: False)
            weights: One of None (random initialization) or a path to a weights file
            dtype: Optional dtype of the convolutions and linear operations,
                    either jnp.float32 (default) or jnp.float16 for mixed precision.
        """

        super().__init__(*args, **kwargs)
        self.stages = stages
        self.block_type = block_type
        self.lowres = lowres

        if weights is not None:
            if weights.endswith(".pkl"):
                collections = pickle.load(open(weights, "rb"))
            elif weights == "imagenet":
                clsname = self.__class__.__name__
                urldict = PRETRAINED_URLS.get(clsname, None)
                if urldict is None:
                    raise ValueError(
                        f"No pretrained weights for {clsname} available")
                fname = utils.download_file(urldict["url"],
                                            sha256=urldict["sha256"])
                collections = pickle.load(open(fname, "rb"))
            else:
                raise ValueError("Unknown weights value: ", weights)

            if isinstance(collections, tuple):
                parameters, collections = collections
            elif "parameters" in collections:
                parameters = collections.pop("parameters")
            else:
                raise ValueError(
                    "Unknown parameters structure, expected either tuple (parameters, collections) or a collections dict with a 'parameters' field."
                )

            x = np.empty([0, 224, 224, 3], dtype=self.dtype)
            # quick but dirty module initialization
            jax.eval_shape(self.init(rng=types.RNGSeq(42)), x)

            self.set_default_parameters(parameters, collections)
Ejemplo n.º 4
0
 def apply(self, input_values, state=None):
     if state is None:
         output_values, _ = jax.eval_shape(self.model.init_with_output,
                                           jax.random.PRNGKey(0),
                                           input_values)
     else:
         output_values = jax.eval_shape(self.model.apply, state,
                                        input_values)
     return output_values
Ejemplo n.º 5
0
def eval_shape(fun, *args, has_aux=False, **kwargs):
    """
    Returns the dtype of forward_fn(pars, v)
    """
    if has_aux:
        out, _ = jax.eval_shape(fun, *args, **kwargs)
    else:
        out = jax.eval_shape(fun, *args, **kwargs)
    return out
Ejemplo n.º 6
0
    def test_strict_promotion(self, module_fn: ModuleFn, shape, dtype):
        if descriptors.module_type(module_fn) in (hk.nets.VectorQuantizer,
                                                  hk.nets.VectorQuantizerEMA):
            self.skipTest('Requires: https://github.com/google/jax/pull/2901')

        f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
        rng = jax.random.PRNGKey(42)
        x = np.ones(shape, dtype)
        params, state = jax.eval_shape(f.init, rng, x)
        self.assertIsNotNone(jax.eval_shape(f.apply, params, state, rng, x))
Ejemplo n.º 7
0
    def test_run(self, module_fn: ModuleFn, shape, dtype):
        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(g)

        def run():
            rng = jax.random.PRNGKey(42)
            x = jnp.zeros(shape, dtype)
            params, state = f.init(rng, x)
            return f.apply(params, state, rng, x)

        jax.eval_shape(run)
Ejemplo n.º 8
0
    def __init__(
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
    ):
        if config is None:
            raise ValueError("config cannot be None")

        if module is None:
            raise ValueError("module cannot be None")

        # Those are private to be exposed as typed property on derived classes.
        self._config = config
        self._module = module

        # Those are public as their type is generic to every derived classes.
        self.key = PRNGKey(seed)
        self.dtype = dtype
        self.input_shape = input_shape

        # To check if the model was intialized automatically.
        self._is_initialized = _do_init

        if _do_init:
            # randomly initialized parameters
            random_params = self.init_weights(self.key, input_shape)
            params_shape_tree = jax.eval_shape(lambda params: params,
                                               random_params)
        else:
            init_fn = partial(self.init_weights, input_shape=input_shape)
            params_shape_tree = jax.eval_shape(init_fn, self.key)

            logger.info(
                "Model weights are not initialized as `_do_init` is set to `False`. "
                f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
            )

        # get the shape of the parameters
        self._params_shape_tree = params_shape_tree

        # save required_params as set
        self._required_params = set(
            flatten_dict(unfreeze(params_shape_tree)).keys())

        # initialize the parameters
        if _do_init:
            self.params = random_params
Ejemplo n.º 9
0
    def assert_dtype(self, test_dtype, module_fn: ModuleFn, shape,
                     input_dtype):
        """Checks that modules accepting float32 input_dtype output test_dtype."""

        if input_dtype != jnp.float32:
            self.skipTest('Skipping module with non-f32 input')

        def ones_creator(next_creator, shape, dtype, init, context):
            if context.full_name == 'vector_quantizer/embeddings':
                # NOTE: vector_quantizer/embeddings is created using a ctor argument
                # so dtype is not expected to follow input to __call__.
                dtype = test_dtype
            else:
                self.assertEqual(dtype, test_dtype, msg=context.full_name)

            # NOTE: We need to do this since some initializers (e.g. random.uniform)
            # do not support <32bit dtypes. This also makes the test run a bit faster.
            init = jnp.ones
            return next_creator(shape, dtype, init)

        def g(x):
            with hk.custom_creator(ones_creator):
                mod = module_fn()
                return mod(x)

        g = hk.transform_with_state(g)

        # No custom creator for state so we need to do this manually.
        def cast_if_floating(x):
            if jnp.issubdtype(x.dtype, jnp.floating):
                x = x.astype(test_dtype)
            return x

        def init_fn(rng, x):
            params, state = g.init(rng, x)
            state = jax.tree_map(cast_if_floating, state)
            return params, state

        x = np.ones(shape, test_dtype)
        rng = jax.random.PRNGKey(42)
        params, state = jax.eval_shape(init_fn, rng, x)

        for _ in range(2):
            y, state = jax.eval_shape(g.apply, params, state, rng, x)

            def assert_dtype(path, v):
                if jnp.issubdtype(v.dtype, jnp.floating):
                    self.assertEqual(v.dtype, test_dtype, msg=path)

            tree.map_structure_with_path(assert_dtype, y)
            tree.map_structure_with_path(assert_dtype, state)
Ejemplo n.º 10
0
    def wrapped_fun(*args) -> str:
        dot_out = ''

        # eval_shape cannot evaluate functions which return str, as str is not a
        # valid JAX types.
        # The following function extracts the created dot string during the
        # abstract evaluation.
        def dot_extractor_fn(*inner_args):
            nonlocal dot_out
            dot_out = to_dot(fun)(*inner_args)

        jax.eval_shape(dot_extractor_fn, *args)
        assert dot_out, 'Failed to extract dot graph from abstract evaluation'
        return dot_out
Ejemplo n.º 11
0
 def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
     """Create a parameter."""
     self.reserve(name)
     if self.has_variable('params', name):
         abs_rng = jax.ShapeDtypeStruct((2, ), jnp.uint32)
         value = self.get_variable('params', name)
         # validate shape of init_fn output is the same as the shape of the existing
         # parameter.
         abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args),
                                    abs_rng)
         abs_value_flat = jax.tree_leaves(abs_value)
         value_flat = jax.tree_leaves(value)
         for val, abs_val in zip(value_flat, abs_value_flat):
             # NOTE: we could check dtype consistency here as well but it's usefuleness is less obvious.
             # we might intentionally change the dtype for inference to a half float type for example.
             if jnp.shape(val) != jnp.shape(abs_val):
                 raise ValueError(
                     'Inconsistent shapes between value and initializer '
                     f'for parameter "{name}" in "{self.path_text}": {jnp.shape(val)}, {jnp.shape(abs_val)}'
                 )
         return value
     else:
         if not self.is_mutable_collection('params'):
             raise ValueError(
                 f'No paramater named "{name}" exists in "{self.path_text}".'
             )
         value = init_fn(self.make_rng('params'), *init_args)
         self.put_variable('params', name, value)
         return value
Ejemplo n.º 12
0
def test_abstract_eval_simple():
    add_two = primitive(
        dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0'))
    x = jax.ShapeDtypeStruct((10, ), np.float32)
    output_shape = jax.eval_shape(add_two, x)
    assert output_shape.shape == (10, )
    assert output_shape.dtype == np.int32
Ejemplo n.º 13
0
 def forward_event_shape_tensor(self, event_shape) -> Array:
     """Returns the shape of the output of `forward` as a `jnp.array`."""
     self._check_shape("Forward", event_shape,
                       base_bijector.event_ndims_in)
     forward_event_shape = jax.eval_shape(base_bijector.forward,
                                          jnp.zeros(event_shape)).shape
     return jnp.array(forward_event_shape, dtype=jnp.int32)
Ejemplo n.º 14
0
def fast_eval_shape(fun, *args, **kwargs):
    """Equivalent to ``eval_shape`` in JAX.

  This utility is equivalent to ``eval_shape`` in JAX except that it avoids
  running Haiku functions whose shapes are trivially known. This can avoid some
  Python overheads in JAX which can accumulate for very large models.

  Optimizations:

  * All parameter/state initialisers replaced with zeros.
  * ``hk.dropout`` replaced with identity.
  * ``jax.random.fold_in`` replaced with identity.

  Args:
    fun: The function to trace.
    *args: Positional arguments to ``fun``.
    **kwargs: Keyword arguments to ``fun``.

  Returns:
    The shape produced by ``fun`` for the given args/kwargs.
  """
    with base.custom_creator_unsafe(zeros_creator), \
         mock.patch.object(basic, 'dropout_impl', noop_dropout), \
         mock.patch.object(jax.random, 'fold_in', lambda key, data: key):
        if base.inside_transform():
            return stateful.eval_shape(fun, *args, **kwargs)
        else:
            return jax.eval_shape(fun, *args, **kwargs)
Ejemplo n.º 15
0
 def test_fast_eval_shape_fold_in(self):
     f = lambda rng, x: jax.random.fold_in(rng, 1)
     rng = jax.random.PRNGKey(0)
     x = jnp.ones([1])
     y_slow = jax.eval_shape(f, rng, x)
     y_fast = eval_shape.fast_eval_shape(f, rng, x)
     self.assertEqual(y_slow, y_fast)
Ejemplo n.º 16
0
 def inverse_event_shape(self, event_shape) -> tfp.tf2jax.TensorShape:
     """Returns the shape of the output of `inverse` as a `TensorShape`."""
     self._check_shape("Inverse", event_shape,
                       base_bijector.event_ndims_out)
     inverse_event_shape = jax.eval_shape(base_bijector.inverse,
                                          jnp.zeros(event_shape)).shape
     return tfp.tf2jax.TensorShape(inverse_event_shape)
def get_output_spec(fn, *args, **kwargs):
    """Traces a callable to determine shape and dtype of its return value(s).

  Args:
    fn: Python `callable` accepting (structures of) `Tensor` arguments and
      returning (structures) of `Tensor`s.
    *args: `Tensor` and/or `tf.TensorSpec` instances representing positional
      arguments to `fn`.
    **kwargs: `Tensor` and/or `tf.TensorSpec` instances representing named
      arguments to `fn`.
  Returns:
    structured_outputs: Object or structure of objects corresponding to the
      value(s) returned by `fn`. These objects have `.shape` and
      `.dtype` attributes; nothing else about them is guaranteed by the API.
  """

    if NUMPY_MODE:
        raise NotImplementedError(
            'Either TensorFlow or JAX is required in order '
            'to trace a function without executing it.')

    if JAX_MODE:
        import jax  # pylint: disable=g-import-not-at-top
        return jax.eval_shape(fn, *args, **kwargs)

    def _as_tensor_spec(t):
        if isinstance(t, tf.TensorSpec):
            return t
        return tf.TensorSpec.from_tensor(tf.convert_to_tensor(t))

    return tf.function(fn, autograph=False).get_concrete_function(
        *tf.nest.map_structure(_as_tensor_spec, args),
        **tf.nest.map_structure(_as_tensor_spec, kwargs)).structured_outputs
Ejemplo n.º 18
0
 def inverse_event_shape_tensor(self, event_shape) -> Array:
     """Returns the shape of the output of `inverse` as a `jnp.array`."""
     self._check_shape("Inverse", event_shape,
                       base_bijector.event_ndims_out)
     inverse_event_shape = jax.eval_shape(base_bijector.inverse,
                                          jnp.zeros(event_shape)).shape
     return jnp.array(inverse_event_shape, dtype=jnp.int32)
Ejemplo n.º 19
0
def get_output_tree(
    jax_function,
    *args,
    **kwargs,
):

    # we need to remove the static arguments first
    # we first do it for the kwars
    static_kwargs = {}
    var_kwargs = {}
    for name, arg in list(kwargs.items()):
        if not isvar(arg):
            static_kwargs.update({name: arg})
        else:
            var_kwargs.update({name: arg})

    # we need to do the same for the args
    who_static = [int(not isvar(arg)) for arg in args]
    static_args = [arg for i, arg in zip(who_static, args) if i]
    var_args = [arg for i, arg in zip(who_static, args) if not i]

    # we need to define an abstract function that only takes as input the
    # non-static arguments, internally join them with the static ones
    # and return the output. This is because the jax shape inference
    # functions does not work with static arguments (such as the dimensions
    # of the transpose function)

    def abstract_func(*args, **kwargs):
        all_args = _args_formatting(args, static_args, who_static)
        return jax_function(*all_args, **kwargs, **static_kwargs)

    # now we evaluate the shape from the jax built-in function
    tree = jax.eval_shape(abstract_func, *var_args, **var_kwargs)
    return tree
Ejemplo n.º 20
0
 def forward_event_shape(self, event_shape) -> tfp.tf2jax.TensorShape:
     """Returns the shape of the output of `forward` as a `TensorShape`."""
     self._check_shape("Forward", event_shape,
                       base_bijector.event_ndims_in)
     forward_event_shape = jax.eval_shape(base_bijector.forward,
                                          jnp.zeros(event_shape)).shape
     return tfp.tf2jax.TensorShape(forward_event_shape)
Ejemplo n.º 21
0
 def test_fast_eval_shape_dropout(self):
     f = lambda rng, x: basic.dropout(rng, 0.5, x)
     rng = jax.random.PRNGKey(0)
     x = jnp.ones([1])
     y_slow = jax.eval_shape(f, rng, x)
     y_fast = eval_shape.fast_eval_shape(f, rng, x)
     self.assertEqual(y_slow, y_fast)
Ejemplo n.º 22
0
 def batch_shape(self) -> Tuple[int, ...]:
   """Shape of batch of distribution samples."""
   sample_shape = jax.eval_shape(
       lambda: self.sample(seed=jax.random.PRNGKey(0), sample_shape=())).shape
   if not self.event_shape:
     return sample_shape
   return sample_shape[:-len(self.event_shape)]
Ejemplo n.º 23
0
    def _init_state(sampler, machine, parameters, key):
        rgen = np.random.default_rng(np.asarray(key))

        σ = np.zeros((sampler.n_batches, sampler.hilbert.size), dtype=sampler.dtype)

        ma_out = jax.eval_shape(machine.apply, parameters, σ)

        state = MetropolisNumpySamplerState(
            σ=σ,
            σ1=np.copy(σ),
            log_values=np.zeros(sampler.n_batches, dtype=ma_out.dtype),
            log_values_1=np.zeros(sampler.n_batches, dtype=ma_out.dtype),
            log_prob_corr=np.zeros(
                sampler.n_batches, dtype=nkjax.dtype_real(ma_out.dtype)
            ),
            rng=rgen,
            rule_state=sampler.rule.init_state(sampler, machine, parameters, rgen),
        )

        if not sampler.reset_chains:
            key = jnp.asarray(
                state.rng.integers(0, 1 << 32, size=2, dtype=np.uint32), dtype=np.uint32
            )

            state.σ = np.copy(
                sampler.rule.random_state(sampler, machine, parameters, state, key)
            )

        return state
Ejemplo n.º 24
0
 def test_abstract_to_dot(self, module_fn: ModuleFn, shape, dtype):
   f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
   rng = jax.random.PRNGKey(42)
   x = np.ones(shape, dtype)
   params, state = jax.eval_shape(f.init, rng, x)
   self.assertIsNotNone(
       hk.experimental.abstract_to_dot(f.apply)(params, state, rng, x))
Ejemplo n.º 25
0
    def assert_dtype(
        self,
        test_dtype: DType,
        module_fn: descriptors.ModuleFn,
        shape: Shape,
        input_dtype: DType,
    ):
        """Checks that modules accepting float32 input_dtype output test_dtype."""
        if jax.local_devices()[0].platform != 'tpu':
            self.skipTest('bfloat16 only supported on TPU')

        if input_dtype != jnp.float32:
            self.skipTest('Skipping module without float32 input')

        rng = jax.random.PRNGKey(42)

        def g(x):
            mod = module_fn()
            return mod(x)

        init_fn, apply_fn = hk.transform_with_state(g)

        # Create state in f32 to start.
        # NOTE: We need to do this since some initializers (e.g. random.uniform) do
        # not support <32bit dtypes.
        x = jax.random.uniform(rng, shape)
        params, state = jax.eval_shape(init_fn, rng, x)

        # Cast f32 to test_dtype.
        def make_param(v):
            dtype = test_dtype if v.dtype == jnp.float32 else v.dtype
            return jnp.ones(v.shape, dtype)

        params, state = jax.tree_map(make_param, (params, state))

        # test_dtype in should result in test_dtype out.
        x = x.astype(test_dtype)

        for _ in range(2):
            y, state = jax.eval_shape(apply_fn, params, state, rng, x)

            def assert_dtype(path, v):
                if v.dtype != jnp.int32:
                    self.assertEqual(v.dtype, test_dtype, msg=path)

            tree.map_structure_with_path(assert_dtype, y)
            tree.map_structure_with_path(assert_dtype, state)
Ejemplo n.º 26
0
 def shape_dtype(self):
     if self._shape_dtype is not None:
         return self._shape_dtype
     fun = functools.partial(
         self.primitive.bind,
         **self.params)  # type: ignore  # bind-properties
     self._shape_dtype = jax.eval_shape(fun, *self.operands)
     return self._shape_dtype
Ejemplo n.º 27
0
def transform_and_run_once(f, *args, **kwargs):
  f = transform.transform(f)
  def g(*args, **kwargs):
    rng = jax.random.PRNGKey(28)
    params = f.init(rng, *args, **kwargs)
    out = f.apply(params, None, *args, **kwargs)
    return params, out
  return jax.tree_map(lambda x: x.dtype, jax.eval_shape(g, *args, **kwargs))
Ejemplo n.º 28
0
def O_mean(samples, params, forward_fn, **kwargs):
    r"""
    compute \langle O \rangle
    i.e. the mean of the rows of the jacobian of forward_fn
    """
    dtype = jax.eval_shape(forward_fn, params, samples).dtype
    v = jnp.ones(samples.shape[0], dtype=dtype) * (1.0 / (samples.shape[0] * n_nodes))
    return O_vjp(samples, params, v, forward_fn, **kwargs)
Ejemplo n.º 29
0
 def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct:
     self.init_params(feat)
     logging.info('Running eval_shape with shape(feat) = %s',
                  tree.map_structure(lambda x: x.shape, feat))
     shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0),
                            feat)
     logging.info('Output shape was %s', shape)
     return shape
Ejemplo n.º 30
0
 def force_fn(R, **kwargs):
   nonlocal _force_fn
   if _force_fn is None:
     out_shape = eval_shape(energy_or_force_fn, R, **kwargs).shape
     if out_shape == ():
       _force_fn = force(energy_or_force_fn)
     else:
       _force_fn = energy_or_force_fn
   return _force_fn(R, **kwargs)