Beispiel #1
0
    def __call__(self, x, key=None, sample=False, MPO=False):
        x = nn.Dense(features=200)(x)
        x = nn.LayerNorm()(x)
        x = nn.tanh(x)
        x = nn.Dense(features=200)(x)
        x = nn.elu(x)
        x = nn.Dense(features=2 * self.action_dim)(x)

        mu, log_sig = jnp.split(x, 2, axis=-1)
        log_sig = jnp.clip(log_sig, self.log_sig_min, self.log_sig_max)

        if MPO:
            return mu, log_sig

        if not sample:
            return self.max_action * nn.tanh(mu), log_sig
        else:
            sig = jnp.exp(log_sig)
            pi = mu + random.normal(key, mu.shape) * sig
            log_pi = gaussian_likelihood(pi, mu, log_sig)
            pi = nn.tanh(pi)
            log_pi -= jnp.sum(
                jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1, keepdims=True,
            )
            return self.max_action * pi, log_pi
Beispiel #2
0
 def __call__(self, inputs):
     x = inputs
     for feature in self.shared_features:
         x = nn.tanh(nn.Dense(feature)(x))
     x = jnp.repeat(jnp.expand_dims(x, axis=0),
                    repeats=self.n_tasks,
                    axis=0)  # If we batch, can we do without copying data?
     for feature in self.specific_features[:-1]:
         x = nn.tanh(MultiTaskDense(feature, self.n_tasks)(x))
     x = MultiTaskDense(self.specific_features[-1], self.n_tasks)(x)
     return x.squeeze().T
Beispiel #3
0
    def __call__(self, inputs, *, train):
        """Function of shapes [B*R,h,w,c*E] -> [E*B*R,num_classes]."""
        out = {}

        x = inputs

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.
        # TODO(dusenberrymw): Switch to self.sow(.).
        out['stem'] = x

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x = vit.Encoder(name='Transformer', **self.transformer)(x, train=train)
        out['transformed'] = x

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        out['head_input'] = x

        if self.representation_size is not None:
            x = nn.Dense(features=self.representation_size,
                         name='pre_logits')(x)
            out['pre_logits'] = x
            x = nn.tanh(x)
        else:
            x = vit.IdentityLayer(name='pre_logits')(x)
            out['pre_logits'] = x

        # TODO(markcollier): Fix base model without using stop_gradient.
        if self.fix_base_model:
            x = jax.lax.stop_gradient(x)

        # Shape: (batch_size, num_classes * ensemble_size).
        x = nn.Dense(self.num_classes * self.ensemble_size,
                     name='head',
                     kernel_init=nn.initializers.zeros)(x)
        # Shape: (batch_size * ensemble_size, num_classes).
        x = jnp.concatenate(jnp.split(x, self.ensemble_size, axis=-1))
        out['logits'] = x
        return x, out
Beispiel #4
0
 def __call__(self, x):
     x = nn.Dense(features=256)(x)
     x = nn.relu(x)
     x = nn.Dense(features=256)(x)
     x = nn.relu(x)
     x = nn.Dense(features=self.action_dim)(x)
     return self.max_action * nn.tanh(x)
Beispiel #5
0
def get_td_target(
    rng: PRNGSequence,
    state: jnp.ndarray,
    action: jnp.ndarray,
    next_state: jnp.ndarray,
    reward: jnp.ndarray,
    not_done: jnp.ndarray,
    discount: float,
    max_action: float,
    action_dim: int,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
) -> jnp.ndarray:
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, action_dim, max_action, next_state, None, False, True
    )
    next_action = mu + jnp.exp(log_sig) * random.normal(rng, mu.shape)
    next_action = max_action * nn.tanh(next_action)

    target_Q1, target_Q2 = apply_double_critic_model(
        critic_target_params, next_state, next_action, False
    )
    target_Q = jnp.minimum(target_Q1, target_Q2)
    target_Q = reward + not_done * discount * target_Q

    return target_Q
Beispiel #6
0
 def __call__(self, α, β):
     h1 = nn.tanh(
         nn.Dense(self.hiddendim, name="encoder_layer_1")(jnp.hstack([α,
                                                                      β])))
     μ = nn.Dense(self.latentdim, name="encoder_μ_layer_1")(h1)
     logσ2 = nn.Dense(self.latentdim, name="encoder_logσ_layer_1")(h1)
     return μ, logσ2
Beispiel #7
0
 def __call__(self, x):
     out_l1 = nn.softplus(self.group_l1(self.layer1(x)))
     out_l1 = nn.softplus(self.group_l1(self.layer12(x)))
     out_1 = nn.softplus(self.group1(self.down1(out_l1)))
     out_1 = nn.softplus(self.group12(self.down12(out_1)))
     out_2 = nn.softplus(self.group2(self.down2(out_1)))
     out_2 = nn.softplus(self.group22(self.down22(out_2)))
     out_3 = nn.softplus(self.group3(self.down3(out_2)))
     out_3 = nn.softplus(self.group32(self.down32(out_3)))
     out_4 = nn.softplus(self.group4(self.down4(out_3)))
     out_4 = nn.softplus(self.group42(self.down42(out_4)))
     out_latent = nn.softplus(self.group_latent(self.latent(out_4)))
     in_up4 = jnp.concatenate((out_4, out_latent), axis=-1)
     # out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(out_4))))
     out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(in_up4))))
     out_up4 = nn.softplus(self.group_up42(self.up42(out_up4)))
     in_up3 = jnp.concatenate((out_3, out_up4), axis=-1)
     out_up3 = nn.softplus(self.group_up3(self.up3(self.deconv(in_up3))))
     out_up3 = nn.softplus(self.group_up32(self.up32(out_up3)))
     in_up2 = jnp.concatenate((out_2, out_up3), axis=-1)
     out_up2 = nn.softplus(self.group_up2(self.up2(self.deconv(in_up2))))
     out_up2 = nn.softplus(self.group_up22(self.up22(out_up2)))
     in_up1 = jnp.concatenate((out_1, out_up2), axis=-1)
     out_up1 = nn.softplus(self.group_up1(self.up1(self.deconv(in_up1))))
     out_up1 = nn.softplus(self.group_up12(self.up12(out_up1)))
     in_straight1 = jnp.concatenate((out_l1, out_up1), axis=-1)
     out_straight1 = nn.softplus(
         self.group_straight1(self.straight1(in_straight1)))
     out_straight1 = nn.softplus(
         self.group_straight12(self.straight12(out_straight1)))
     return nn.tanh(self.group_straight2(self.straight2(out_straight1)))
Beispiel #8
0
    def __call__(self, keys: Array, mask: Array) -> Array:
        """Applies model  to the input keys and mask.

    Args:
      keys: The inputs for which to compute an attention score. Shape:
        <float32>[batch_size, seq_length, embeddings_size].
      mask: A mask that determinines which values in `keys` are valid. Only
        values for which the mask is True will get non-zero attention scores.
        <bool>[batch_size, seq_length].

    Returns:
      The normalized attention scores. <float32>[batch_size, seq_length].
    """
        hidden = nn.Dense(self.hidden_size, name='keys', use_bias=False)(keys)
        energy = nn.tanh(hidden)
        scores = nn.Dense(1, name='energy', use_bias=False)(energy)
        scores = scores.squeeze(
            -1)  # New shape: <float32>[batch_size, seq_len].
        scores = jnp.where(mask, scores,
                           -jnp.inf)  # Using exp(-inf) = 0 below.
        scores = nn.softmax(scores, axis=-1)

        # Captures the scores if 'intermediates' is mutable, otherwise does nothing.
        self.sow('intermediates', 'attention', scores)

        return scores
Beispiel #9
0
    def __call__(self, x):

        #x = nn.tanh(nn.Dense(features=128)(x))
        x = nn.tanh(nn.Dense(features=64)(x))
        x = nn.Dense(features=1)(x)
        sp = -nn.softplus(x)

        return jnp.concatenate([sp, sp + x], -1)  #p(z|x), 1-p(z|x)
 def __call__(self, hidden_states, deterministic=True):
     hidden_states = hidden_states[:, 0, :]  # take <s> token (equiv. to [CLS])
     hidden_states = self.dropout(hidden_states, deterministic=deterministic)
     hidden_states = self.dense(hidden_states)
     hidden_states = nn.tanh(hidden_states)
     hidden_states = self.dropout(hidden_states, deterministic=deterministic)
     hidden_states = self.out_proj(hidden_states)
     return hidden_states
Beispiel #11
0
    def __call__(self, x):

        l1 = nn.relu(self.group_l1(self.layer1(x)))
        unet = self.mid(l1)
        cat = jnp.concatenate((l1, unet), axis=-1)

        l2 = nn.relu(self.group_straight1(self.straight1(cat)))
        out = nn.tanh(self.straight2(l2))
        return out
 def __call__(self, hidden_states):
     cls_token = hidden_states[:, 0]
     out = nn.Dense(
         hidden_states.shape[-1],
         kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
         name="dense",
         dtype=self.dtype,
     )(cls_token)
     return nn.tanh(out)
Beispiel #13
0
    def __call__(self, x):
        # Images are stored in the replay buffer as uint8.
        x = x.astype(jnp.float32) / 255.0

        # Flatten the last dimension (normally to deal with stacked rgb frames)
        if len(x.shape) > 3:
            x = x.reshape((*x.shape[:2], -1))

        kernel_init = nn.initializers.orthogonal()
        x = nn.Conv(features=32,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    kernel_init=kernel_init)(x)
        x = nn.relu(x)
        x = nn.Conv(features=32,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    kernel_init=kernel_init)(x)
        x = nn.relu(x)
        x = nn.Conv(features=32,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    kernel_init=kernel_init)(x)
        x = nn.relu(x)
        x = nn.Conv(features=32,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    kernel_init=kernel_init)(x)
        x = nn.relu(x)
        x = jnp.reshape(x, -1)  # Flatten

        critic_z = nn.Dense(features=50, kernel_init=kernel_init)(x)
        critic_z = nn.LayerNorm()(critic_z)
        critic_z = nn.tanh(critic_z)

        # Only the critic should train the convolution layers, so stop the
        # gradients from the actor.
        actor_z = nn.Dense(features=50,
                           kernel_init=kernel_init)(jax.lax.stop_gradient(x))
        actor_z = nn.LayerNorm()(actor_z)
        actor_z = nn.tanh(actor_z)

        return SACEncoderOutputs(critic_z, actor_z)
Beispiel #14
0
    def __call__(self, state, action, Q1=False):
        state_action = jnp.concatenate([state, action], axis=-1)

        q1 = nn.Dense(features=500)(state_action)
        q1 = nn.LayerNorm()(q1)
        q1 = nn.tanh(q1)
        q1 = nn.Dense(features=500)(q1)
        q1 = nn.elu(q1)
        q1 = nn.Dense(features=1)(q1)

        if Q1:
            return q1

        q2 = nn.Dense(features=500)(state_action)
        q2 = nn.LayerNorm()(q2)
        q2 = nn.tanh(q2)
        q2 = nn.Dense(features=500)(q2)
        q2 = nn.elu(q2)
        q2 = nn.Dense(features=1)(q2)

        return q1, q2
Beispiel #15
0
 def __call__(self, x):
     # need to flatten extra dimensions required by CNN and LSTM
     x = x.squeeze()
     x = nn.Dense(
         features=self.hidden_dim,
         use_bias=False,
         name=f"shallow_fc{1}_model" + str(self.model_num),
     )(x)
     x = nn.tanh(x)
     x = nn.Dense(features=self.out_dim,
                  use_bias=True,
                  name=f"shallow_fc{2}_model" + str(self.model_num))(x)
     return x.squeeze(
     )  # squeeze for consistent shape w/ boundary model output
Beispiel #16
0
    def __call__(self, batch: Dict[str, Array], deterministic: bool):
        encoding, loss_helpers, logging_helpers = self.encoder.forward(
            batch, deterministic)
        cls_encoding = encoding[:, 0, ...]

        if self.apply_mlp:
            cls_encoding = self.mlp(cls_encoding)
            cls_encoding = nn.tanh(cls_encoding)
            cls_encoding = self.dropout(cls_encoding,
                                        deterministic=deterministic)
        classifier_logits = self.linear_classifier(cls_encoding)
        loss_helpers['classifier_logits'] = classifier_logits

        return loss_helpers, logging_helpers
  def __call__(self, images: jnp.ndarray, train: Optional[bool] = None):
    train = nn.module.merge_param("train", self.train, train)
    transformer = self.transformer or {}
    # Convert images to patches.
    x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid)
    # Add "class" token if necessary.
    n, _, c = x.shape
    if self.classifier == "token":
      cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size))
      cls = jnp.tile(cls, [n, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)
    # Encode tokens.
    x, extra_info = BatchEnsembleEncoder(
        train=train, name="BatchEnsembleTransformer", **transformer)(
            x)
    # Reduce tokens to a single vector representation.
    if self.classifier == "token":
      # Take the first token's output as representation as in BERT.
      x = x[:, 0]
    elif self.classifier == "gap":
      # Average all tokens.
      x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1)))  # (1,) or (1, 2)
    elif self.classifier == "map":
      probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c))
      probe = jnp.tile(probe, [n, 1, 1])
      attention = nn.MultiHeadDotProductAttention(
          deterministic=not train,
          num_heads=transformer.get("attention", {}).get("num_heads", 1),
          kernel_init=nn.initializers.xavier_uniform())
      x = attention(inputs_q=probe, inputs_kv=x)
      y = nn.LayerNorm()(x)
      y = patch_transformer_lib.MlpBlock(
          mlp_dim=transformer["mlp_dim"],
          dropout_rate=0,
          deterministic=not train)(y)
      x = (x + y)[:, 0]
    else:
      raise ValueError(f"Unknown classifier: {self.classifier}")

    if self.representation_size is None:
      x = identity.IdentityLayer(name="pre_logits")(x)
    else:
      x = nn.Dense(self.representation_size, name="pre_logits")(x)
      x = nn.tanh(x)

    x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init,
                 name="head")(x)
    return x, extra_info
Beispiel #18
0
    def __call__(self, x, *, train, debug=False):

        fh, fw = self.patches.size
        # Extracting patches and then embedding is in fact a single convolution.
        x = nn.Conv(self.hidden_size, (fh, fw),
                    strides=(fh, fw),
                    padding='VALID',
                    name='embedding')(x)
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype)
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x = Encoder(mlp_dim=self.mlp_dim,
                    num_layers=self.num_layers,
                    num_heads=self.num_heads,
                    dropout_rate=self.dropout_rate,
                    attention_dropout_rate=self.attention_dropout_rate,
                    stochastic_depth=self.stochastic_depth,
                    dtype=self.dtype,
                    nb_x_patches=h,
                    nb_y_patches=w,
                    name='Transformer')(x, train=train)

        if self.classifier in ('token', '0'):
            x = x[:, 0]
        elif self.classifier in ('gap', 'gmp', 'gsp'):
            fn = {
                'gap': jnp.mean,
                'gmp': jnp.max,
                'gsp': jnp.sum
            }[self.classifier]
            x = fn(x, axis=1)

        if self.representation_size is not None:
            x = nn.Dense(self.representation_size, name='pre_logits')(x)
            x = nn.tanh(x)
        else:
            x = nn_layers.IdentityLayer(name='pre_logits')(x)
        x = nn.Dense(self.num_classes,
                     kernel_init=nn.initializers.zeros,
                     name='output_projection')(x)
        return x
Beispiel #19
0
def sample_actions_and_evaluate(
    rng: PRNGSequence,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    batch_size: int,
    action_sample_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    To build our nonparametric policy, q(s, a), we sample `action_sample_size`
    actions from each policy in the batch and evaluate their Q-values.
    """
    # get the policy distribution for each state and sample `action_sample_size`
    # actions from each
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, action_dim, max_action, state, None, False, True
    )
    mu = jnp.expand_dims(mu, axis=1)
    sig = jnp.expand_dims(jnp.exp(log_sig), axis=1)
    sampled_actions = (
        mu + random.normal(rng, (batch_size, action_sample_size, action_dim)) * sig
    )
    sampled_actions = sampled_actions.reshape(
        (batch_size * action_sample_size, action_dim)
    )

    sampled_actions = jax.lax.stop_gradient(sampled_actions)

    states_repeated = jnp.repeat(state, action_sample_size, axis=0)

    # evaluate each of the sampled actions at their corresponding state
    # we keep the `sampled_actions` array unnquashed because we need to calcuate
    # the log probabilities using it, but we pass the squashed actions to the critic
    Q1 = apply_double_critic_model(
        critic_target_params,
        states_repeated,
        max_action * nn.tanh(sampled_actions),
        True,
    )
    Q1 = Q1.reshape((batch_size, action_sample_size))

    Q1 = jax.lax.stop_gradient(Q1)

    return Q1, sampled_actions
Beispiel #20
0
def sample_actions_and_evaluate(
    rng: PRNGSequence,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    batch_size: int,
    action_sample_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    To build our nonparametric policy, q(s, a), we sample `action_sample_size`
    actions from each policy in the batch and evaluate their Q-values.
    """
    state_dim = state.shape[-1]
    # get the policy distribution for each state and sample `action_sample_size`
    # actions from each
    mu, log_sig = apply_gaussian_policy_model(
        actor_target_params, state_dim, max_action, state, None, False, True
    )
    sig = jnp.exp(log_sig)
    sampled_actions = mu + random.normal(rng, (batch_size, action_sample_size)) * sig
    sampled_actions = max_action * nn.tanh(sampled_actions)
    sampled_actions = sampled_actions.reshape(
        (batch_size * action_sample_size, action_dim)
    )

    sampled_actions = jax.lax.stop_gradient(sampled_actions)

    states_repeated = jnp.repeat(state, action_sample_size, axis=0)

    # evaluate each of the sampled actions at their corresponding state
    Q1 = apply_double_critic_model(
        critic_target_params, states_repeated, sampled_actions, True
    )
    Q1 = Q1.reshape((batch_size, action_sample_size))

    Q1 = jax.lax.stop_gradient(Q1)

    return Q1, sampled_actions
Beispiel #21
0
 def __call__(self, hidden_states):
     cls_hidden_state = hidden_states[:, 0]
     cls_hidden_state = self.dense(cls_hidden_state)
     return nn.tanh(cls_hidden_state)
    def __call__(self,
                 images: jnp.ndarray,
                 train: Optional[bool] = None,
                 mean_field_factor: float = -1.,
                 **gp_kwargs):
        train = nn.module.merge_param("train", self.train, train)
        transformer = self.transformer or {}
        # Convert images to patches.
        x = self.patches(images, self.hidden_size, self.patch_size,
                         self.patch_grid)
        # Add "class" token if necessary.
        n, _, c = x.shape
        if self.classifier == "token":
            cls = self.param("cls", nn.initializers.zeros,
                             (1, 1, self.hidden_size))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)
        # Encode tokens.
        x, extra_info = vit_batchensemble.BatchEnsembleEncoder(
            train=train, name="Transformer", **transformer)(x)
        # Reduce tokens to a single vector representation.
        if self.classifier == "token":
            # Take the first token's output as representation as in BERT.
            x = x[:, 0]
        elif self.classifier == "gap":
            # Average all tokens.
            x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1)))  # (1,) or (1, 2)
        elif self.classifier == "map":
            probe = self.param("probe", nn.initializers.xavier_uniform(),
                               (1, 1, c))
            # x may have been subject to tiling, n can be different from x.shape[0].
            probe = jnp.tile(probe, [x.shape[0], 1, 1])
            attention = nn.MultiHeadDotProductAttention(
                deterministic=not train,
                num_heads=transformer.get("attention", {}).get("num_heads", 1),
                kernel_init=nn.initializers.xavier_uniform())
            x = attention(inputs_q=probe, inputs_kv=x)
            y = nn.LayerNorm()(x)
            y = vit.MlpBlock(mlp_dim=transformer["mlp_dim"],
                             dropout_rate=0)(y, deterministic=not train)
            x = (x + y)[:, 0]
        else:
            raise ValueError(f"Unknown classifier: {self.classifier}")

        if self.representation_size is None:
            x = vit.IdentityLayer(name="pre_logits")(x)
            extra_info["pre_logits"] = x
        else:
            x = nn.Dense(self.representation_size, name="pre_logits")(x)
            extra_info["pre_logits"] = x
            x = nn.tanh(x)

        if self.use_gp_layer:
            x_gp = self.gp_layer(x, **gp_kwargs)
            # Gaussian process layer output: a tuple of logits, covmat, and optionally
            # random features.
            extra_info["covmat"] = x_gp[1]
            if len(x_gp) > 2:
                extra_info["random_features"] = x_gp[2]
            if train:
                x = x_gp[0]
            else:
                # During inference, compute posterior mean by adjusting the original
                # logits with predictive uncertainty.
                x = ed.nn.utils.mean_field_logits(
                    logits=x_gp[0],
                    covmat=x_gp[1],
                    mean_field_factor=mean_field_factor)
        else:
            x = nn.Dense(self.num_classes,
                         kernel_init=self.head_kernel_init,
                         name="batchensemble_head")(x)
        return x, extra_info
Beispiel #23
0
    def __call__(self, inputs, *, train):

        x = inputs
        # (Possibly partial) ResNet root.
        if self.resnet is not None:
            width = int(64 * self.resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(features=width,
                                      kernel_size=(7, 7),
                                      strides=(2, 2),
                                      use_bias=False,
                                      name='conv_root')(x)
            x = nn.GroupNorm(name='gn_root')(x)
            x = nn.relu(x)
            x = nn.max_pool(x,
                            window_shape=(3, 3),
                            strides=(2, 2),
                            padding='SAME')

            # ResNet stages.
            if self.resnet.num_layers:
                x = models_resnet.ResNetStage(
                    block_size=self.resnet.num_layers[0],
                    nout=width,
                    first_stride=(1, 1),
                    name='block1')(x)
                for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
                    x = models_resnet.ResNetStage(block_size=block_size,
                                                  nout=width * 2**i,
                                                  first_stride=(2, 2),
                                                  name=f'block{i + 1}')(x)

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x = Encoder(name='Transformer', **self.transformer)(x, train=train)

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        if self.representation_size is not None:
            x = nn.Dense(features=self.representation_size,
                         name='pre_logits')(x)
            x = nn.tanh(x)
        else:
            x = IdentityLayer(name='pre_logits')(x)

        if self.num_classes:
            x = nn.Dense(features=self.num_classes,
                         name='head',
                         kernel_init=nn.initializers.zeros)(x)
        return x
Beispiel #24
0
  def __call__(self,
               inputs: Array,
               train: bool,
               mean_field_factor: float = -1.,
               **gp_kwargs) -> Tuple[Array, Mapping[str, Any]]:
    out = {}

    x = inputs
    n, h, w, c = x.shape

    # We can merge s2d+emb into a single conv; it's the same.
    x = nn.Conv(
        features=self.hidden_size,
        kernel_size=self.patches.size,
        strides=self.patches.size,
        padding='VALID',
        name='embedding')(
            x)

    # Here, x is a grid of embeddings.
    # TODO(dusenberrymw): Switch to self.sow(.).
    out['stem'] = x

    # Transformer.
    n, h, w, c = x.shape
    x = jnp.reshape(x, [n, h * w, c])

    # If we want to add a class token, add it here.
    if self.classifier == 'token':
      cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
      cls = jnp.tile(cls, [n, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)

    x = vit.Encoder(name='Transformer', **self.transformer)(x, train=train)
    out['transformed'] = x

    if self.classifier == 'token':
      x = x[:, 0]
    elif self.classifier == 'gap':
      x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
    else:
      raise ValueError(f'Invalid classifier={self.classifier}')

    out['head_input'] = x

    if self.representation_size is not None:
      x = nn.Dense(features=self.representation_size, name='pre_logits')(x)
      out['pre_logits'] = x
      x = nn.tanh(x)
    else:
      x = vit.IdentityLayer(name='pre_logits')(x)
      out['pre_logits'] = x

    if not self.use_gp_layer:
      logits = nn.Dense(
          features=self.num_classes,
          name='head',
          kernel_init=nn.initializers.zeros)(
              x)
      out['logits'] = logits
    else:
      # Using Gaussian process output layer.
      # This is the only place that ViT-GP differs from determinisitc ViT.
      x_gp = self.gp_layer(x, **gp_kwargs)

      # Gaussian process layer output: a tuple of logits, covmat, and optionally
      # random features.
      out['logits'] = x_gp[0]
      out['covmat'] = x_gp[1]
      if len(x_gp) > 2:
        out['random_features'] = x_gp[2]

      if not train:
        # During inference, compute posterior mean by adjusting the original
        # logits with predictive uncertainty.
        logits = ed.nn.utils.mean_field_logits(
            logits=x_gp[0], covmat=x_gp[1], mean_field_factor=mean_field_factor)
      else:
        logits = x_gp[0]

    return logits, out
Beispiel #25
0
    def __call__(self, inputs, *, train):
        out = {}

        x = inputs
        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.
        # TODO(dusenberrymw): Switch to self.sow(.).
        out['stem'] = x

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x = Encoder(name='Transformer', **self.transformer)(x, train=train)
        out['transformed'] = x

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        out['head_input'] = x

        if self.representation_size is not None:
            x = nn.Dense(features=self.representation_size,
                         name='pre_logits')(x)
            out['pre_logits'] = x
            x = nn.tanh(x)
        else:
            x = IdentityLayer(name='pre_logits')(x)
            out['pre_logits'] = x

        if self.multiclass:
            output_layer = ed.nn.MCSoftmaxDenseFA(self.num_classes,
                                                  self.num_factors,
                                                  self.temperature,
                                                  self.param_efficient,
                                                  self.mc_samples,
                                                  self.mc_samples,
                                                  logits_only=True,
                                                  return_locs=self.return_locs,
                                                  name='multiclass_head')
        else:
            output_layer = ed.nn.MCSigmoidDenseFA(self.num_classes,
                                                  self.num_factors,
                                                  self.temperature,
                                                  self.param_efficient,
                                                  self.mc_samples,
                                                  self.mc_samples,
                                                  logits_only=True,
                                                  return_locs=self.return_locs,
                                                  name='multilabel_head')

        # TODO(markcollier): Fix base model without using stop_gradient.
        if self.fix_base_model:
            x = jax.lax.stop_gradient(x)

        x = output_layer(x)

        out['logits'] = x
        return x, out
 def __call__(self, x):
     h = nn.Dense(self.hidden_dim, use_bias=self.use_bias)(x)
     h = nn.tanh(h)
     return nn.tanh(nn.Dense(self.output_dim, use_bias=self.use_bias)(h))
Beispiel #27
0
  def __call__(self, images: jnp.ndarray, train: Optional[bool] = None):
    train = nn.module.merge_param("train", self.train, train)
    transformer = self.transformer or {}
    # Convert images to patches.
    x = self.embed(images, self.hidden_size, self.patches.size)
    # Add "class" token if necessary.
    n, _, c = x.shape
    if self.classifier == "token":
      cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size))
      cls = jnp.tile(cls, [n, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)
    # Encode tokens.
    x, extra_info = BatchEnsembleEncoder(
        train=train, name="Transformer", **transformer)(
            x)
    # Reduce tokens to a single vector representation.
    if self.classifier == "token":
      # Take the first token's output as representation as in BERT.
      x = x[:, 0]
    elif self.classifier == "gap":
      # Average all tokens.
      x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1)))  # (1,) or (1, 2)
    elif self.classifier == "map":
      probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c))
      # x may have been subject to tiling, n can be different from x.shape[0].
      probe = jnp.tile(probe, [x.shape[0], 1, 1])
      attention = nn.MultiHeadDotProductAttention(
          deterministic=not train,
          num_heads=transformer.get("attention", {}).get("num_heads", 1),
          kernel_init=nn.initializers.xavier_uniform())
      x = attention(inputs_q=probe, inputs_kv=x)
      y = nn.LayerNorm()(x)
      y = vit.MlpBlock(
          mlp_dim=transformer["mlp_dim"], dropout_rate=0)(
              y, deterministic=not train)
      x = (x + y)[:, 0]
    else:
      raise ValueError(f"Unknown classifier: {self.classifier}")

    if self.representation_size is None:
      x = IdentityLayer(name="pre_logits")(x)
      extra_info["pre_logits"] = x
    else:
      x = ed.nn.DenseBatchEnsemble(
          self.representation_size,
          self.transformer.get("ens_size"),
          activation=None,
          alpha_init=ed.nn.utils.make_sign_initializer(
              self.transformer.get("random_sign_init")),
          gamma_init=ed.nn.utils.make_sign_initializer(
              self.transformer.get("random_sign_init")),
          name="pre_logits")(x)
      extra_info["pre_logits"] = x
      x = nn.tanh(x)

    x = ed.nn.DenseBatchEnsemble(
        self.num_classes,
        self.transformer.get("ens_size"),
        activation=None,
        alpha_init=ed.nn.utils.make_sign_initializer(
            self.transformer.get("random_sign_init")),
        gamma_init=ed.nn.utils.make_sign_initializer(
            self.transformer.get("random_sign_init")),
        kernel_init=self.head_kernel_init,
        name="batchensemble_head")(x)
    return x, extra_info
Beispiel #28
0
 def __call__(self, inputs):
     x = inputs
     for feature in self.features[:-1]:
         x = nn.tanh(nn.Dense(feature)(x))
     x = nn.Dense(self.features[-1])(x)
     return x
Beispiel #29
0
    def __call__(self, inputs: Array, train: bool,
                 **kwargs) -> Tuple[Array, Mapping[str, Any]]:
        out = {}

        x = inputs
        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(features=self.hidden_size,
                    kernel_size=self.patches.size,
                    strides=self.patches.size,
                    padding='VALID',
                    name='embedding')(x)

        # Here, x is a grid of embeddings.
        # TODO(dusenberrymw): Switch to self.sow(.).
        out['stem'] = x

        # Transformer.
        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # If we want to add a class token, add it here.
        if self.classifier == 'token':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
            cls = jnp.tile(cls, [n, 1, 1])
            x = jnp.concatenate([cls, x], axis=1)

        x, _ = vit_batchensemble.BatchEnsembleEncoder(name='Transformer',
                                                      **self.transformer)(
                                                          x, train=train)
        out['transformed'] = x

        if self.classifier == 'token':
            x = x[:, 0]
        elif self.classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)
        else:
            raise ValueError(f'Invalid classifier={self.classifier}')

        out['head_input'] = x

        if self.representation_size is not None:
            x = ed.nn.DenseBatchEnsemble(
                self.representation_size,
                self.transformer.get('ens_size'),
                activation=None,
                alpha_init=ed.nn.utils.make_sign_initializer(
                    self.transformer.get('random_sign_init')),
                gamma_init=ed.nn.utils.make_sign_initializer(
                    self.transformer.get('random_sign_init')),
                name='pre_logits')(x)
            out['pre_logits'] = x
            x = nn.tanh(x)
        else:
            x = vit.IdentityLayer(name='pre_logits')(x)
            out['pre_logits'] = x

        # TODO(markcollier): Fix base model without using stop_gradient.
        if self.fix_base_model:
            x = jax.lax.stop_gradient(x)

        if self.use_gp:
            if self.covmat_momentum < 0.:
                gp_layer_kwargs = {'covmat_kwargs': {'momentum': None}}
            else:
                gp_layer_kwargs = {
                    'covmat_kwargs': {
                        'momentum': self.covmat_momentum
                    }
                }

            if self.multiclass:
                raise NotImplementedError(
                    'Multi-class HetSNGP layer not available.')
            else:
                gp_layer = ed.nn.MCSigmoidDenseFASNGPBE(
                    num_outputs=self.num_classes,
                    num_factors=self.num_factors,
                    temperature=self.temperature,
                    parameter_efficient=self.param_efficient,
                    train_mc_samples=self.mc_samples,
                    test_mc_samples=self.mc_samples,
                    ens_size=self.transformer.get('ens_size'),
                    logits_only=True,
                    name='head',
                    **gp_layer_kwargs)
            x_gp = gp_layer(x, training=train, **kwargs)

            # Gaussian process layer output: a tuple of logits, covmat, and optionally
            # random features.
            out['logits'] = x_gp[0]
            out['covmat'] = x_gp[1]

            logits = x_gp[0]
        else:
            # Note we're using non-BE layers.
            if self.multiclass:
                output_layer = ed.nn.MCSoftmaxDenseFA(
                    self.num_classes,
                    self.num_factors,
                    self.temperature,
                    self.param_efficient,
                    self.mc_samples,
                    self.mc_samples,
                    logits_only=True,
                    return_locs=self.return_locs,
                    name='head')
            else:
                output_layer = ed.nn.MCSigmoidDenseFA(
                    num_outputs=self.num_classes,
                    num_factors=self.num_factors,
                    temperature=self.temperature,
                    parameter_efficient=self.param_efficient,
                    train_mc_samples=self.mc_samples,
                    test_mc_samples=self.mc_samples,
                    logits_only=True,
                    return_locs=self.return_locs,
                    name='head')
            logits = output_layer(x)
            out['logits'] = logits

        if not train:
            if self.multiclass:
                logits = log_average_softmax_probs(
                    jnp.asarray(
                        jnp.split(logits, self.transformer.get('ens_size'))))
                out['pre_ens_logits'] = out['pre_logits']
                out['pre_logits'] = log_average_softmax_probs(
                    jnp.asarray(
                        jnp.split(out['pre_logits'],
                                  self.transformer.get('ens_size'))))
            else:
                logits = log_average_sigmoid_probs(
                    jnp.asarray(
                        jnp.split(logits, self.transformer.get('ens_size'))))
                out['pre_ens_logits'] = out['pre_logits']
                out['pre_logits'] = log_average_sigmoid_probs(
                    jnp.asarray(
                        jnp.split(out['pre_logits'],
                                  self.transformer.get('ens_size'))))

        return logits, out
Beispiel #30
0
    def __call__(self, x, *, train=False):
        out = {}

        # Patch extraction
        x = out['stem'] = nn.Conv(self.width,
                                  self.patch_size,
                                  strides=self.patch_size,
                                  padding='VALID',
                                  name='embedding')(x)

        n, h, w, c = x.shape
        x = jnp.reshape(x, [n, h * w, c])

        # Add posemb before adding extra token.
        x = out['with_posemb'] = x + get_posemb(self, self.posemb, (h, w), c,
                                                'pos_embedding', x.dtype)

        if self.pool_type == 'tok':
            cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype)
            x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)

        n, l, c = x.shape  # pylint: disable=unused-variable
        x = nn.Dropout(rate=self.dropout)(x, not train)

        x, out['encoder'] = Encoder(depth=self.depth,
                                    mlp_dim=self.mlp_dim,
                                    num_heads=self.num_heads,
                                    dropout=self.dropout,
                                    name='Transformer')(x, train=not train)
        encoded = out['encoded'] = x

        if self.pool_type == 'map':
            x = out['head_input'] = MAPHead(num_heads=self.num_heads,
                                            mlp_dim=self.mlp_dim)(x)
        elif self.pool_type == 'gap':
            x = out['head_input'] = jnp.mean(x, axis=1)
        elif self.pool_type == '0':
            x = out['head_input'] = x[:, 0]
        elif self.pool_type == 'tok':
            x = out['head_input'] = x[:, 0]
            encoded = encoded[:, 1:]
        else:
            raise ValueError(f'Unknown pool type: "{self.pool_type}"')

        x_2d = jnp.reshape(encoded, [n, h, w, -1])

        if self.rep_size:
            rep_size = self.width if self.rep_size is True else self.rep_size  # pylint: disable=g-bool-id-comparison
            hid = nn.Dense(rep_size, name='pre_logits')
            # NOTE: In the past we did not include tanh in pre_logits.
            # For few-shot, it should not matter much, as it whitens anyways.
            x_2d = nn.tanh(hid(x_2d))
            x = nn.tanh(hid(x))

        out['pre_logits_2d'] = x_2d
        out['pre_logits'] = x

        if self.num_classes:
            kw = {
                'kernel_init': nn.initializers.zeros
            } if self.head_zeroinit else {}
            head = nn.Dense(self.num_classes, name='head', **kw)
            x_2d = out['logits_2d'] = head(x_2d)
            x = out['logits'] = head(x)

        # TODO(dsuo): this used to be `return x, out`. Do we need out?
        return x