def __call__(self, x, key=None, sample=False, MPO=False): x = nn.Dense(features=200)(x) x = nn.LayerNorm()(x) x = nn.tanh(x) x = nn.Dense(features=200)(x) x = nn.elu(x) x = nn.Dense(features=2 * self.action_dim)(x) mu, log_sig = jnp.split(x, 2, axis=-1) log_sig = jnp.clip(log_sig, self.log_sig_min, self.log_sig_max) if MPO: return mu, log_sig if not sample: return self.max_action * nn.tanh(mu), log_sig else: sig = jnp.exp(log_sig) pi = mu + random.normal(key, mu.shape) * sig log_pi = gaussian_likelihood(pi, mu, log_sig) pi = nn.tanh(pi) log_pi -= jnp.sum( jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1, keepdims=True, ) return self.max_action * pi, log_pi
def __call__(self, inputs): x = inputs for feature in self.shared_features: x = nn.tanh(nn.Dense(feature)(x)) x = jnp.repeat(jnp.expand_dims(x, axis=0), repeats=self.n_tasks, axis=0) # If we batch, can we do without copying data? for feature in self.specific_features[:-1]: x = nn.tanh(MultiTaskDense(feature, self.n_tasks)(x)) x = MultiTaskDense(self.specific_features[-1], self.n_tasks)(x) return x.squeeze().T
def __call__(self, inputs, *, train): """Function of shapes [B*R,h,w,c*E] -> [E*B*R,num_classes].""" out = {} x = inputs # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')(x) # Here, x is a grid of embeddings. # TODO(dusenberrymw): Switch to self.sow(.). out['stem'] = x # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = vit.Encoder(name='Transformer', **self.transformer)(x, train=train) out['transformed'] = x if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') out['head_input'] = x if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name='pre_logits')(x) out['pre_logits'] = x x = nn.tanh(x) else: x = vit.IdentityLayer(name='pre_logits')(x) out['pre_logits'] = x # TODO(markcollier): Fix base model without using stop_gradient. if self.fix_base_model: x = jax.lax.stop_gradient(x) # Shape: (batch_size, num_classes * ensemble_size). x = nn.Dense(self.num_classes * self.ensemble_size, name='head', kernel_init=nn.initializers.zeros)(x) # Shape: (batch_size * ensemble_size, num_classes). x = jnp.concatenate(jnp.split(x, self.ensemble_size, axis=-1)) out['logits'] = x return x, out
def __call__(self, x): x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=self.action_dim)(x) return self.max_action * nn.tanh(x)
def get_td_target( rng: PRNGSequence, state: jnp.ndarray, action: jnp.ndarray, next_state: jnp.ndarray, reward: jnp.ndarray, not_done: jnp.ndarray, discount: float, max_action: float, action_dim: int, actor_target_params: FrozenDict, critic_target_params: FrozenDict, ) -> jnp.ndarray: mu, log_sig = apply_gaussian_policy_model( actor_target_params, action_dim, max_action, next_state, None, False, True ) next_action = mu + jnp.exp(log_sig) * random.normal(rng, mu.shape) next_action = max_action * nn.tanh(next_action) target_Q1, target_Q2 = apply_double_critic_model( critic_target_params, next_state, next_action, False ) target_Q = jnp.minimum(target_Q1, target_Q2) target_Q = reward + not_done * discount * target_Q return target_Q
def __call__(self, α, β): h1 = nn.tanh( nn.Dense(self.hiddendim, name="encoder_layer_1")(jnp.hstack([α, β]))) μ = nn.Dense(self.latentdim, name="encoder_μ_layer_1")(h1) logσ2 = nn.Dense(self.latentdim, name="encoder_logσ_layer_1")(h1) return μ, logσ2
def __call__(self, x): out_l1 = nn.softplus(self.group_l1(self.layer1(x))) out_l1 = nn.softplus(self.group_l1(self.layer12(x))) out_1 = nn.softplus(self.group1(self.down1(out_l1))) out_1 = nn.softplus(self.group12(self.down12(out_1))) out_2 = nn.softplus(self.group2(self.down2(out_1))) out_2 = nn.softplus(self.group22(self.down22(out_2))) out_3 = nn.softplus(self.group3(self.down3(out_2))) out_3 = nn.softplus(self.group32(self.down32(out_3))) out_4 = nn.softplus(self.group4(self.down4(out_3))) out_4 = nn.softplus(self.group42(self.down42(out_4))) out_latent = nn.softplus(self.group_latent(self.latent(out_4))) in_up4 = jnp.concatenate((out_4, out_latent), axis=-1) # out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(out_4)))) out_up4 = nn.softplus(self.group_up4(self.up4(self.deconv(in_up4)))) out_up4 = nn.softplus(self.group_up42(self.up42(out_up4))) in_up3 = jnp.concatenate((out_3, out_up4), axis=-1) out_up3 = nn.softplus(self.group_up3(self.up3(self.deconv(in_up3)))) out_up3 = nn.softplus(self.group_up32(self.up32(out_up3))) in_up2 = jnp.concatenate((out_2, out_up3), axis=-1) out_up2 = nn.softplus(self.group_up2(self.up2(self.deconv(in_up2)))) out_up2 = nn.softplus(self.group_up22(self.up22(out_up2))) in_up1 = jnp.concatenate((out_1, out_up2), axis=-1) out_up1 = nn.softplus(self.group_up1(self.up1(self.deconv(in_up1)))) out_up1 = nn.softplus(self.group_up12(self.up12(out_up1))) in_straight1 = jnp.concatenate((out_l1, out_up1), axis=-1) out_straight1 = nn.softplus( self.group_straight1(self.straight1(in_straight1))) out_straight1 = nn.softplus( self.group_straight12(self.straight12(out_straight1))) return nn.tanh(self.group_straight2(self.straight2(out_straight1)))
def __call__(self, keys: Array, mask: Array) -> Array: """Applies model to the input keys and mask. Args: keys: The inputs for which to compute an attention score. Shape: <float32>[batch_size, seq_length, embeddings_size]. mask: A mask that determinines which values in `keys` are valid. Only values for which the mask is True will get non-zero attention scores. <bool>[batch_size, seq_length]. Returns: The normalized attention scores. <float32>[batch_size, seq_length]. """ hidden = nn.Dense(self.hidden_size, name='keys', use_bias=False)(keys) energy = nn.tanh(hidden) scores = nn.Dense(1, name='energy', use_bias=False)(energy) scores = scores.squeeze( -1) # New shape: <float32>[batch_size, seq_len]. scores = jnp.where(mask, scores, -jnp.inf) # Using exp(-inf) = 0 below. scores = nn.softmax(scores, axis=-1) # Captures the scores if 'intermediates' is mutable, otherwise does nothing. self.sow('intermediates', 'attention', scores) return scores
def __call__(self, x): #x = nn.tanh(nn.Dense(features=128)(x)) x = nn.tanh(nn.Dense(features=64)(x)) x = nn.Dense(features=1)(x) sp = -nn.softplus(x) return jnp.concatenate([sp, sp + x], -1) #p(z|x), 1-p(z|x)
def __call__(self, hidden_states, deterministic=True): hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS]) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.dense(hidden_states) hidden_states = nn.tanh(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.out_proj(hidden_states) return hidden_states
def __call__(self, x): l1 = nn.relu(self.group_l1(self.layer1(x))) unet = self.mid(l1) cat = jnp.concatenate((l1, unet), axis=-1) l2 = nn.relu(self.group_straight1(self.straight1(cat))) out = nn.tanh(self.straight2(l2)) return out
def __call__(self, hidden_states): cls_token = hidden_states[:, 0] out = nn.Dense( hidden_states.shape[-1], kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, )(cls_token) return nn.tanh(out)
def __call__(self, x): # Images are stored in the replay buffer as uint8. x = x.astype(jnp.float32) / 255.0 # Flatten the last dimension (normally to deal with stacked rgb frames) if len(x.shape) > 3: x = x.reshape((*x.shape[:2], -1)) kernel_init = nn.initializers.orthogonal() x = nn.Conv(features=32, kernel_size=(3, 3), strides=(2, 2), kernel_init=kernel_init)(x) x = nn.relu(x) x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), kernel_init=kernel_init)(x) x = nn.relu(x) x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), kernel_init=kernel_init)(x) x = nn.relu(x) x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), kernel_init=kernel_init)(x) x = nn.relu(x) x = jnp.reshape(x, -1) # Flatten critic_z = nn.Dense(features=50, kernel_init=kernel_init)(x) critic_z = nn.LayerNorm()(critic_z) critic_z = nn.tanh(critic_z) # Only the critic should train the convolution layers, so stop the # gradients from the actor. actor_z = nn.Dense(features=50, kernel_init=kernel_init)(jax.lax.stop_gradient(x)) actor_z = nn.LayerNorm()(actor_z) actor_z = nn.tanh(actor_z) return SACEncoderOutputs(critic_z, actor_z)
def __call__(self, state, action, Q1=False): state_action = jnp.concatenate([state, action], axis=-1) q1 = nn.Dense(features=500)(state_action) q1 = nn.LayerNorm()(q1) q1 = nn.tanh(q1) q1 = nn.Dense(features=500)(q1) q1 = nn.elu(q1) q1 = nn.Dense(features=1)(q1) if Q1: return q1 q2 = nn.Dense(features=500)(state_action) q2 = nn.LayerNorm()(q2) q2 = nn.tanh(q2) q2 = nn.Dense(features=500)(q2) q2 = nn.elu(q2) q2 = nn.Dense(features=1)(q2) return q1, q2
def __call__(self, x): # need to flatten extra dimensions required by CNN and LSTM x = x.squeeze() x = nn.Dense( features=self.hidden_dim, use_bias=False, name=f"shallow_fc{1}_model" + str(self.model_num), )(x) x = nn.tanh(x) x = nn.Dense(features=self.out_dim, use_bias=True, name=f"shallow_fc{2}_model" + str(self.model_num))(x) return x.squeeze( ) # squeeze for consistent shape w/ boundary model output
def __call__(self, batch: Dict[str, Array], deterministic: bool): encoding, loss_helpers, logging_helpers = self.encoder.forward( batch, deterministic) cls_encoding = encoding[:, 0, ...] if self.apply_mlp: cls_encoding = self.mlp(cls_encoding) cls_encoding = nn.tanh(cls_encoding) cls_encoding = self.dropout(cls_encoding, deterministic=deterministic) classifier_logits = self.linear_classifier(cls_encoding) loss_helpers['classifier_logits'] = classifier_logits return loss_helpers, logging_helpers
def __call__(self, images: jnp.ndarray, train: Optional[bool] = None): train = nn.module.merge_param("train", self.train, train) transformer = self.transformer or {} # Convert images to patches. x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid) # Add "class" token if necessary. n, _, c = x.shape if self.classifier == "token": cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) # Encode tokens. x, extra_info = BatchEnsembleEncoder( train=train, name="BatchEnsembleTransformer", **transformer)( x) # Reduce tokens to a single vector representation. if self.classifier == "token": # Take the first token's output as representation as in BERT. x = x[:, 0] elif self.classifier == "gap": # Average all tokens. x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1))) # (1,) or (1, 2) elif self.classifier == "map": probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c)) probe = jnp.tile(probe, [n, 1, 1]) attention = nn.MultiHeadDotProductAttention( deterministic=not train, num_heads=transformer.get("attention", {}).get("num_heads", 1), kernel_init=nn.initializers.xavier_uniform()) x = attention(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm()(x) y = patch_transformer_lib.MlpBlock( mlp_dim=transformer["mlp_dim"], dropout_rate=0, deterministic=not train)(y) x = (x + y)[:, 0] else: raise ValueError(f"Unknown classifier: {self.classifier}") if self.representation_size is None: x = identity.IdentityLayer(name="pre_logits")(x) else: x = nn.Dense(self.representation_size, name="pre_logits")(x) x = nn.tanh(x) x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init, name="head")(x) return x, extra_info
def __call__(self, x, *, train, debug=False): fh, fw = self.patches.size # Extracting patches and then embedding is in fact a single convolution. x = nn.Conv(self.hidden_size, (fh, fw), strides=(fh, fw), padding='VALID', name='embedding')(x) n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(mlp_dim=self.mlp_dim, num_layers=self.num_layers, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, stochastic_depth=self.stochastic_depth, dtype=self.dtype, nb_x_patches=h, nb_y_patches=w, name='Transformer')(x, train=train) if self.classifier in ('token', '0'): x = x[:, 0] elif self.classifier in ('gap', 'gmp', 'gsp'): fn = { 'gap': jnp.mean, 'gmp': jnp.max, 'gsp': jnp.sum }[self.classifier] x = fn(x, axis=1) if self.representation_size is not None: x = nn.Dense(self.representation_size, name='pre_logits')(x) x = nn.tanh(x) else: x = nn_layers.IdentityLayer(name='pre_logits')(x) x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros, name='output_projection')(x) return x
def sample_actions_and_evaluate( rng: PRNGSequence, actor_target_params: FrozenDict, critic_target_params: FrozenDict, max_action: float, action_dim: int, state: jnp.ndarray, batch_size: int, action_sample_size: int, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ To build our nonparametric policy, q(s, a), we sample `action_sample_size` actions from each policy in the batch and evaluate their Q-values. """ # get the policy distribution for each state and sample `action_sample_size` # actions from each mu, log_sig = apply_gaussian_policy_model( actor_target_params, action_dim, max_action, state, None, False, True ) mu = jnp.expand_dims(mu, axis=1) sig = jnp.expand_dims(jnp.exp(log_sig), axis=1) sampled_actions = ( mu + random.normal(rng, (batch_size, action_sample_size, action_dim)) * sig ) sampled_actions = sampled_actions.reshape( (batch_size * action_sample_size, action_dim) ) sampled_actions = jax.lax.stop_gradient(sampled_actions) states_repeated = jnp.repeat(state, action_sample_size, axis=0) # evaluate each of the sampled actions at their corresponding state # we keep the `sampled_actions` array unnquashed because we need to calcuate # the log probabilities using it, but we pass the squashed actions to the critic Q1 = apply_double_critic_model( critic_target_params, states_repeated, max_action * nn.tanh(sampled_actions), True, ) Q1 = Q1.reshape((batch_size, action_sample_size)) Q1 = jax.lax.stop_gradient(Q1) return Q1, sampled_actions
def sample_actions_and_evaluate( rng: PRNGSequence, actor_target_params: FrozenDict, critic_target_params: FrozenDict, max_action: float, action_dim: int, state: jnp.ndarray, batch_size: int, action_sample_size: int, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ To build our nonparametric policy, q(s, a), we sample `action_sample_size` actions from each policy in the batch and evaluate their Q-values. """ state_dim = state.shape[-1] # get the policy distribution for each state and sample `action_sample_size` # actions from each mu, log_sig = apply_gaussian_policy_model( actor_target_params, state_dim, max_action, state, None, False, True ) sig = jnp.exp(log_sig) sampled_actions = mu + random.normal(rng, (batch_size, action_sample_size)) * sig sampled_actions = max_action * nn.tanh(sampled_actions) sampled_actions = sampled_actions.reshape( (batch_size * action_sample_size, action_dim) ) sampled_actions = jax.lax.stop_gradient(sampled_actions) states_repeated = jnp.repeat(state, action_sample_size, axis=0) # evaluate each of the sampled actions at their corresponding state Q1 = apply_double_critic_model( critic_target_params, states_repeated, sampled_actions, True ) Q1 = Q1.reshape((batch_size, action_sample_size)) Q1 = jax.lax.stop_gradient(Q1) return Q1, sampled_actions
def __call__(self, hidden_states): cls_hidden_state = hidden_states[:, 0] cls_hidden_state = self.dense(cls_hidden_state) return nn.tanh(cls_hidden_state)
def __call__(self, images: jnp.ndarray, train: Optional[bool] = None, mean_field_factor: float = -1., **gp_kwargs): train = nn.module.merge_param("train", self.train, train) transformer = self.transformer or {} # Convert images to patches. x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid) # Add "class" token if necessary. n, _, c = x.shape if self.classifier == "token": cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) # Encode tokens. x, extra_info = vit_batchensemble.BatchEnsembleEncoder( train=train, name="Transformer", **transformer)(x) # Reduce tokens to a single vector representation. if self.classifier == "token": # Take the first token's output as representation as in BERT. x = x[:, 0] elif self.classifier == "gap": # Average all tokens. x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1))) # (1,) or (1, 2) elif self.classifier == "map": probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c)) # x may have been subject to tiling, n can be different from x.shape[0]. probe = jnp.tile(probe, [x.shape[0], 1, 1]) attention = nn.MultiHeadDotProductAttention( deterministic=not train, num_heads=transformer.get("attention", {}).get("num_heads", 1), kernel_init=nn.initializers.xavier_uniform()) x = attention(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm()(x) y = vit.MlpBlock(mlp_dim=transformer["mlp_dim"], dropout_rate=0)(y, deterministic=not train) x = (x + y)[:, 0] else: raise ValueError(f"Unknown classifier: {self.classifier}") if self.representation_size is None: x = vit.IdentityLayer(name="pre_logits")(x) extra_info["pre_logits"] = x else: x = nn.Dense(self.representation_size, name="pre_logits")(x) extra_info["pre_logits"] = x x = nn.tanh(x) if self.use_gp_layer: x_gp = self.gp_layer(x, **gp_kwargs) # Gaussian process layer output: a tuple of logits, covmat, and optionally # random features. extra_info["covmat"] = x_gp[1] if len(x_gp) > 2: extra_info["random_features"] = x_gp[2] if train: x = x_gp[0] else: # During inference, compute posterior mean by adjusting the original # logits with predictive uncertainty. x = ed.nn.utils.mean_field_logits( logits=x_gp[0], covmat=x_gp[1], mean_field_factor=mean_field_factor) else: x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init, name="batchensemble_head")(x) return x, extra_info
def __call__(self, inputs, *, train): x = inputs # (Possibly partial) ResNet root. if self.resnet is not None: width = int(64 * self.resnet.width_factor) # Root block. x = models_resnet.StdConv(features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name='conv_root')(x) x = nn.GroupNorm(name='gn_root')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME') # ResNet stages. if self.resnet.num_layers: x = models_resnet.ResNetStage( block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name='block1')(x) for i, block_size in enumerate(self.resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f'block{i + 1}')(x) n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')(x) # Here, x is a grid of embeddings. # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(name='Transformer', **self.transformer)(x, train=train) if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name='pre_logits')(x) x = nn.tanh(x) else: x = IdentityLayer(name='pre_logits')(x) if self.num_classes: x = nn.Dense(features=self.num_classes, name='head', kernel_init=nn.initializers.zeros)(x) return x
def __call__(self, inputs: Array, train: bool, mean_field_factor: float = -1., **gp_kwargs) -> Tuple[Array, Mapping[str, Any]]: out = {} x = inputs n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv( features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')( x) # Here, x is a grid of embeddings. # TODO(dusenberrymw): Switch to self.sow(.). out['stem'] = x # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = vit.Encoder(name='Transformer', **self.transformer)(x, train=train) out['transformed'] = x if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') out['head_input'] = x if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name='pre_logits')(x) out['pre_logits'] = x x = nn.tanh(x) else: x = vit.IdentityLayer(name='pre_logits')(x) out['pre_logits'] = x if not self.use_gp_layer: logits = nn.Dense( features=self.num_classes, name='head', kernel_init=nn.initializers.zeros)( x) out['logits'] = logits else: # Using Gaussian process output layer. # This is the only place that ViT-GP differs from determinisitc ViT. x_gp = self.gp_layer(x, **gp_kwargs) # Gaussian process layer output: a tuple of logits, covmat, and optionally # random features. out['logits'] = x_gp[0] out['covmat'] = x_gp[1] if len(x_gp) > 2: out['random_features'] = x_gp[2] if not train: # During inference, compute posterior mean by adjusting the original # logits with predictive uncertainty. logits = ed.nn.utils.mean_field_logits( logits=x_gp[0], covmat=x_gp[1], mean_field_factor=mean_field_factor) else: logits = x_gp[0] return logits, out
def __call__(self, inputs, *, train): out = {} x = inputs n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')(x) # Here, x is a grid of embeddings. # TODO(dusenberrymw): Switch to self.sow(.). out['stem'] = x # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(name='Transformer', **self.transformer)(x, train=train) out['transformed'] = x if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') out['head_input'] = x if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name='pre_logits')(x) out['pre_logits'] = x x = nn.tanh(x) else: x = IdentityLayer(name='pre_logits')(x) out['pre_logits'] = x if self.multiclass: output_layer = ed.nn.MCSoftmaxDenseFA(self.num_classes, self.num_factors, self.temperature, self.param_efficient, self.mc_samples, self.mc_samples, logits_only=True, return_locs=self.return_locs, name='multiclass_head') else: output_layer = ed.nn.MCSigmoidDenseFA(self.num_classes, self.num_factors, self.temperature, self.param_efficient, self.mc_samples, self.mc_samples, logits_only=True, return_locs=self.return_locs, name='multilabel_head') # TODO(markcollier): Fix base model without using stop_gradient. if self.fix_base_model: x = jax.lax.stop_gradient(x) x = output_layer(x) out['logits'] = x return x, out
def __call__(self, x): h = nn.Dense(self.hidden_dim, use_bias=self.use_bias)(x) h = nn.tanh(h) return nn.tanh(nn.Dense(self.output_dim, use_bias=self.use_bias)(h))
def __call__(self, images: jnp.ndarray, train: Optional[bool] = None): train = nn.module.merge_param("train", self.train, train) transformer = self.transformer or {} # Convert images to patches. x = self.embed(images, self.hidden_size, self.patches.size) # Add "class" token if necessary. n, _, c = x.shape if self.classifier == "token": cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) # Encode tokens. x, extra_info = BatchEnsembleEncoder( train=train, name="Transformer", **transformer)( x) # Reduce tokens to a single vector representation. if self.classifier == "token": # Take the first token's output as representation as in BERT. x = x[:, 0] elif self.classifier == "gap": # Average all tokens. x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1))) # (1,) or (1, 2) elif self.classifier == "map": probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c)) # x may have been subject to tiling, n can be different from x.shape[0]. probe = jnp.tile(probe, [x.shape[0], 1, 1]) attention = nn.MultiHeadDotProductAttention( deterministic=not train, num_heads=transformer.get("attention", {}).get("num_heads", 1), kernel_init=nn.initializers.xavier_uniform()) x = attention(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm()(x) y = vit.MlpBlock( mlp_dim=transformer["mlp_dim"], dropout_rate=0)( y, deterministic=not train) x = (x + y)[:, 0] else: raise ValueError(f"Unknown classifier: {self.classifier}") if self.representation_size is None: x = IdentityLayer(name="pre_logits")(x) extra_info["pre_logits"] = x else: x = ed.nn.DenseBatchEnsemble( self.representation_size, self.transformer.get("ens_size"), activation=None, alpha_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), gamma_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), name="pre_logits")(x) extra_info["pre_logits"] = x x = nn.tanh(x) x = ed.nn.DenseBatchEnsemble( self.num_classes, self.transformer.get("ens_size"), activation=None, alpha_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), gamma_init=ed.nn.utils.make_sign_initializer( self.transformer.get("random_sign_init")), kernel_init=self.head_kernel_init, name="batchensemble_head")(x) return x, extra_info
def __call__(self, inputs): x = inputs for feature in self.features[:-1]: x = nn.tanh(nn.Dense(feature)(x)) x = nn.Dense(self.features[-1])(x) return x
def __call__(self, inputs: Array, train: bool, **kwargs) -> Tuple[Array, Mapping[str, Any]]: out = {} x = inputs n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')(x) # Here, x is a grid of embeddings. # TODO(dusenberrymw): Switch to self.sow(.). out['stem'] = x # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x, _ = vit_batchensemble.BatchEnsembleEncoder(name='Transformer', **self.transformer)( x, train=train) out['transformed'] = x if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') out['head_input'] = x if self.representation_size is not None: x = ed.nn.DenseBatchEnsemble( self.representation_size, self.transformer.get('ens_size'), activation=None, alpha_init=ed.nn.utils.make_sign_initializer( self.transformer.get('random_sign_init')), gamma_init=ed.nn.utils.make_sign_initializer( self.transformer.get('random_sign_init')), name='pre_logits')(x) out['pre_logits'] = x x = nn.tanh(x) else: x = vit.IdentityLayer(name='pre_logits')(x) out['pre_logits'] = x # TODO(markcollier): Fix base model without using stop_gradient. if self.fix_base_model: x = jax.lax.stop_gradient(x) if self.use_gp: if self.covmat_momentum < 0.: gp_layer_kwargs = {'covmat_kwargs': {'momentum': None}} else: gp_layer_kwargs = { 'covmat_kwargs': { 'momentum': self.covmat_momentum } } if self.multiclass: raise NotImplementedError( 'Multi-class HetSNGP layer not available.') else: gp_layer = ed.nn.MCSigmoidDenseFASNGPBE( num_outputs=self.num_classes, num_factors=self.num_factors, temperature=self.temperature, parameter_efficient=self.param_efficient, train_mc_samples=self.mc_samples, test_mc_samples=self.mc_samples, ens_size=self.transformer.get('ens_size'), logits_only=True, name='head', **gp_layer_kwargs) x_gp = gp_layer(x, training=train, **kwargs) # Gaussian process layer output: a tuple of logits, covmat, and optionally # random features. out['logits'] = x_gp[0] out['covmat'] = x_gp[1] logits = x_gp[0] else: # Note we're using non-BE layers. if self.multiclass: output_layer = ed.nn.MCSoftmaxDenseFA( self.num_classes, self.num_factors, self.temperature, self.param_efficient, self.mc_samples, self.mc_samples, logits_only=True, return_locs=self.return_locs, name='head') else: output_layer = ed.nn.MCSigmoidDenseFA( num_outputs=self.num_classes, num_factors=self.num_factors, temperature=self.temperature, parameter_efficient=self.param_efficient, train_mc_samples=self.mc_samples, test_mc_samples=self.mc_samples, logits_only=True, return_locs=self.return_locs, name='head') logits = output_layer(x) out['logits'] = logits if not train: if self.multiclass: logits = log_average_softmax_probs( jnp.asarray( jnp.split(logits, self.transformer.get('ens_size')))) out['pre_ens_logits'] = out['pre_logits'] out['pre_logits'] = log_average_softmax_probs( jnp.asarray( jnp.split(out['pre_logits'], self.transformer.get('ens_size')))) else: logits = log_average_sigmoid_probs( jnp.asarray( jnp.split(logits, self.transformer.get('ens_size')))) out['pre_ens_logits'] = out['pre_logits'] out['pre_logits'] = log_average_sigmoid_probs( jnp.asarray( jnp.split(out['pre_logits'], self.transformer.get('ens_size')))) return logits, out
def __call__(self, x, *, train=False): out = {} # Patch extraction x = out['stem'] = nn.Conv(self.width, self.patch_size, strides=self.patch_size, padding='VALID', name='embedding')(x) n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # Add posemb before adding extra token. x = out['with_posemb'] = x + get_posemb(self, self.posemb, (h, w), c, 'pos_embedding', x.dtype) if self.pool_type == 'tok': cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype) x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1) n, l, c = x.shape # pylint: disable=unused-variable x = nn.Dropout(rate=self.dropout)(x, not train) x, out['encoder'] = Encoder(depth=self.depth, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout=self.dropout, name='Transformer')(x, train=not train) encoded = out['encoded'] = x if self.pool_type == 'map': x = out['head_input'] = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x) elif self.pool_type == 'gap': x = out['head_input'] = jnp.mean(x, axis=1) elif self.pool_type == '0': x = out['head_input'] = x[:, 0] elif self.pool_type == 'tok': x = out['head_input'] = x[:, 0] encoded = encoded[:, 1:] else: raise ValueError(f'Unknown pool type: "{self.pool_type}"') x_2d = jnp.reshape(encoded, [n, h, w, -1]) if self.rep_size: rep_size = self.width if self.rep_size is True else self.rep_size # pylint: disable=g-bool-id-comparison hid = nn.Dense(rep_size, name='pre_logits') # NOTE: In the past we did not include tanh in pre_logits. # For few-shot, it should not matter much, as it whitens anyways. x_2d = nn.tanh(hid(x_2d)) x = nn.tanh(hid(x)) out['pre_logits_2d'] = x_2d out['pre_logits'] = x if self.num_classes: kw = { 'kernel_init': nn.initializers.zeros } if self.head_zeroinit else {} head = nn.Dense(self.num_classes, name='head', **kw) x_2d = out['logits_2d'] = head(x_2d) x = out['logits'] = head(x) # TODO(dsuo): this used to be `return x, out`. Do we need out? return x