def apply( self, x, action_dim, max_action, key=None, MPO=False, sample=False, log_sig_min=-20, log_sig_max=2, ): x = nn.Dense(x, features=200) x = nn.LayerNorm(x) x = nn.tanh(x) x = nn.Dense(x, features=200) x = nn.elu(x) x = nn.Dense(x, features=2 * action_dim) mu, log_sig = jnp.split(x, 2, axis=-1) log_sig = nn.softplus(log_sig) log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max) if MPO: return mu, log_sig if not sample: return max_action * nn.tanh(mu), log_sig else: pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig) log_pi = gaussian_likelihood(pi, mu, log_sig) pi = nn.tanh(pi) log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1) return max_action * pi, log_pi
def apply(self, x): x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=1) return x
def apply(self, x, action_dim, max_action): x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=action_dim) return max_action * nn.tanh(x)
def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None): H_0 = nn.relu(nn.Dense(x, conv_rep_size)) G_0 = nn.relu(nn.Dense(x, conv_rep_size)) H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2) for layer in range(1, m_layers+1): if layer < m_layers: H_features, G_features = m_features[layer-1] else: H_features, G_features = conv_rep_size, conv_rep_size H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1] H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1)) G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) if layer < m_layers: H = nn.relu(H) G = nn.relu(G) else: H = nn.tanh(H) G = nn.sigmoid(G) H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2) F = H * G + G_0 rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size) return rep
def apply(self, x, num_classes=1000, train=False, resnet=None, patches=None, hidden_size=None, transformer=None, representation_size=None, classifier='gap'): n, h, w, c = x.shape # Embed the grid or patches of the grid. fh, fw = patches.size gh, gw = h // fh, w // fw if hidden_size: # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv( x, hidden_size, (fh, fw), strides=(fh, fw), padding='VALID', name='embedding') else: # This path often results in excessive padding. x = jnp.reshape(x, [n, gh, fh, gw, fw, c]) x = jnp.transpose(x, [0, 1, 3, 2, 4, 5]) x = jnp.reshape(x, [n, gh, gw, -1]) # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if classifier == 'token': cls = self.param('cls', (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(x, train=train, name='Transformer', **transformer) if classifier == 'token': x = x[:, 0] elif classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) if representation_size is not None: x = nn.Dense(x, representation_size, name='pre_logits') x = nn.tanh(x) else: x = IdentityLayer(x, name='pre_logits') x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros) return x
def apply(self, state, action, Q1=False): state_action = jnp.concatenate([state, action], axis=1) q1 = nn.Dense(state_action, features=500) q1 = nn.LayerNorm(q1) q1 = nn.tanh(q1) q1 = nn.Dense(q1, features=500) q1 = nn.elu(q1) q1 = nn.Dense(q1, features=1) if Q1: return q1 q2 = nn.Dense(state_action, features=500) q2 = nn.LayerNorm(q2) q2 = nn.tanh(q2) q2 = nn.Dense(q2, features=500) q2 = nn.elu(q2) q2 = nn.Dense(q2, features=1) return q1, q2
def loss_fn(mlo, slo, actor): mu, log_sig = actor(state, MPO=True) sig = jnp.exp(log_sig) target_mu, target_log_sig = actor_target(state, MPO=True) target_sig = jnp.exp(target_log_sig) actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, sig) actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_sig) actor_log_prob = actor_log_prob.transpose((0, 1)) mu, target_mu = nn.tanh(mu), nn.tanh(mu) reg_mu = eps_mu - kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean() reg_sig = eps_sig - kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean() mlo = lagrange_step(mlo, reg_mu) slo = lagrange_step(slo, reg_sig) actor_loss = -(actor_log_prob[:, None] * weights).sum(axis=1).mean() actor_loss -= mu_lagrange_optimizer.target() * reg_mu actor_loss -= sig_lagrange_optimizer.target() * reg_sig return actor_loss.mean(), (mlo, slo)
def apply(self, inputs: jnp.ndarray, hidden_size: int = None, output_size: int = None, output_bias: bool = False, dropout: float = None, train: bool = None): # inputs.shape = <float32>[batch_size, seq_length, hidden_size] hidden = nn.Dense(inputs, hidden_size, name='hidden') hidden = nn.tanh(hidden) if train: hidden = nn.dropout(hidden, rate=dropout) output = nn.Dense(hidden, output_size, bias=output_bias, name='output') return output
def apply(self, x): return nn.sigmoid(x) * nn.sigmoid(-x) * nn.tanh(x) * (1 / 0.15)
def apply(self, x, num_classes=1000, train=False, resnet=None, patches=None, hidden_size=None, transformer=None, representation_size=None, classifier='gap'): # (Possibly partial) ResNet root. if resnet is not None: width = int(64 * resnet.width_factor) # Root block. x = models_resnet.StdConv(x, width, (7, 7), (2, 2), bias=False, name='conv_root') x = nn.GroupNorm(x, name='gn_root') x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') # ResNet stages. x = models_resnet.ResNetStage(x, resnet.num_layers[0], width, first_stride=(1, 1), name='block1') for i, block_size in enumerate(resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(x, block_size, width * 2**i, first_stride=(2, 2), name=f'block{i + 1}') n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(x, hidden_size, patches.size, strides=patches.size, padding='VALID', name='embedding') # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if classifier == 'token': cls = self.param('cls', (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(x, train=train, name='Transformer', **transformer) if classifier == 'token': x = x[:, 0] elif classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) if representation_size is not None: x = nn.Dense(x, representation_size, name='pre_logits') x = nn.tanh(x) else: x = IdentityLayer(x, name='pre_logits') x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros) return x
def apply(self, x, num_classes=1, train=False, hidden_size=None, transformer=None, resnet_emb=None, representation_size=None): """Apply model on inputs. Args: x: the processed input patches and position annotations. num_classes: the number of output classes. 1 for single model. train: train or eval. hidden_size: the hidden dimension for patch embedding tokens. transformer: the model config for Transformer backbone. resnet_emb: the config for patch embedding w/ small resnet. representation_size: size of the last FC before prediction. Returns: Model prediction output. """ assert transformer is not None # Either 3: (batch size, seq len, channel) or # 4: (batch size, crops, seq len, channel) assert len(x.shape) in [3, 4] multi_crops_input = False if len(x.shape) == 4: multi_crops_input = True batch_size, num_crops, l, channel = x.shape x = jnp.reshape(x, [batch_size * num_crops, l, channel]) # We concat (x, spatial_positions, scale_posiitons, input_masks) # when preprocessing. inputs_spatial_positions = x[:, :, -3] inputs_spatial_positions = inputs_spatial_positions.astype(jnp.int32) inputs_scale_positions = x[:, :, -2] inputs_scale_positions = inputs_scale_positions.astype(jnp.int32) inputs_masks = x[:, :, -1] inputs_masks = inputs_masks.astype(jnp.bool_) x = x[:, :, :-3] n, l, channel = x.shape if hidden_size: if resnet_emb: # channel = patch_size * patch_size * 3 patch_size = int(np.sqrt(channel // 3)) x = jnp.reshape(x, [-1, patch_size, patch_size, 3]) x = resnet.StdConv( x, RESNET_TOKEN_DIM, (7, 7), (2, 2), bias=False, name="conv_root") x = nn.GroupNorm(x, name="gn_root") x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") if resnet_emb.num_layers > 0: blocks, bottleneck = resnet.get_block_desc(resnet_emb.num_layers) if blocks: x = resnet.ResNetStage( x, blocks[0], RESNET_TOKEN_DIM, first_stride=(1, 1), bottleneck=bottleneck, name="block1") for i, block_size in enumerate(blocks[1:], 1): x = resnet.ResNetStage( x, block_size, RESNET_TOKEN_DIM * 2**i, first_stride=(2, 2), bottleneck=bottleneck, name=f"block{i + 1}") x = jnp.reshape(x, [n, l, -1]) x = nn.Dense(x, hidden_size, name="embedding") # Here, x is a list of embeddings. x = utils.Encoder( x, inputs_spatial_positions, inputs_scale_positions, inputs_masks, train=train, name="Transformer", **transformer) x = x[:, 0] if representation_size: x = nn.Dense(x, representation_size, name="pre_logits") x = nn.tanh(x) else: x = resnet.IdentityLayer(x, name="pre_logits") x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros) if multi_crops_input: _, channel = x.shape x = jnp.reshape(x, [batch_size, num_crops, channel]) return x