Exemplo n.º 1
0
  def __call__(self, x):
    B, H, W, C = x.shape  # pylint: disable=invalid-name,unused-variable
    assert C % self.num_heads == 0
    head_dim = C // self.num_heads

    h = Normalize(name='norm')(x)

    assert h.shape == (B, H, W, C)
    h = h.reshape(B, H * W, C)
    q = nn.DenseGeneral(features=(self.num_heads, head_dim), name='q')(h)
    k = nn.DenseGeneral(features=(self.num_heads, head_dim), name='k')(h)
    v = nn.DenseGeneral(features=(self.num_heads, head_dim), name='v')(h)
    assert q.shape == k.shape == v.shape == (B, H * W, self.num_heads, head_dim)
    h = nn.dot_product_attention(query=q, key=k, value=v)
    assert h.shape == (B, H * W, self.num_heads, head_dim)
    h = nn.DenseGeneral(
        features=C,
        axis=(-2, -1),
        kernel_init=nn.initializers.zeros,
        name='proj_out')(
            h)
    assert h.shape == (B, H * W, C)
    h = h.reshape(B, H, W, C)
    assert h.shape == x.shape
    return x + h
Exemplo n.º 2
0
    def __call__(self, language, vision, hidden):
        input_q = jnp.concatenate([language, hidden], axis=-1)

        query = nn.DenseGeneral(features=(self.num_heads, self.head_dim),
                                name="query")(input_q)
        key = nn.DenseGeneral(features=(self.num_heads, self.head_dim),
                              name="key")(language)
        value = nn.DenseGeneral(features=(self.num_heads, self.head_dim),
                                name="memory_value")(vision)

        x = nn.dot_product_attention(query, key, value)

        out = nn.DenseGeneral(features=self.out_features,
                              axis=(-2, -1),
                              name="out")(x)
        return out
Exemplo n.º 3
0
 def __call__(self,
              inputs_q: Array,
              inputs_kv: Array,
              mask: Optional[Array] = None):
     query = self.dense(name="query")(inputs_q)
     key = self.dense(name="key")(inputs_kv)
     value = self.dense(name="value")(inputs_kv)
     if mask is not None:
         attention_bias = lax.select(
             mask > 0,
             jnp.full(mask.shape, 0).astype(self.dtype),
             jnp.full(mask.shape, -1e10).astype(self.dtype))
     else:
         attention_bias = None
     dropout_rng = None
     if not self.deterministic and self.dropout_rate > 0:
         dropout_rng = self.make_rng("dropout")
     x = nn.attention.dot_product_attention(
         query,
         key,
         value,
         bias=attention_bias,
         dropout_rng=dropout_rng,
         dropout_rate=self.dropout_rate,
         deterministic=self.deterministic,
         dtype=self.dtype)
     output = nn.DenseGeneral(features=self.out_features,
                              axis=(-2, -1),
                              dtype=self.dtype,
                              name="out")(x)
     return output
Exemplo n.º 4
0
 def setup(self):
     self.hidden_layers = [
         nn.Dense(feats) for feats in self.hidden_features
     ]
     self.output_layer = nn.DenseGeneral((self.Z_dim, self.S_dim))
     if self.learn_prior:
         self.prior = self.param("prior", nn.zeros, (self.Z_dim, ))
Exemplo n.º 5
0
    def test_dense_general_batch_dim(self):
        rng = dict(params=random.PRNGKey(0))
        x = jnp.ones((2, 1, 3, 5))

        state = {'counter': 0.}

        def _counter_init(rng, shape, dtype, state):
            del rng, dtype
            state['counter'] += 1.
            return jnp.full(shape, state['counter'])

        counter_init = functools.partial(_counter_init, state=state)

        dg_module = nn.DenseGeneral(
            features=7,
            axis=(3, -2),
            batch_dims=0,
            bias_init=initializers.ones,
            kernel_init=counter_init,
        )
        y, _ = dg_module.init_with_output(rng, x)
        target = np.concatenate(
            [np.full((1, 1, 7), 16.),
             np.full((1, 1, 7), 31.)], axis=0)
        np.testing.assert_allclose(y, target)
Exemplo n.º 6
0
    def __call__(self,
                 input_ids,
                 type_ids,
                 masked_lm_positions,
                 masked_lm_labels,
                 masked_lm_weights,
                 next_sentence_labels,
                 deterministic=False):
        """Applies pre-training model on inputs.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      masked_lm_positions: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] indices
        indicating which inputs are masked.
      masked_lm_labels: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] true labels
        for masked inputs.
      masked_lm_weights: <float>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] relative
        weighting for masked inputs.
      next_sentence_labels: <int>[BATCH_SIZE] Labels for next sentence
        prediction task.
      deterministic: Whether to apply dropout to input.

    Returns:
      Loss and metrics for given inputs.
    """
        encoder_output = EncoderModel(self.config, name="encoder")(
            input_ids, type_ids, deterministic=deterministic)

        masked_lm_output = layers.gather(encoder_output.sequence_output,
                                         masked_lm_positions)
        masked_lm_output = nn.DenseGeneral(
            self.config.d_emb,
            use_bias=True,
            dtype=self.config.dtype,
            kernel_init=default_kernel_init,
            name="predictions_dense")(masked_lm_output)
        masked_lm_output = nn.gelu(masked_lm_output)
        masked_lm_output = nn.LayerNorm(
            epsilon=LAYER_NORM_EPSILON,
            dtype=self.config.dtype,
            name="predictions_layer_norm")(masked_lm_output)
        masked_lm_logits = layers.OutputProjection(
            kernel=self._get_embedding_table(),
            name="predictions_output")(masked_lm_output)

        next_sentence_logits = layers.OutputProjection(
            n_out=2, kernel_init=default_kernel_init,
            name="classification")(encoder_output.pooled_output)

        return _compute_pretraining_metrics(masked_lm_logits,
                                            next_sentence_logits,
                                            masked_lm_labels,
                                            masked_lm_weights,
                                            next_sentence_labels)
Exemplo n.º 7
0
 def test_dense_general_two_out(self):
   rng = dict(params=random.PRNGKey(0))
   x = jnp.ones((1, 3))
   dg_module = nn.DenseGeneral(
       features=(2, 2),
       kernel_init=initializers.ones,
       bias_init=initializers.ones,
   )
   y, _ = dg_module.init_with_output(rng, x)
   np.testing.assert_allclose(y, np.full((1, 2, 2), 4.))
Exemplo n.º 8
0
    def __call__(self, x):
        B, H, W, C = x.shape  # pylint: disable=invalid-name,unused-variable
        assert C % self.num_heads == 0

        if self.mode == 'row':
            axis = (2, )  # Select only width axis.
        elif self.mode == 'column':
            axis = (1, )  # Select only height axis.
        elif self.mode == 'full':
            axis = (1, 2)  # Select both axes.
        else:
            raise ValueError()

        h = Normalize(name='norm')(x)
        if self.num_heads == 1:
            q = nn.Dense(features=C, name='q')(h)
            k = nn.Dense(features=C, name='k')(h)
            v = nn.Dense(features=C, name='v')(h)
            h = unet_utils.dot_product_attention(q[:, :, :, None, :],
                                                 k[:, :, :, None, :],
                                                 v[:, :, :, None, :],
                                                 axis=axis)[:, :, :, 0, :]
            h = nn.Dense(features=C,
                         kernel_init=nn.initializers.zeros,
                         name='proj_out')(h)
        else:
            head_dim = C // self.num_heads
            q = nn.DenseGeneral(features=(self.num_heads, head_dim),
                                name='q')(h)
            k = nn.DenseGeneral(features=(self.num_heads, head_dim),
                                name='k')(h)
            v = nn.DenseGeneral(features=(self.num_heads, head_dim),
                                name='v')(h)
            assert q.shape == k.shape == v.shape == (B, H, W, self.num_heads,
                                                     head_dim)
            h = unet_utils.dot_product_attention(q, k, v, axis=axis)
            h = nn.DenseGeneral(features=C,
                                axis=(-2, -1),
                                kernel_init=nn.initializers.zeros,
                                name='proj_out')(h)

        assert h.shape == x.shape
        return x + h
Exemplo n.º 9
0
 def test_dense_general_batch_dim_raises(self):
   rng = dict(params=random.PRNGKey(0))
   x = jnp.ones((1, 3, 2, 5))
   with self.assertRaises(ValueError):
     dg_module = nn.DenseGeneral(
         features=4,
         batch_dims=(0, 2),
         kernel_init=initializers.ones,
         bias_init=initializers.ones,
     )
     dg_module.init_with_output(rng, x)
Exemplo n.º 10
0
    def __call__(self, inputs, is_training: bool):
        assert len(self.strides) == 3
        assert inputs.ndim == 3
        q_strides, k_strides, v_strides = self.strides
        b, l, c = inputs.shape
        out_ch = self.out_ch if self.out_ch is not None else c
        spatial_ch = int(jnp.ceil(jnp.sqrt(l)))
        inputs = jnp.pad(inputs, ((0, 0), (0, spatial_ch**2 - l), (0, 0)))
        inputs = rearrange(inputs, 'b (H W) c -> b H W c', W=spatial_ch)

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * self.head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype,
                            precision=self.precision,
                            kernel_init=self.kernel_init,
                            bias_init=self.bias_init)

        query = conv_proj(strides=q_strides)(inputs, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs, is_training=is_training)

        query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads)

        query = query / jnp.sqrt(self.head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        if (self.num_heads * self.head_ch) == self.out_ch:
            output = rearrange(attn_scores, '... q h d -> ... q (h d)')
        else:
            output = nn.DenseGeneral(features=self.out_ch,
                                     axis=(-2, -1),
                                     use_bias=self.use_bias,
                                     dtype=self.dtype,
                                     precision=self.precision,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init)(attn_scores)
        return output
Exemplo n.º 11
0
  def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr):
    rng = dict(params=random.PRNGKey(0))
    x = jnp.ones((16, 8, 9, 10))

    dg_module = nn.DenseGeneral(
        features=(11, 12),
        axis=axis,
        batch_dims=batch_dims,
        bias_init=initializers.ones,
        kernel_init=initializers.normal(),
    )
    y, initial_params = dg_module.init_with_output(rng, x)
    target = np.einsum(einsum_expr, x, initial_params['params']['kernel']) + 1.
    np.testing.assert_allclose(y, target, atol=1e-6)
Exemplo n.º 12
0
  def __call__(self, x):
    B, H, W, C = x.shape  # pylint: disable=invalid-name,unused-variable

    if self.head_dim is None:
      assert self.num_heads is not None
      assert C % self.num_heads == 0
      num_heads = self.num_heads
      head_dim = C // num_heads
    else:
      assert self.num_heads is None
      assert C % self.head_dim == 0
      head_dim = self.head_dim
      num_heads = C // head_dim

    h = Normalize(name='norm')(x)

    assert h.shape == (B, H, W, C)
    h = h.reshape(B, H * W, C)
    q = nn.DenseGeneral(features=(num_heads, head_dim), name='q')(h)
    k = nn.DenseGeneral(features=(num_heads, head_dim), name='k')(h)
    v = nn.DenseGeneral(features=(num_heads, head_dim), name='v')(h)
    assert q.shape == k.shape == v.shape == (B, H * W, num_heads, head_dim)
    h = nn.dot_product_attention(query=q, key=k, value=v)
    assert h.shape == (B, H * W, num_heads, head_dim)
    h = nn.DenseGeneral(
        features=C,
        axis=(-2, -1),
        kernel_init=nn.initializers.zeros,
        name='proj_out')(h)
    assert h.shape == (B, H * W, C)
    h = h.reshape(B, H, W, C)
    assert h.shape == x.shape
    logging.info(
        '%s: x=%r num_heads=%d head_dim=%d',
        self.name, x.shape, num_heads, head_dim)
    return x + h
Exemplo n.º 13
0
 def __call__(self, x):
     u, v = np.split(x, 2, axis=-1)
     v = nn.normalization.LayerNorm()(v)
     dims = v.shape[1:-1]
     axes = np.arange(v.ndim)[1:-1]
     v = np.moveaxis(v, [0, v.ndim - 1, *axes], np.arange(v.ndim))
     general = nn.DenseGeneral(
         features=dims,
         axis=tuple(np.arange(v.ndim)[2:]),
         kernel_init=nn.initializers.variance_scaling(
             0.1, 'fan_in',
             'truncated_normal'),  # near-zero projection matrix
         bias_init=nn.initializers.ones)
     v = general(v)
     v = np.moveaxis(v, np.arange(v.ndim), [0, v.ndim - 1, *axes])
     return u * v
Exemplo n.º 14
0
  def test_dense_is_dense_general(self):
    x = jax.random.normal(random.PRNGKey(0), (5, 3))
    dense_module = nn.Dense(
        features=4,
        use_bias=True,
        bias_init=initializers.normal(),
    )
    y1, _ = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x)
    dg_module = nn.DenseGeneral(
        features=4,
        use_bias=True,
        bias_init=initializers.normal(),
    )
    y2, _ = dg_module.init_with_output(dict(params=random.PRNGKey(1)), x)

    np.testing.assert_allclose(y1, y2)
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == inputs_kv.ndim == 3

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype)

        query = dense(name='queries')(inputs_q)
        key = dense(name='keys')(inputs_kv)
        value = dense(name='values')(inputs_kv)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query,
                                  key)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights, value)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype)(attn_scores)

        output = nn.Dropout(rate=self.out_dropout_rate)(
            output, deterministic=not is_training)
        return output
Exemplo n.º 16
0
    def __call__(self, inputs):
        cfg = self.config
        assert inputs.ndim == 3

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(cfg.num_heads, cfg.dim_head),
                        use_bias=False,
                        kernel_init=cfg.kernel_init,
                        precision=cfg.precision)

        query, key, value = (dense(dtype=cfg.dtype)(inputs),
                             dense(dtype=cfg.dtype)(inputs),
                             dense(dtype=cfg.dtype)(inputs))

        query = query / jnp.sqrt(cfg.dim_head).astype(cfg.dtype)

        attn_weights = jnp.einsum('b q h d, b k h d -> b h q k',
                                  query,
                                  key,
                                  precision=cfg.precision)
        attn_weights = nn.softmax(attn_weights).astype(cfg.dtype)

        if cfg.shared_theta:
            attn_weights = self.theta_transform(attn_weights)
        else:
            attn_weights = ThetaTransform(config=cfg)(attn_weights)

        attn_weights = nn.LayerNorm()(attn_weights)

        out = jnp.einsum('b h q k, b q h d -> b k h d',
                         attn_weights,
                         value,
                         precision=cfg.precision)

        if (cfg.num_heads * cfg.dim_head) != cfg.emb_dim:
            out = nn.DenseGeneral(features=cfg.emb_dim,
                                  axis=(-2, -1),
                                  dtype=cfg.dtype,
                                  precision=cfg.precision,
                                  kernel_init=cfg.kernel_init,
                                  bias_init=cfg.bias_init)(out)
        else:
            out = rearrange(out, 'b k h d -> b k (h d)')

        return out
Exemplo n.º 17
0
    def setup(self):
        """Setup Model."""

        precision = self.epipolar_config.precision

        #--------------------------------------------------------------------
        # Light Field Object
        self.lightfield = lf_utils.get_lightfield_obj(self.lf_config)

        #--------------------------------------------------------------------
        # Projector
        self.projector = projector.RayProjector(self.epipolar_config)

        #--------------------------------------------------------------------
        # Transformers
        self.vision_transfromer = transformer.SelfAttentionTransformer(
            self.epipolar_transformer_config)
        self.epipolar_transformer = transformer.SelfAttentionTransformer(
            self.epipolar_transformer_config)
        self.view_transformer = transformer.SelfAttentionTransformer(
            self.view_transformer_config)

        #--------------------------------------------------------------------
        # Layer to predict the attention weight for each point on epipolar line
        self.epipolar_correspondence = nn.DenseGeneral(1)
        self.view_correspondence = nn.DenseGeneral(1)

        #--------------------------------------------------------------------
        # Color prediction (Optional)
        self.rgb_dense = nn.Dense(self.render_config.num_rgb_channels,
                                  precision=precision)

        #--------------------------------------------------------------------
        # Layer to transform key and query to same dim for conatenation
        self.key_transform = nn.DenseGeneral(
            self.epipolar_transformer_config.qkv_params, precision=precision)
        self.query_transform = nn.DenseGeneral(
            self.epipolar_transformer_config.qkv_params, precision=precision)

        # Layer to transform key and query to same dim for conatenation
        self.key_transform2 = nn.DenseGeneral(
            self.view_transformer_config.qkv_params)
        self.query_transform2 = nn.DenseGeneral(
            self.view_transformer_config.qkv_params, precision=precision)

        self.conv_layer = efficient_conv.SplitConvModel(
            features=self.epipolar_config.conv_feature_dim,
            kernel_size=self.epipolar_config.patch_size,
        )
        self.feature_activation = nn.elu

        self.mean = einshape("x->111x", jnp.array([0.485, 0.456, 0.406]))
        self.std = einshape("x->111x", jnp.array([0.229, 0.224, 0.225]))
Exemplo n.º 18
0
  def __call__(self, token_inputs,
               num_experts):
    """Applies RouterWeights module.

    Args:
      token_inputs: Flattened batch of tokens with shape <float>[NUM_GROUPS,
        TOKENS_PER_GROUP, HIDDEN_DIM].
      num_experts: Number of experts.

    Returns:
      Router logits with shape <float>[NUM_GROUPS, TOKENS_PER_GROUP,
      NUM_EXPERTS].
    """
    return nn.DenseGeneral(
        num_experts,
        use_bias=self.use_bias,
        dtype=self.dtype,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init,
        precision=self.precision)(
            token_inputs)
Exemplo n.º 19
0
    def setup(self):
        """Setup Model."""

        precision = self.epipolar_config.precision
        # Light Field Object
        self.lightfield = lf_utils.get_lightfield_obj(self.lf_config)
        # Projector
        self.projector = projector.RayProjector(self.epipolar_config)
        # Transformers
        self.epipolar_transformer = transformer.SelfAttentionTransformer(
            self.epipolar_transformer_config)
        self.view_transformer = transformer.SelfAttentionTransformer(
            self.view_transformer_config)

        # Layer to predict the attention weight for each point on epipolar line
        self.epipolar_correspondence = nn.DenseGeneral(1)
        self.view_correspondence = nn.DenseGeneral(1)

        self.rgb_dense = nn.Dense(self.render_config.num_rgb_channels,
                                  precision=precision)

        # Layer to transform key and query to same dim for conatenation
        self.key_transform = nn.DenseGeneral(
            self.epipolar_transformer_config.qkv_params, precision=precision)
        self.query_transform = nn.DenseGeneral(
            self.epipolar_transformer_config.qkv_params, precision=precision)

        # Layer to transform key and query to same dim for conatenation
        self.key_transform2 = nn.DenseGeneral(
            self.view_transformer_config.qkv_params)
        self.query_transform2 = nn.DenseGeneral(
            self.view_transformer_config.qkv_params, precision=precision)

        # Optinally have a learned embedding per camera view
        if self.epipolar_config.use_learned_embedding:
            self.camera_embedding = transformer.LearnedPositionEmbs(
                max_length=self.epipolar_config.num_train_views, )

        if self.epipolar_config.use_conv_features:
            self.conv_layer1 = efficient_conv.SplitConvModel(
                features=self.epipolar_config.conv_feature_dim,
                kernel_size=self.epipolar_config.ksize1,
            )
            self.feature_activation = nn.elu

        # Set fill value for background rays
        self.fill_value = 1. if self.render_config.white_bkgd else 0.
Exemplo n.º 20
0
    def setup(self):
        """Initializes encoder with config-dependent mixing layer."""
        encoder_blocks = []  # Attributes are immutable so use temporary list
        for layer in range(self.config.num_layers):
            if self._is_attention_layer(layer):
                attention_sublayer = layers.AttentionLayer(
                    num_heads=self.config.num_heads,
                    d_model=self.config.d_model,
                    dtype=self.config.dtype,
                    kernel_init=default_kernel_init,
                    bias_init=default_bias_init,
                    dropout_rate=self.config.mixing_dropout_rate,
                    pad_id=self.config.pad_id,
                    name=f"attention_{layer}")
                mixing_sublayer = None
            else:
                attention_sublayer = None
                mixing_sublayer = self._init_mixing_sublayer(layer)

            feed_forward_sublayer = self._init_feed_forward_sublayer(layer)

            encoder_blocks.append(
                layers.EncoderBlock(
                    mixing_sublayer=mixing_sublayer,
                    attention_sublayer=attention_sublayer,
                    feed_forward_sublayer=feed_forward_sublayer,
                    name=f"encoder_{layer}"))
        self.encoder_blocks = encoder_blocks

        self.embedder = layers.EmbeddingLayer(config=self.config,
                                              name="embedder")

        self.pooler = nn.DenseGeneral(self.config.d_model,
                                      use_bias=True,
                                      dtype=self.config.dtype,
                                      kernel_init=default_kernel_init,
                                      name="pooler")
    def __call__(self, inputs, is_training: bool):
        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, self.head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        if self.is_lca:
            q_inputs = jnp.expand_dims(inputs[:, -1, :], axis=1)
        else:
            q_inputs = inputs

        query = dense(name='queries')(q_inputs)
        key = dense(name='keys')(inputs)
        value = dense(name='values')(inputs)

        query = query / jnp.sqrt(self.head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)
        if self.talking_heads:
            pre_softmax_transform = self.param(
                'pre_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        if is_training and self.dropout_rate > 0.:
            keep_prob = 1.0 - self.dropout_rate
            dropout_rng = self.make_rng('dropout')
            keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
            multiplier = (keep.astype(attn_weights.dtype) /
                          jnp.asarray(keep_prob, dtype=self.dtype))
            attn_weights = attn_weights * multiplier

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        if (self.num_heads * self.head_ch) == self.out_ch:
            output = rearrange(attn_scores, '... q h d -> ... q (h d)')
        else:
            output = nn.DenseGeneral(features=self.out_ch,
                                     axis=(-2, -1),
                                     use_bias=self.use_bias,
                                     dtype=self.dtype,
                                     precision=self.precision,
                                     kernel_init=self.kernel_init,
                                     bias_init=self.bias_init)(attn_scores)
        return output
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == 4
        assert inputs_kv.ndim == 4
        assert len(self.strides) == 3
        q_strides, k_strides, v_strides = self.strides

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype)

        query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs_kv,
                                             is_training=is_training)

        query = rearrange(query,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query,
                                  key)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights, value)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype)(attn_scores)

        output = nn.Dropout(rate=self.out_dropout_rate)(
            output, deterministic=not is_training)
        return output
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert len(self.strides) == 3
        q_strides, k_strides, v_strides = self.strides

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        inputs_q = zero_pad_and_reshape(inputs_q)
        inputs_kv = zero_pad_and_reshape(inputs_kv)

        conv_proj = partial(ConvProjectionBlock,
                            out_ch=self.num_heads * head_ch,
                            kernel_size=self.kernel_size,
                            use_bias=self.use_bias,
                            bn_momentum=self.bn_momentum,
                            bn_epsilon=self.bn_epsilon,
                            dtype=self.dtype,
                            precision=self.precision,
                            kernel_init=self.kernel_init,
                            bias_init=self.bias_init)

        query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training)
        key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training)
        value = conv_proj(strides=v_strides)(inputs_kv,
                                             is_training=is_training)

        query = rearrange(query,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)
        key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads)
        value = rearrange(value,
                          'b H W (h d) -> b (H W) h d',
                          h=self.num_heads)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)

        if self.talking_heads:
            pre_softmax_transform = self.param(
                'pre_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(attn_scores)

        output = nn.Dropout(rate=self.out_drop_rate)(
            output, deterministic=not is_training)

        return output
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == inputs_kv.ndim == 3

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        query = dense(name='queries')(inputs_q)
        key = dense(name='keys')(inputs_kv)
        value = dense(name='values')(inputs_kv)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k',
                                  query,
                                  key,
                                  precision=self.precision)
        if self.talking_heads:
            pre_softmax_transform = self.param('pre_softmax', self.kernel_init,
                                               (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... h q k, h i -> ... i q k',
                                      attn_weights,
                                      pre_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            post_softmax_transform = self.param(
                'post_softmax', self.kernel_init,
                (self.num_heads, self.num_heads))
            attn_weights = jnp.einsum('... i q k, i h -> ... h q k',
                                      attn_weights,
                                      post_softmax_transform,
                                      precision=self.precision)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights,
                                 value,
                                 precision=self.precision)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(attn_scores)

        output = nn.Dropout(rate=self.out_drop_rate)(
            output, deterministic=not is_training)

        return output
Exemplo n.º 25
0
    def exec_op(self, op, input_values, deterministic, training, **_):
        """Executes an op according to the normal concrete semantics."""
        input_kwargs: Dict[str, Any] = op.input_kwargs
        op_kwargs: Dict[str, Any] = op.op_kwargs
        op_type = op.type
        if "name" not in op_kwargs:
            raise ValueError("Op kwargs must contain a name.")
        op_name = op_kwargs["name"]

        if op_type == OpType.NONE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [lax.stop_gradient(input_value)]

        elif op_type == OpType.IDENTITY:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [input_value]

        # nn.linear

        elif op_type == OpType.DENSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.Dense(**op_kwargs)(input_value)]

        elif op_type == OpType.DENSE_GENERAL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert 2 <= len(op_kwargs) <= 7
            output_values = [nn.DenseGeneral(**op_kwargs)(input_value)]

        elif op_type == OpType.CONV:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs

            ks = op_kwargs["kernel_size"]
            if isinstance(ks, int):
                op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2)

            output_values = [nn.Conv(**op_kwargs)(input_value)]

        # others

        elif op_type == OpType.MUL:
            assert len(input_values) == 2
            assert not input_kwargs
            assert len(op_kwargs) == 1  # name
            output_values = [input_values[0] * input_values[1]]

        elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]:
            assert len(op_kwargs) == 1  # name

            input_value = input_values[0]
            if "layer_drop_rate" in input_kwargs:
                assert len(input_kwargs) == 1
                survival_rate = 1 - input_kwargs["layer_drop_rate"]
                if survival_rate == 1.0 or deterministic:
                    pass
                else:
                    # Reuse dropout's rng stream.
                    rng = self.make_rng("dropout")
                    mask_shape = [input_value.shape[0]
                                  ] + [1] * (input_value.ndim - 1)
                    mask = random.bernoulli(rng,
                                            p=survival_rate,
                                            shape=mask_shape)
                    mask = jnp.tile(mask, [1] + list(input_value.shape[1:]))
                    input_value = lax.select(mask, input_value / survival_rate,
                                             jnp.zeros_like(input_value))
            else:
                assert not input_kwargs
                assert op_type == OpType.ADD

            if op_type == OpType.ADD:
                assert len(input_values) == 2
                output_values = [input_value + input_values[1]]
            else:
                assert len(input_values) == 1
                output_values = [input_value]

        elif op_type == OpType.SCALAR_MUL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            if "const" in input_kwargs:
                c = input_kwargs["const"]
            else:
                c = 1 / jnp.sqrt(input_values[0].shape[-1])
            output_values = [input_values[0] * c]

        elif op_type == OpType.SCALAR_ADD:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            assert "const" in input_kwargs
            c = input_kwargs["const"]
            output_values = [input_values[0] + c]

        elif op_type == OpType.DOT_GENERAL:
            assert len(input_values) == 2
            assert 0 < len(input_kwargs) <= 3
            assert len(op_kwargs) == 1  # name
            output_values = [
                lax.dot_general(input_values[0], input_values[1],
                                **input_kwargs)
            ]

        elif op_type == OpType.EINSUM:
            assert len(input_values) == 2
            assert len(input_kwargs) == 1
            assert "sum" in input_kwargs
            output_values = [
                jnp.einsum(input_kwargs["sum"], input_values[0],
                           input_values[1])
            ]

        # nn.attention

        elif op_type == OpType.SELF_ATTENTION:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [
                nn.SelfAttention(**op_kwargs,
                                 deterministic=deterministic)(input_value)
            ]

        # nn.activation

        elif op_type in [
                OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID
        ]:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            fn = {
                OpType.RELU: nn.relu,
                OpType.GELU: nn.gelu,
                OpType.SWISH: nn.swish,
                OpType.SIGMOID: nn.sigmoid
            }[op_type]
            output_values = [fn(input_value)]

        elif op_type == OpType.SOFTMAX:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [nn.softmax(input_value, **input_kwargs)]

        # nn.normalization

        elif op_type == OpType.BATCH_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            add_kwargs = {}
            if "use_running_average" not in input_kwargs:
                add_kwargs = {"use_running_average": not training}
            else:
                add_kwargs = {}
            output_values = [
                nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs,
                                          **add_kwargs)
            ]

        elif op_type == OpType.LAYER_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.LayerNorm(**op_kwargs)(input_value)]

        elif op_type == OpType.GROUP_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.GroupNorm(**op_kwargs)(input_value)]

        # reshape operators

        elif op_type == OpType.RESHAPE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert 0 < len(input_kwargs) < 3
            new_shape = input_kwargs.pop("new_shape")
            if new_shape[0] == "B":
                new_shape = (input_value.shape[0], ) + new_shape[1:]
            output_values = [
                jnp.reshape(input_value, new_shape, **input_kwargs)
            ]

        elif op_type == OpType.FLATTEN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            new_shape = (input_value.shape[0], -1)
            output_values = [jnp.reshape(input_value, new_shape)]

        elif op_type == OpType.TRANSPOSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) == 1
            assert len(op_kwargs) == 1  # name
            output_values = [jnp.transpose(input_value, **input_kwargs)]

        # nn.stochastic

        elif op_type == OpType.DROPOUT:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [
                nn.Dropout(**op_kwargs)(input_value,
                                        deterministic=deterministic,
                                        **input_kwargs)
            ]

        # nn.pooling

        elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL:
            op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs

            ws = input_kwargs["window_shape"]
            if isinstance(ws, int):
                ws = [ws] * (input_value.ndim - 2)
            new_ws = []
            for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]):
                if window_dim_shape == 0:
                    new_ws.append(dim_shape)
                else:
                    new_ws.append(window_dim_shape)
            input_kwargs["window_shape"] = tuple(new_ws)

            if "strides" in input_kwargs:
                s = input_kwargs["strides"]
                if isinstance(s, int):
                    input_kwargs["strides"] = (s, ) * (input_value.ndim - 2)

            output_values = [op_fn(input_value, **input_kwargs)]

        elif op_type == OpType.MEAN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs
            output_values = [jnp.mean(input_value, **input_kwargs)]

        # new param

        elif op_type == OpType.PARAM:
            assert not input_values
            assert 0 < len(input_kwargs) <= 2
            init_fn = input_kwargs.pop("init_fn")

            init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs)
            output_values = [self.param(op_name, init_fn_with_kwargs)]

        else:
            raise ValueError(f"op_type {op_type} not supported...")

        return output_values
Exemplo n.º 26
0
 def setup(self):
     self.hidden_layers = [
         nn.Dense(feats) for feats in self.hidden_features
     ]
     self.output_layer = nn.DenseGeneral((self.S_dim, self.S_dim))
Exemplo n.º 27
0
 def __call__(self, x):
     return nn.DenseGeneral(features=6, axis=(0, 1),
                            name='dense')(x)