예제 #1
0
    def test_complex_dtype(self):
        if jax.local_devices()[0].platform == "tpu":
            self.skipTest("Complex dtype not supported by TPU")
        # This just makes sure we can call the initializers in accordance to the
        # API and get the right shapes and dtypes out.
        inits = [
            initializers.Constant(42. + 1j * 1729.),
            initializers.RandomNormal(),
            initializers.RandomNormal(2.0),
            initializers.RandomNormal(2. - 3j),
            initializers.TruncatedNormal(),
            initializers.TruncatedNormal(2.),
            initializers.TruncatedNormal(2., 1. - 1j),

            # Users are supposed to be able to use these.
            jnp.zeros,
            jnp.ones,
        ]

        shape = (5, 13, 17)

        dtype = jnp.complex64
        for init in inits:
            generated = init(shape, dtype)
            self.assertEqual(generated.shape, shape)
            self.assertEqual(generated.dtype, dtype)
예제 #2
0
    def __call__(self, inputs):

        channel_index = utils.get_channel_index(self._data_format)
        weight_shape = self._kernel_shape + (1, self._channel_multiplier *
                                             inputs.shape[channel_index])
        fan_in_shape = np.prod(weight_shape[:-1])
        stddev = 1. / np.sqrt(fan_in_shape)
        w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev)
        w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init)
        if self._channel_index == -1:
            dn = DIMENSION_NUMBERS[self._num_spatial_dims]
        else:
            dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]
        result = lax.conv_general_dilated(
            inputs,
            w,
            self._stride,
            self._padding,
            self._lhs_dilation,
            self._rhs_dilation,
            dn,
            feature_group_count=inputs.shape[channel_index])
        if self._with_bias:
            if channel_index == -1:
                bias_shape = (self._channel_multiplier *
                              inputs.shape[channel_index], )
            else:
                bias_shape = (self._channel_multiplier *
                              inputs.shape[channel_index], 1, 1)
            b = base.get_parameter("b", bias_shape, init=self._b_init)
            result = result + b
        return result
예제 #3
0
    def test_initializers(self):
        # This just makes sure we can call the initializers in accordance to the
        # API and get the right shapes and dtypes out.
        inits = [
            initializers.Constant(42.0),
            initializers.RandomNormal(),
            initializers.RandomNormal(2.0),
            initializers.RandomUniform(),
            initializers.RandomUniform(3.0),
            initializers.VarianceScaling(),
            initializers.VarianceScaling(2.0),
            initializers.VarianceScaling(2.0, mode="fan_in"),
            initializers.VarianceScaling(2.0, mode="fan_out"),
            initializers.VarianceScaling(2.0, mode="fan_avg"),
            initializers.VarianceScaling(2.0, distribution="truncated_normal"),
            initializers.VarianceScaling(2.0, distribution="normal"),
            initializers.VarianceScaling(2.0, distribution="uniform"),
            initializers.UniformScaling(),
            initializers.UniformScaling(2.0),
            initializers.TruncatedNormal(),
            initializers.Orthogonal(),

            # Users are supposed to be able to use these.
            jnp.zeros,
            jnp.ones,
        ]

        # TODO(ibab): Test other shapes as well.
        shape = (20, 42)

        dtype = jnp.float32
        for init in inits:
            generated = init(shape, dtype)
            self.assertEqual(generated.shape, shape)
            self.assertEqual(generated.dtype, dtype)
예제 #4
0
파일: basic.py 프로젝트: ml-lab/dm-haiku
  def __call__(self, inputs):
    if not inputs.shape:
      raise ValueError("Input must not be scalar.")

    self.input_size = inputs.shape[-1]
    default_stddev = 1. / jnp.sqrt(self.input_size)
    w_init = self.w_init or initializers.TruncatedNormal(stddev=default_stddev)

    w = base.get_parameter("w", [self.input_size, self.output_size],
                           inputs.dtype, init=w_init)
    out = jnp.dot(inputs, w)
    if self.with_bias:
      out += base.get_parameter("b", [self.output_size], inputs.dtype,
                                init=self.b_init)
    return out
예제 #5
0
파일: conv.py 프로젝트: ibab/haiku
    def __call__(self, inputs):
        """Connects `ConvND` layer.

    Args:
      inputs: A rank-N+2 array with shape [N, spatial_dims, C].

    Returns:
      A rank-N+2 array with shape [N, spatial_dims, output_channels].
    """
        if len(inputs.shape) != self._num_spatial_dims + 2:
            raise ValueError(
                "Input to ConvND needs to have rank {}, but input "
                "has shape {}.".format(self._num_spatial_dims + 2,
                                       inputs.shape))
        weight_shape = self._kernel_shape + (inputs.shape[self._channel_index],
                                             self._output_channels)

        fan_in_shape = np.prod(weight_shape[:-1])
        stddev = 1. / np.sqrt(fan_in_shape)
        w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev)
        w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init)

        if self._mask is not None:
            if self._mask.shape != w.shape:
                raise ValueError(
                    "Mask needs to have the same shape as weights. "
                    "Shapes are: {}, {}".format(self._mask.shape, w.shape))
            w *= self._mask
        result = lax.conv_general_dilated(inputs,
                                          w,
                                          self._stride,
                                          self._padding,
                                          lhs_dilation=self._lhs_dilation,
                                          rhs_dilation=self._kernal_dilation,
                                          dimension_numbers=self._dn)
        if self._with_bias:
            if self._channel_index == -1:
                bias_shape = (self._output_channels, )
            else:
                bias_shape = (
                    self._output_channels, ) + (1, ) * self._num_spatial_dims
            b = base.get_parameter("b",
                                   bias_shape,
                                   inputs.dtype,
                                   init=self._b_init)
            result = result + b
        return result
예제 #6
0
    def __init__(self,
                 vocab_size=None,
                 embed_dim=None,
                 embedding_matrix=None,
                 w_init=None,
                 lookup_style=EmbedLookupStyle.ARRAY_INDEX.name,
                 name=None):
        """Constructs an Embed module.

    Args:
      vocab_size: int or None: the number of unique tokens to embed. If not
        provided, an existing vocabulary matrix from which vocab_size can be
        inferred must be provided as `existing_vocab`.
      embed_dim: int or None. Number of dimensions to assign to each embedding.
        If an existing vocabulary matrix initializes the module, this should not
        be provided as it will be inferred.
      embedding_matrix: A matrix-like object equivalent in size to
        [vocab_size, embed_dim]. If given, it is used as the initial value for
        the embedding matrix and neither vocab_size or embed_dim need be given.
        If they are given, their values are checked to be consistent with the
        dimensions of embedding_matrix.
      w_init: An initializer for the embeddings matrix. As a default,
        embeddings are initialized via a truncated normal distribution.
      lookup_style: One of the enum values of EmbedLookupStyle determining how
        to access the value of the embbeddings given an ID. Regardless the input
        should be a dense array of integer values representing ids. This setting
        changes how internally this module maps those ides to embeddings. The
        result is the same, but the speed and memory tradeoffs are different.
        It default to using numpy-style array indexing. This value is only the
        default for the module, and at any given invocation can be overriden
        in the __call__ method.
      name: string. Name for this module.

    Raise:
      ValueError: If none of embed_dim, embedding_matrix and vocab_size are
        supplied, or if embedding_matrix is supplied and embed_dim or vocab_size
        is not consistent with the supplied matrix.
    """
        super(Embed, self).__init__(name=name)
        if not embedding_matrix and not (vocab_size and embed_dim):
            raise ValueError(
                "hk.Embed must be supplied either with an initial `embedding_matrix` "
                "or with `embed_dim` and `vocab_size`.")
        if embedding_matrix:
            embedding_matrix = jnp.asarray(embedding_matrix)
            if vocab_size and embedding_matrix.shape[0] != vocab_size:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `vocab_size` of {vs} "
                    "was not consistent with its shape {emb_shape}.".format(
                        vs=vocab_size, emb_shape=embedding_matrix.shape))
            if embed_dim and embedding_matrix.shape[1] != embed_dim:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `embed_dim` of {ed} "
                    "was not consistent with its shape {emb_shape}.".format(
                        ed=embed_dim, emb_shape=embedding_matrix.shape))
            self._embedding = base.get_parameter(
                "embeddings",
                shape=embedding_matrix.shape,
                init=lambda _, __: embedding_matrix)
        else:
            w_init = w_init or hk_init.TruncatedNormal()
            self._embedding = base.get_parameter("embeddings",
                                                 shape=[vocab_size, embed_dim],
                                                 init=w_init)

        self._vocab_size = vocab_size or embedding_matrix.shape[0]
        self._embed_dim = embed_dim or embedding_matrix.shape[1]
        self._lookup_style = lookup_style