def __call__(self, inputs): channel_index = utils.get_channel_index(self._data_format) weight_shape = self._kernel_shape + (1, self._channel_multiplier * inputs.shape[channel_index]) fan_in_shape = np.prod(weight_shape[:-1]) stddev = 1. / np.sqrt(fan_in_shape) w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev) w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init) if self._channel_index == -1: dn = DIMENSION_NUMBERS[self._num_spatial_dims] else: dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims] result = lax.conv_general_dilated( inputs, w, self._stride, self._padding, self._lhs_dilation, self._rhs_dilation, dn, feature_group_count=inputs.shape[channel_index]) if self._with_bias: if channel_index == -1: bias_shape = (self._channel_multiplier * inputs.shape[channel_index], ) else: bias_shape = (self._channel_multiplier * inputs.shape[channel_index], 1, 1) b = base.get_parameter("b", bias_shape, init=self._b_init) result = result + b return result
def __call__(self, inputs, state): if len(inputs.shape) > 2 or not inputs.shape: raise ValueError("GRU input must be rank-1 or rank-2.") input_size = inputs.shape[-1] hidden_size = self.hidden_size w_i = base.get_parameter( name="w_i", shape=[input_size, 3 * hidden_size], init=self._w_i_init) w_h = base.get_parameter( name="w_h", shape=[hidden_size, 3 * hidden_size], init=self._w_h_init) b = base.get_parameter( name="b", shape=[3 * hidden_size], dtype=inputs.dtype, init=self._b_init) w_h_z, w_h_a = jnp.split(w_h, indices_or_sections=[2 * hidden_size], axis=1) b_z, b_a = jnp.split(b, indices_or_sections=[2 * hidden_size], axis=0) gates_x = jnp.matmul(inputs, w_i) zr_x, a_x = jnp.split( gates_x, indices_or_sections=[2 * hidden_size], axis=-1) zr_h = jnp.matmul(state, w_h_z) zr = zr_x + zr_h + jnp.broadcast_to(b_z, zr_h.shape) z, r = jnp.split(jax.nn.sigmoid(zr), indices_or_sections=2, axis=-1) a_h = jnp.matmul(r * state, w_h_a) a = jnp.tanh(a_x + a_h + jnp.broadcast_to(b_a, a_h.shape)) next_state = (1 - z) * state + z * a return next_state, next_state
def test_unable_to_mutate_name(self): def mutates_name(next_creator, name, shape, dtype, init): next_creator(name + "_foo", shape, dtype, init) with base.new_context(), base.custom_creator(mutates_name): with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"): base.get_parameter("w", [], init=jnp.ones)
def test_context_copies_input(self): before = {"~": {"w": jnp.array(1.)}} with base.new_context(params=before, state=before) as ctx: base.get_parameter("w", [], init=jnp.ones) base.set_state("w", jnp.array(2.)) self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}}) self.assertIsNot(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}}) self.assertEqual(before, {"~": {"w": jnp.array(1.)}})
def __call__(self, inputs, scale=None, offset=None): """Connects the layer norm. Args: inputs: An array, where the data format is [N, ..., C]. scale: An array up to n-D. The shape of this tensor must be broadcastable to the shape of `inputs`. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with `create_scale=True`. offset: An array up to n-D. The shape of this tensor must be broadcastable to the shape of `inputs`. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with `create_offset=True`. Returns: The array, normalized. """ if isinstance(self._axis, slice): axes = tuple(range(len(inputs.shape))) axis = axes[self._axis] else: axis = self._axis m = jnp.mean(inputs, axis=axis, keepdims=True) variance = jnp.var(inputs, axis=axis, keepdims=True) param_shape = inputs.shape[-1:] if self._create_scale: if scale is not None: raise ValueError( "Cannot pass `scale` at call time if `create_scale=True`.") scale = base.get_parameter("scale", param_shape, init=self._scale_init) elif scale is None: scale = 1. if self._create_offset: if offset is not None: raise ValueError( "Cannot pass `offset` at call time if `create_offset=True`." ) offset = base.get_parameter("offset", param_shape, init=self._offset_init) elif offset is None: offset = 0. scale = jnp.broadcast_to(scale, inputs.shape) offset = jnp.broadcast_to(offset, inputs.shape) m = jnp.broadcast_to(m, inputs.shape) inv = scale * jax.lax.rsqrt(variance + self._eps) return inv * (inputs - m) + offset
def test_init_custom_creator(self): def zeros_creator(next_creator, name, shape, dtype, init): self.assertEqual(name, "~/w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(name, shape, dtype, jnp.zeros) with base.new_context() as ctx: with base.custom_creator(zeros_creator): base.get_parameter("w", [], init=jnp.ones) self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.zeros([])}})
def __call__(self, inputs): if not inputs.shape: raise ValueError("Input must not be scalar.") self.input_size = inputs.shape[-1] default_stddev = 1. / jnp.sqrt(self.input_size) w_init = self.w_init or initializers.TruncatedNormal(stddev=default_stddev) w = base.get_parameter("w", [self.input_size, self.output_size], inputs.dtype, init=w_init) out = jnp.dot(inputs, w) if self.with_bias: out += base.get_parameter("b", [self.output_size], inputs.dtype, init=self.b_init) return out
def test_parameter_in_apply(self, params): _, apply_fn = transform.transform( lambda: base.get_parameter("w", [], init=jnp.zeros)) with self.assertRaisesRegex( ValueError, "parameters must be created as part of `init`"): apply_fn(params)
def __call__(self, inputs: jnp.ndarray, multiplier: FloatLike = None): """Adds bias to `inputs` and optionally multiplies by `multiplier`. Args: inputs: A Tensor of size `[batch_size, input_size1, ...]`. multiplier: A scalar or Tensor which the bias term is multiplied by before adding it to `inputs`. Anything which works in the expression `bias * multiplier` is acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place via `multiplier=-1`. Returns: A Tensor of size `[batch_size, input_size1, ...]`. """ utils.assert_minimum_rank(inputs, 2) input_shape = inputs.shape self.bias_shape = calculate_bias_shape(input_shape, self.bias_dims) input_size = input_shape[1:] if self.output_size is not None and self.output_size != input_size: raise ValueError("Input shape must be {} not {}".format( (-1, ) + self.output_size, input_shape)) self.input_size = input_size b = base.get_parameter("b", self.bias_shape, inputs.dtype, init=self.b_init) b = jnp.broadcast_to(b, inputs.shape) if multiplier is not None: return inputs + (b * multiplier) else: return inputs + b
def test_do_not_store(self): def my_creator(next_creator, shape, dtype, init, context): del next_creator, shape, dtype, init, context return base.DO_NOT_STORE def my_getter(next_getter, value, context): assert value is base.DO_NOT_STORE return next_getter( context.original_init(context.original_shape, context.original_dtype)) def my_setter(next_setter, value, context): del next_setter, value, context return base.DO_NOT_STORE with base.new_context() as ctx: with base.custom_creator(my_creator, state=True), \ base.custom_getter(my_getter, state=True), \ base.custom_setter(my_setter): self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1) self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1) base.set_state("s2", jnp.ones([])) self.assertEmpty(ctx.collect_params()) self.assertEmpty(ctx.collect_state())
def __call__(self, *args, **kwargs): frame = base.current_frame() bundle_name = self.module_name if _SENTINEL_NAME in frame.params[bundle_name]: prefix = bundle_name + "/" lifted_params = unpack_from_dict(frame.params, prefix) lifted_state = unpack_from_dict(frame.state, prefix) return lifted_params, lifted_state else: # Ensure sentinel is set for apply. base.get_parameter(_SENTINEL_NAME, (), init=jnp.zeros) # Lift parameters into this transform's params_dict. params, state = self._init_fn(*args, **kwargs) pack_into_dict(params, frame.params, bundle_name) pack_into_dict(state, frame.state, bundle_name) return params, state
def __init__(self, embedding_dim: int, num_embeddings: int, commitment_cost: float, dtype: DType = jnp.float32, name: str = None): """Initializes a VQ-VAE module. Args: embedding_dim: dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well. num_embeddings: number of vectors in the quantized space. commitment_cost: scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta). dtype: dtype for the embeddings variable, defaults to tf.float32. name: name of the module. """ super(VectorQuantizer, self).__init__(name=name) self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.commitment_cost = commitment_cost embedding_shape = [embedding_dim, num_embeddings] initializer = initializers.VarianceScaling(distribution='uniform') self.embeddings = base.get_parameter('embeddings', embedding_shape, dtype, init=initializer)
def test_getter_types(self, params, state): log = [] def logging_getter(next_getter, value, context): log.append(context.full_name) return next_getter(value) with base.new_context(): with base.custom_getter(logging_getter, params=params, state=state): base.get_parameter("params", [], init=jnp.zeros) base.get_state("state", [], init=jnp.zeros) self.assertLen(log, int(params) + int(state)) if params: self.assertIn("~/params", log) if state: self.assertIn("~/state", log)
def __call__(self, inputs): """Connects `ConvND` layer. Args: inputs: A rank-N+2 array with shape [N, spatial_dims, C]. Returns: A rank-N+2 array with shape [N, spatial_dims, output_channels]. """ if len(inputs.shape) != self._num_spatial_dims + 2: raise ValueError( "Input to ConvND needs to have rank {}, but input " "has shape {}.".format(self._num_spatial_dims + 2, inputs.shape)) weight_shape = self._kernel_shape + (inputs.shape[self._channel_index], self._output_channels) fan_in_shape = np.prod(weight_shape[:-1]) stddev = 1. / np.sqrt(fan_in_shape) w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev) w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init) if self._mask is not None: if self._mask.shape != w.shape: raise ValueError( "Mask needs to have the same shape as weights. " "Shapes are: {}, {}".format(self._mask.shape, w.shape)) w *= self._mask result = lax.conv_general_dilated(inputs, w, self._stride, self._padding, lhs_dilation=self._lhs_dilation, rhs_dilation=self._kernal_dilation, dimension_numbers=self._dn) if self._with_bias: if self._channel_index == -1: bias_shape = (self._output_channels, ) else: bias_shape = ( self._output_channels, ) + (1, ) * self._num_spatial_dims b = base.get_parameter("b", bias_shape, inputs.dtype, init=self._b_init) result = result + b return result
def test_custom_getter_bf16(self): def bf16_getter(next_getter, value, context): del context if value.dtype == jnp.float32: value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with base.custom_getter(bf16_getter): f = base.get_parameter("f", [], jnp.float32, init=jnp.ones) i = base.get_parameter("i", [], jnp.int32, init=jnp.ones) params = ctx.collect_params() self.assertEqual(params["~"]["f"].dtype, jnp.float32) self.assertEqual(f.dtype, jnp.bfloat16) self.assertEqual(params["~"]["i"].dtype, jnp.int32) self.assertEqual(i.dtype, jnp.int32)
def __call__(self, x): assert x.ndim == 0 p = base.get_parameter("p", [], jnp.int32, init=lambda *_: jnp.array(2)) y = x**p base.set_state("y", y) return y
def test_nested_creators(self): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, name, shape, dtype, init): log.append(log_msg) return next_creator(name, shape, dtype, init) return _logging_creator with base.new_context(): with base.custom_creator(logging_creator("a")), \ base.custom_creator(logging_creator("b")), \ base.custom_creator(logging_creator("c")): base.get_parameter("w", [], init=jnp.ones) self.assertEqual(log, ["a", "b", "c"])
def __call__(self, carry, x): x += base.get_parameter("w", shape=[], init=jnp.zeros) inner = transform.transform(inner_fn) keys = base.next_rng_key() if transform.running_init( ) else None params = lift.lift(inner.init, allow_reuse=self._allow_reuse)(keys, x) return carry, inner.apply(params, None, x)
def test_unable_to_mutate_name(self): def mutates_name(next_creator, name, shape, dtype, init): next_creator(name + "_foo", shape, dtype, init) init_fn, _ = base.transform( lambda: base.get_parameter("w", [], init=jnp.ones)) with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"): with base.custom_creator(mutates_name): init_fn(None)
def __call__(self, x): x += base.get_parameter("a", shape=[10, 10], init=jnp.zeros) def inner_fn(x): return InnerModule(name="inner")(x) inner_transformed = transform.transform(inner_fn) inner_params = lift.transparent_lift(inner_transformed.init)( base.next_rng_key(), x) x = inner_transformed.apply(inner_params, base.next_rng_key(), x) return x
def test_init_custom_creator(self): def zeros_creator(next_creator, name, shape, dtype, init): self.assertEqual(name, "~/w") self.assertEqual(shape, []) self.assertEqual(dtype, jnp.float32) self.assertEqual(init, jnp.ones) return next_creator(name, shape, dtype, jnp.zeros) init_fn, _ = base.transform( lambda: base.get_parameter("w", [], init=jnp.ones)) with base.custom_creator(zeros_creator): params = init_fn(None) self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
def f(): w = base.get_parameter('w', [], init=jnp.zeros) s = base.get_state('s', [], init=jnp.zeros) init = lambda: None def add(): s_add = base.get_state('s', [], init=jnp.zeros) w_add = base.get_parameter('w', [], init=jnp.zeros) return w, w_add, s, s_add def sub(): s_sub = base.get_state('s', [], init=jnp.zeros) w_sub = base.get_parameter('w', [], init=jnp.zeros) return w, w_sub, s, s_sub return init, (add, sub)
def test_nested_creators(self): log = [] def logging_creator(log_msg): def _logging_creator(next_creator, name, shape, dtype, init): log.append(log_msg) return next_creator(name, shape, dtype, init) return _logging_creator init_fn, _ = base.transform( lambda: base.get_parameter("w", [], init=jnp.ones)) a, b, c = map(logging_creator, ["a", "b", "c"]) with base.custom_creator(a), base.custom_creator(b), base.custom_creator(c): init_fn(None) self.assertEqual(log, ["a", "b", "c"])
def test_original_shape(self): def new_shape_creator(next_creator, shape, dtype, init, context): del shape del context new_shape = (1, 2, 3) return next_creator(new_shape, dtype, init) def original_shape_restorer(next_creator, shape, dtype, init, context): assert shape == (1, 2, 3) return next_creator(context.original_shape, dtype, init) with base.new_context(): with base.custom_creator(new_shape_creator): with base.custom_creator(original_shape_restorer): param = base.get_parameter("w", [5], jnp.bfloat16, jnp.ones) assert param.shape == (5, )
def test_assert_no_new_parameters(self): with base.new_context(): base.get_parameter("w", [], init=jnp.zeros) with base.assert_no_new_parameters(): # Should not raise, "w" already exists. base.get_parameter("w", [], init=jnp.zeros) with self.assertRaisesRegex(AssertionError, "New parameters were created: .*x"): with base.assert_no_new_parameters(): # Should raise, "x" does not exist. base.get_parameter("x", [], init=jnp.zeros)
def test_original_dtype(self): def dtype_cast_creator(next_creator, shape, dtype, init, context): if context.original_dtype == jnp.bfloat16: dtype = jnp.float32 return next_creator(shape, dtype, init) def dtype_recast_getter(next_getter, value, context): if context.original_dtype == jnp.bfloat16: assert value.dtype == jnp.float32 value = value.astype(jnp.bfloat16) return next_getter(value) with base.new_context() as ctx: with base.custom_creator(dtype_cast_creator), \ base.custom_getter(dtype_recast_getter): param = base.get_parameter("w", [], jnp.bfloat16, jnp.ones) orig_param = jax.tree_leaves(ctx.collect_params())[0] assert param.dtype == jnp.bfloat16 assert orig_param.dtype == jnp.float32
def test_nested_getters(self): log = [] def logging_getter(log_msg, dtype_in, dtype_out): def _logging_getter(next_getter, value, context): del context log.append(log_msg) self.assertEqual(value.dtype, dtype_in) value = value.astype(dtype_out) return next_getter(value) return _logging_getter with base.new_context(): with base.custom_getter(logging_getter("a", jnp.float32, jnp.bfloat16)), \ base.custom_getter(logging_getter("b", jnp.bfloat16, jnp.int32)), \ base.custom_getter(logging_getter("c", jnp.int32, jnp.int8)): w = base.get_parameter("w", [], init=jnp.ones) self.assertEqual(w.dtype, jnp.int8) self.assertEqual(log, ["a", "b", "c"])
def bias_fn(x): b = base.get_parameter("b", [], init=jnp.ones) return x + b
def __init__(self, vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style=EmbedLookupStyle.ARRAY_INDEX.name, name=None): """Constructs an Embed module. Args: vocab_size: int or None: the number of unique tokens to embed. If not provided, an existing vocabulary matrix from which vocab_size can be inferred must be provided as `existing_vocab`. embed_dim: int or None. Number of dimensions to assign to each embedding. If an existing vocabulary matrix initializes the module, this should not be provided as it will be inferred. embedding_matrix: A matrix-like object equivalent in size to [vocab_size, embed_dim]. If given, it is used as the initial value for the embedding matrix and neither vocab_size or embed_dim need be given. If they are given, their values are checked to be consistent with the dimensions of embedding_matrix. w_init: An initializer for the embeddings matrix. As a default, embeddings are initialized via a truncated normal distribution. lookup_style: One of the enum values of EmbedLookupStyle determining how to access the value of the embbeddings given an ID. Regardless the input should be a dense array of integer values representing ids. This setting changes how internally this module maps those ides to embeddings. The result is the same, but the speed and memory tradeoffs are different. It default to using numpy-style array indexing. This value is only the default for the module, and at any given invocation can be overriden in the __call__ method. name: string. Name for this module. Raise: ValueError: If none of embed_dim, embedding_matrix and vocab_size are supplied, or if embedding_matrix is supplied and embed_dim or vocab_size is not consistent with the supplied matrix. """ super(Embed, self).__init__(name=name) if not embedding_matrix and not (vocab_size and embed_dim): raise ValueError( "hk.Embed must be supplied either with an initial `embedding_matrix` " "or with `embed_dim` and `vocab_size`.") if embedding_matrix: embedding_matrix = jnp.asarray(embedding_matrix) if vocab_size and embedding_matrix.shape[0] != vocab_size: raise ValueError( "An `embedding_matrix` was supplied but the `vocab_size` of {vs} " "was not consistent with its shape {emb_shape}.".format( vs=vocab_size, emb_shape=embedding_matrix.shape)) if embed_dim and embedding_matrix.shape[1] != embed_dim: raise ValueError( "An `embedding_matrix` was supplied but the `embed_dim` of {ed} " "was not consistent with its shape {emb_shape}.".format( ed=embed_dim, emb_shape=embedding_matrix.shape)) self._embedding = base.get_parameter( "embeddings", shape=embedding_matrix.shape, init=lambda _, __: embedding_matrix) else: w_init = w_init or hk_init.TruncatedNormal() self._embedding = base.get_parameter("embeddings", shape=[vocab_size, embed_dim], init=w_init) self._vocab_size = vocab_size or embedding_matrix.shape[0] self._embed_dim = embed_dim or embedding_matrix.shape[1] self._lookup_style = lookup_style
def __call__(self): w = base.get_parameter("w", [], init=jnp.zeros) with module.name_scope("foo"): w_foo = base.get_parameter("w", [], init=jnp.zeros) return w, w_foo