def apply( self, x, action_dim, max_action, key=None, MPO=False, sample=False, log_sig_min=-20, log_sig_max=2, ): x = nn.Dense(x, features=200) x = nn.LayerNorm(x) x = nn.tanh(x) x = nn.Dense(x, features=200) x = nn.elu(x) x = nn.Dense(x, features=2 * action_dim) mu, log_sig = jnp.split(x, 2, axis=-1) log_sig = nn.softplus(log_sig) log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max) if MPO: return mu, log_sig if not sample: return max_action * nn.tanh(mu), log_sig else: pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig) log_pi = gaussian_likelihood(pi, mu, log_sig) pi = nn.tanh(pi) log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1) return max_action * pi, log_pi
def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=True, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim x = nn.Dense( inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.gelu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense( x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) return output
def apply(self, x, num_actions): initializer = nn.initializers.xavier_uniform() # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) q_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.DQNNetworkType(q_values)
def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), num_partitions=2): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim inputs_shape = inputs.shape inputs = inputs.reshape((-1, inputs_shape[-1])) x = nn.Dense(inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.relu(x) if num_partitions > 1: x = with_sharding_constraint(x, P(1, num_partitions)) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense(x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) output = output.reshape(inputs_shape[:-1] + (actual_out_dim, )) return output
def apply(self, x): x = nn.Dense(x, features=32) x = nn.sigmoid(x) x = nn.Dense(x, features=32) x = nn.sigmoid(x) x = nn.Dense(x, features=1) return nn.sigmoid(x)
def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None): H_0 = nn.relu(nn.Dense(x, conv_rep_size)) G_0 = nn.relu(nn.Dense(x, conv_rep_size)) H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2) for layer in range(1, m_layers+1): if layer < m_layers: H_features, G_features = m_features[layer-1] else: H_features, G_features = conv_rep_size, conv_rep_size H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1] H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1)) G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) if layer < m_layers: H = nn.relu(H) G = nn.relu(G) else: H = nn.tanh(H) G = nn.sigmoid(G) H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2) F = H * G + G_0 rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size) return rep
def apply(self, s, layers=[10 ], bias=False, actFun=[ jax.nn.elu, ]): for l in range(len(actFun), len(layers) + 1): actFun.append(actFun[-1]) s = 2 * s - 1 for l, fun in zip(layers, actFun[:-1]): s = fun( nn.Dense(s, features=l, bias=bias, dtype=global_defs.tReal, kernel_init=jax.nn.initializers.lecun_normal( dtype=global_defs.tReal), bias_init=partial(jax.nn.initializers.zeros, dtype=global_defs.tReal))) return jnp.sum(actFun[-1](nn.Dense( s, features=1, bias=bias, dtype=global_defs.tReal, kernel_init=jax.nn.initializers.lecun_normal( dtype=global_defs.tReal), bias_init=partial(jax.nn.initializers.zeros, dtype=global_defs.tReal))))
def classifier_head(encoded, num_classes, mlp_dim, pooling_mode='MEAN'): """Classifier head. We put this here just so that all models consistently call the same function. Args: encoded: tensor inputs are shape of [bs, len, dim]. num_classes: int, number of classes mlp_dim: int, dim of intermediate MLP. pooling_mode: str, string dictating pooling op {MEAN} Returns: tensor of shape [bs, num_classes] """ if pooling_mode == 'MEAN': encoded = jnp.mean(encoded, axis=1) elif pooling_mode == 'SUM': encoded = jnp.sum(encoded, axis=1) elif pooling_mode == 'FLATTEN': encoded = encoded.reshape((encoded.shape[0], -1)) elif pooling_mode == 'CLS': encoded = encoded[:, 0] else: raise NotImplementedError('Pooling not supported yet.') encoded = nn.Dense(encoded, mlp_dim, name='mlp') encoded = nn.relu(encoded) encoded = nn.Dense(encoded, num_classes, name='logits') return encoded
def apply(self, x, num_actions): initializer = nn.initializers.xavier_uniform() # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten #x -= gym_lib.CARTPOLE_MIN_VALS #x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS #x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) print('x', x.shape, len(x)) adv = nn.Dense(x, features=num_actions, kernel_init=initializer) val = nn.Dense(x, features=1, kernel_init=initializer) #q_values = nn.Dense(x, features=num_actions, kernel_init=initializer) # https://jax.readthedocs.io/en/latest/_modules/jax/nn/functions.html (JAX Mean) #q_values = val + (adv - (jnp.mean(adv, 1, keepdims=True))) q_values = val + (adv - (jnp.mean(adv, -1, keepdims=True))) return atari_lib.DQNNetworkType(q_values)
def apply(self, x, action_dim, max_action): x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=action_dim) return max_action * nn.tanh(x)
def apply( self, hidden_states, *, d_ff: int, dropout_rate: float = 0.0, intermediate_activation=nn.gelu, # TODO(kitaev): chunk_size hparam for chunking kernel_init=nn.initializers.xavier_uniform(), deterministic: bool = False): """Applies FeedForward module.""" d_model = hidden_states.shape[-1] hidden_states = nn.Dense(hidden_states, d_ff, kernel_init=kernel_init, name='intermediate') hidden_states = intermediate_activation(hidden_states) hidden_states = nn.Dense(hidden_states, d_model, kernel_init=kernel_init, name='output') hidden_states = nn.dropout(hidden_states, rate=dropout_rate, deterministic=deterministic) return hidden_states
def apply(self, x, num_actions, num_atoms): initializer = jax.nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x, hidden_layers, hidden_dim, n_classes): x = jnp.reshape(x, (x.shape[0], -1)) for layer in range(hidden_layers): x = nn.Dense(x, hidden_dim, name=f'fc{layer}') x = nn.relu(x) x = nn.Dense(x, n_classes, name=f'fc{hidden_layers}') preds = nn.log_softmax(x) return preds
def apply(self, x, reduction=16): num_channels = x.shape[-1] y = x.mean(axis=(1, 2)) y = nn.Dense(y, features=num_channels // reduction, bias=False) y = nn.relu(y) y = nn.Dense(y, features=num_channels, bias=False) y = nn.sigmoid(y) return x * y[:, None, None, :]
def apply(self, x, num_classes=1000, train=False, resnet=None, patches=None, hidden_size=None, transformer=None, representation_size=None, classifier='gap'): n, h, w, c = x.shape # Embed the grid or patches of the grid. fh, fw = patches.size gh, gw = h // fh, w // fw if hidden_size: # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv( x, hidden_size, (fh, fw), strides=(fh, fw), padding='VALID', name='embedding') else: # This path often results in excessive padding. x = jnp.reshape(x, [n, gh, fh, gw, fw, c]) x = jnp.transpose(x, [0, 1, 3, 2, 4, 5]) x = jnp.reshape(x, [n, gh, gw, -1]) # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if classifier == 'token': cls = self.param('cls', (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(x, train=train, name='Transformer', **transformer) if classifier == 'token': x = x[:, 0] elif classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) if representation_size is not None: x = nn.Dense(x, representation_size, name='pre_logits') x = nn.tanh(x) else: x = IdentityLayer(x, name='pre_logits') x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros) return x
def apply(self, x): real = nn.Dense(x, 25) real = jnp.sin(real) real = nn.Dense(real, 2) imag = nn.Dense(x, 25) imag = jnp.sin(imag) imag = nn.Dense(imag, 2) imag = jnp.pi * nn.soft_sign(imag) return real * jnp.exp(1j * imag)
def apply(self, x): net = nn.Dense(x, 500, name='fc1') net = nn.leaky_relu(net) net = nn.BatchNorm(net) net = nn.Dense(net, 500, name='fc2') net = nn.leaky_relu(net) net = nn.BatchNorm(net) net = nn.Dense(net, 500, name='fc3') net = nn.leaky_relu(net) net = nn.BatchNorm(net) return nn.softmax(nn.Dense(net, n_bin))
def classifier(x, num_outputs, dropout_rate, deterministic): """Implements the classification portion of the network.""" x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = nn.Dense(x, 512) x = nn.relu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = nn.Dense(x, 512) x = nn.relu(x) x = nn.Dense(x, num_outputs) return x
def apply(self, x, num_actions): initializer = nn.initializers.xavier_uniform() x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) q_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.DQNNetworkType(q_values)
def apply(self, x, L=10, hiddenSize=10, inputDim=2, actFun=nn.elu, initScale=1.0, logProbFactor=0.5): rnnCell = RNNCell.shared(hiddenSize=hiddenSize, outDim=hiddenSize, actFun=actFun, initScale=initScale, name="myCell") probDense = nn.Dense.shared( features=inputDim, name="probDense", dtype=global_defs.tReal, kernel_init=jax.nn.initializers.lecun_normal( dtype=global_defs.tReal), bias_init=partial(jax.nn.initializers.zeros, dtype=global_defs.tReal)) state = jnp.zeros((hiddenSize, )) def rnn_cell(carry, x): newCarry, out = rnnCell(carry[0], carry[1]) logProb = nn.log_softmax(actFun(probDense(out))) logProb = jnp.sum(logProb * x, axis=-1) return (newCarry, x), (jnp.nan_to_num(logProb, nan=-35), out) _, (probs, phaseOut) = jax.lax.scan(rnn_cell, (state, jnp.zeros(inputDim)), jax.nn.one_hot(x, inputDim)) phase = nn.Dense(phaseOut, features=6, dtype=global_defs.tReal, kernel_init=jax.nn.initializers.lecun_normal( dtype=global_defs.tReal), bias_init=partial(jax.nn.initializers.zeros, dtype=global_defs.tReal)) phase = actFun(phase) phase = nn.Dense(phaseOut, features=4, dtype=global_defs.tReal, kernel_init=jax.nn.initializers.lecun_normal( dtype=global_defs.tReal), bias_init=partial(jax.nn.initializers.zeros, dtype=global_defs.tReal)) return logProbFactor * jnp.sum(probs, axis=0) + 1.j * jnp.mean( actFun(phase))
def apply(self, x): x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=50) x = nn.tanh(x) x = nn.Dense(x, features=1) return x
def apply(self, actions, num_layers, hidden_dims): timesteps = actions.shape[1] # flatten time into batch actions = jnp.reshape(actions, (-1, ) + actions.shape[2:]) # embed actions x = nn.Dense(actions, hidden_dims) for _ in range(num_layers): x = nn.Dense(x, hidden_dims) x = nn.LayerNorm(x) x = nn.relu(x) x = nn.Dense(x, 1) x = jnp.reshape(x, (-1, timesteps, 1)) return x
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(x, features=64, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=10) x = nn.log_softmax(x) return x
def apply(self, x, num_actions): initializer = nn.initializers.xavier_uniform() x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x -= gym_lib.ACROBOT_MIN_VALS x /= gym_lib.ACROBOT_MAX_VALS - gym_lib.ACROBOT_MIN_VALS x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) q_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.DQNNetworkType(q_values)
def apply(self, inputs: jnp.ndarray, hidden_size: int = None, output_size: int = None, output_bias: bool = False, dropout: float = None, train: bool = None): # inputs.shape = <float32>[batch_size, seq_length, hidden_size] hidden = nn.Dense(inputs, hidden_size, name='hidden') hidden = nn.tanh(hidden) if train: hidden = nn.dropout(hidden, rate=dropout) output = nn.Dense(hidden, output_size, bias=output_bias, name='output') return output
def classifier_head_dual(encoded1, encoded2, num_classes, mlp_dim, pooling_mode='MEAN', interaction=None): """Classifier head for dual encoding or pairwise problem. We put this here just so that all models consistently call the same function. Args: encoded1: tensor inputs are shape of [bs, len, dim]. encoded2: tensor inputs are shape of [bs, len, dim]. num_classes: int, number of classes mlp_dim: int, dim of intermediate MLP. pooling_mode: str, string dictating pooling op {MEAN} interaction: str, string dictating interaction between e1, e2 Returns: tensor of shape [bs, num_classes] """ if pooling_mode == 'MEAN': encoded1 = jnp.mean(encoded1, axis=1) encoded2 = jnp.mean(encoded2, axis=1) elif pooling_mode == 'SUM': encoded1 = jnp.sum(encoded1, axis=1) encoded2 = jnp.sum(encoded2, axis=1) elif pooling_mode == 'FLATTEN': encoded1 = encoded1.reshape((encoded1.shape[0], -1)) encoded2 = encoded2.reshape((encoded2.shape[0], -1)) elif pooling_mode == 'CLS': encoded1 = encoded1[:, 0] encoded2 = encoded2[:, 0] else: raise NotImplementedError('Pooling not supported yet.') if interaction == 'NLI': # NLI interaction style encoded = jnp.concatenate( [encoded1, encoded2, encoded1 * encoded2, encoded1 - encoded2], 1) else: encoded = jnp.concatenate([encoded1, encoded2], 1) encoded = nn.Dense(encoded, mlp_dim, name='mlp') encoded = nn.relu(encoded) encoded = nn.Dense(encoded, int(mlp_dim // 2), name='mlp2') encoded = nn.relu(encoded) encoded = nn.Dense(encoded, num_classes, name='logits') return encoded
def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles, rng): initializer = jax.nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten state_vector_length = x.shape[-1] state_net_tiled = jnp.tile(x, [num_quantiles, 1]) quantiles_shape = [num_quantiles, 1] quantiles = jax.random.uniform(rng, shape=quantiles_shape) quantile_net = jnp.tile(quantiles, [1, quantile_embedding_dim]) quantile_net = ( jnp.arange(1, quantile_embedding_dim + 1, 1).astype(jnp.float32) * onp.pi * quantile_net) quantile_net = jnp.cos(quantile_net) quantile_net = nn.Dense(quantile_net, features=state_vector_length, kernel_init=initializer) quantile_net = jax.nn.relu(quantile_net) x = state_net_tiled * quantile_net x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) quantile_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)
def apply(self, x, act, normalize, temb=None, out_ch=None, conv_shortcut=False, dropout=0.1, train=True, skip_rescale=False, init_scale=0.): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32))) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = nn.dropout(h, dropout, deterministic=not train) h = conv3x3(h, out_ch, init_scale=init_scale) if C != out_ch: if conv_shortcut: x = conv3x3(x, out_ch) else: x = NIN(x, out_ch) if not skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def apply(self, x, num_actions): initializer = nn.initializers.xavier_uniform() # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x -= gym_lib.CARTPOLE_MIN_VALS x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) q_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.DQNNetworkType(q_values)
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x