def apply(self, x, num_filters=64, block_sizes=(3, 4, 6, 3), train=True, block=BottleneckBlock, small_inputs=False): if small_inputs: x = nn.Conv(x, num_filters, kernel_size=(3, 3), strides=(1, 1), bias=False, name="init_conv") else: x = nn.Conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, name="init_conv") x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name="init_bn") if not small_inputs: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block(x, num_filters * 2**i, strides=strides, train=train) return x
def apply(self, x, num_actions, num_atoms): initializer = jax.nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None): H_0 = nn.relu(nn.Dense(x, conv_rep_size)) G_0 = nn.relu(nn.Dense(x, conv_rep_size)) H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2) for layer in range(1, m_layers+1): if layer < m_layers: H_features, G_features = m_features[layer-1] else: H_features, G_features = conv_rep_size, conv_rep_size H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1] H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1)) G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) if layer < m_layers: H = nn.relu(H) G = nn.relu(G) else: H = nn.tanh(H) G = nn.sigmoid(G) H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2) F = H * G + G_0 rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size) return rep
def apply(self, x, num_actions): initializer = nn.initializers.xavier_uniform() # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) q_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.DQNNetworkType(q_values)
def apply(self, x, filters, strides=(1, 1), groups=1, base_width=64, train=True): needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1) width = int(filters * (base_width / 64.)) * groups batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = nn.Conv(x, width, (1, 1), (1, 1), bias=False, name='conv1') y = batch_norm(y, name='bn1') y = jax.nn.relu(y) y = nn.Conv(y, width, (3, 3), strides, bias=False, feature_group_count=groups, name='conv2') y = batch_norm(y, name='bn2') y = jax.nn.relu(y) y = nn.Conv(y, filters * 4, (1, 1), (1, 1), bias=False, name='conv3') y = batch_norm(y, name='bn3', scale_init=initializers.zeros) if needs_projection: x = nn.Conv(x, filters * 4, (1, 1), strides, bias=False, name='proj_conv') x = batch_norm(x, name='proj_bn') return jax.nn.relu(x + y)
def apply(self, x, channels, strides=(1, 1), train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) a = b = residual = x a = jax.nn.relu(a) a = nn.Conv(a, channels, (3, 3), strides, padding='SAME', name='conv_a_1') a = batch_norm(a, name='bn_a_1') a = jax.nn.relu(a) a = nn.Conv(a, channels, (3, 3), padding='SAME', name='conv_a_2') a = batch_norm(a, name='bn_a_2') b = jax.nn.relu(b) b = nn.Conv(b, channels, (3, 3), strides, padding='SAME', name='conv_b_1') b = batch_norm(b, name='bn_b_1') b = jax.nn.relu(b) b = nn.Conv(b, channels, (3, 3), padding='SAME', name='conv_b_2') b = batch_norm(b, name='bn_b_2') if train and not self.is_initializing(): ab = utils.shake_shake_train(a, b) else: ab = utils.shake_shake_eval(a, b) # Apply an up projection in case of channel mismatch if (residual.shape[-1] != channels) or strides != (1, 1): residual = nn.Conv(residual, channels, (3, 3), strides, padding='SAME', name='conv_residual') residual = batch_norm(residual, name='bn_residual') return residual + ab
def apply(self, x, channels, strides, prob, alpha_min, alpha_max, beta_min, beta_max, train=True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, features] where dim is the resolution (width and height if the input is an image). channels: How many channels to use in the convolutional layers. strides: Strides for the pooling. prob: Probability of dropping the block (see paper for details). alpha_min: See paper. alpha_max: See paper. beta_min: See paper. beta_max: See paper. train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the bottleneck block. """ y = utils.activation(x, apply_relu=False, train=train, name='bn_1_pre') y = nn.Conv(y, channels, (1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='1x1_conv_contract') y = utils.activation(y, train=train, name='bn_1_post') y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='3x3') y = utils.activation(y, train=train, name='bn_2') y = nn.Conv(y, channels * 4, (1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='1x1_conv_expand') y = utils.activation(y, apply_relu=False, train=train, name='bn_3') if train: y = utils.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min, beta_max) else: y = utils.shake_drop_eval(y, prob, alpha_min, alpha_max) x = _shortcut(x, channels * 4, strides) return x + y
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(x, features=64, kernel_size=(3, 3)) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=256) x = nn.relu(x) x = nn.Dense(x, features=10) x = nn.log_softmax(x) return x
def apply(self, x, channels, strides = (1, 1), train = True): """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, features] where dim is the resolution (width and height if the input is an image). channels: How many channels to use in the convolutional layers. strides: Strides for the pooling. train: If False, will use the moving average for batch norm statistics. Returns: The output of the resnet block. Will have shape [batch_size, dim, dim, channels] if strides = (1, 1) or [batch_size, dim/2, dim/2, channels] if strides = (2, 2). """ if x.shape[-1] == channels: return x # Skip path 1 h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID') h1 = nn.Conv( h1, channels // 2, (1, 1), strides=(1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_h1') # Skip path 2 # The next two lines offset the "image" by one pixel on the right and one # down (see Shake-Shake regularization, Xavier Gastaldi for details) pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :] h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID') h2 = nn.Conv( h2, channels // 2, (1, 1), strides=(1, 1), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv_h2') merged_branches = jnp.concatenate([h1, h2], axis=3) return utils.activation( merged_branches, apply_relu=False, train=train, name='bn_residual')
def apply(self, x, use_squeeze_excite=False): x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID") x = nn.relu(x) if use_squeeze_excite: x = SqueezeExciteLayer(x) x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID") scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0] return scores
def apply(self, x: jnp.ndarray, channels: int, strides: Tuple[int, int] = (1, 1), activate_before_residual: bool = False, train: bool = True) -> jnp.ndarray: """Implements the forward pass in the module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, features] where dim is the resolution (width and height if the input is an image). channels: How many channels to use in the convolutional layers. strides: Strides for the pooling. activate_before_residual: True if the batch norm and relu should be applied before the residual branches out (should be True only for the first block of the model). train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. Returns: The output of the resnet block. """ if activate_before_residual: x = utils.activation(x, train, name='init_bn') orig_x = x else: orig_x = x block_x = x if not activate_before_residual: block_x = utils.activation(block_x, train, name='init_bn') block_x = nn.Conv(block_x, channels, (3, 3), strides, padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv1') block_x = utils.activation(block_x, train=train, name='bn_2') block_x = nn.Conv(block_x, channels, (3, 3), padding='SAME', bias=False, kernel_init=utils.conv_kernel_init_fn, name='conv2') return _output_add(block_x, orig_x)
def apply(self, x, num_actions, quantile_embedding_dim, num_quantiles, rng): initializer = jax.nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten state_vector_length = x.shape[-1] state_net_tiled = jnp.tile(x, [num_quantiles, 1]) quantiles_shape = [num_quantiles, 1] quantiles = jax.random.uniform(rng, shape=quantiles_shape) quantile_net = jnp.tile(quantiles, [1, quantile_embedding_dim]) quantile_net = ( jnp.arange(1, quantile_embedding_dim + 1, 1).astype(jnp.float32) * onp.pi * quantile_net) quantile_net = jnp.cos(quantile_net) quantile_net = nn.Dense(quantile_net, features=state_vector_length, kernel_init=initializer) quantile_net = jax.nn.relu(quantile_net) x = state_net_tiled * quantile_net x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) quantile_values = nn.Dense(x, features=num_actions, kernel_init=initializer) return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)
def apply(self, x): x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv") x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten. x = nn.Dense(x, 128, name="fc") return x
def apply(self,x, num_actions, noisy, dueling): # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. initializer_conv = nn.initializers.xavier_uniform() x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=initializer_conv) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten if noisy: print('InvadersDDQNNetwork-Noisy[Johan]') initializer = None bias = True def net(x, features, bias, kernel_init): return NoisyNetwork(x, features, bias, kernel_init) else: initializer = nn.initializers.xavier_uniform() bias = None def net(x, features, bias, kernel_init): return nn.Dense(x, features, kernel_init) if dueling: print('InvadersDDQNNetwork-Dueling[Johan]') adv = net(x, features=num_actions, bias=bias, kernel_init=initializer) val = net(x, features=1, bias=bias, kernel_init=initializer) q_values = val + (adv - (jnp.mean(adv, -1, keepdims=True))) else: q_values = net(x, features=num_actions, bias=bias, kernel_init=initializer) return atari_lib.DQNNetworkType(q_values)
def apply( self, x, num_outputs, num_filters=64, block_sizes=[3, 4, 6, 3], # pylint: disable=dangerous-default-value train=True): x = nn.Conv(x, num_filters, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, num_filters * 2**i, strides=strides, train=train) x = jnp.mean(x, axis=(1, 2)) x_clf = nn.Dense(x, num_outputs, name='clf') # We return both the outputs from the dense layer *and* the features # that go into it. return x_clf, x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, train=True): x = nn.Conv( x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetShakeShakeGroup( x, blocks_per_group, 16 * channel_multiplier, train=train) x = WideResnetShakeShakeGroup( x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train) x = WideResnetShakeShakeGroup( x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, dtype=jnp.float32): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes) x = nn.log_softmax(x) return x
def apply(self, x): b = x.shape[0] x = nn.Conv(x, features=128, kernel_size=(4, ), padding='SAME') x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.avg_pool(x, window_shape=(2, ), padding='SAME') x = nn.Conv(x, features=256, kernel_size=(4, ), padding='SAME') x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.avg_pool(x, window_shape=(2, ), padding='SAME') x = x.reshape(b, -1) x = nn.Dense(x, features=128) x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.Dense(x, features=n_bins) x = nn.softmax(x) return x
def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = batch_norm(x, name='bn1') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1') y = batch_norm(y, name='bn2') y = jax.nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2') # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv(x, channels, (3, 3), strides, padding='SAME') return x + y
def apply(self, x, num_classes=1000, train=False, resnet=None, patches=None, hidden_size=None, transformer=None, representation_size=None, classifier='gap'): n, h, w, c = x.shape # Embed the grid or patches of the grid. fh, fw = patches.size gh, gw = h // fh, w // fw if hidden_size: # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv( x, hidden_size, (fh, fw), strides=(fh, fw), padding='VALID', name='embedding') else: # This path often results in excessive padding. x = jnp.reshape(x, [n, gh, fh, gw, fw, c]) x = jnp.transpose(x, [0, 1, 3, 2, 4, 5]) x = jnp.reshape(x, [n, gh, gw, -1]) # Here, x is a grid of embeddings. # (Possibly partial) Transformer. if transformer is not None: n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if classifier == 'token': cls = self.param('cls', (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(x, train=train, name='Transformer', **transformer) if classifier == 'token': x = x[:, 0] elif classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) if representation_size is not None: x = nn.Dense(x, representation_size, name='pre_logits') x = nn.tanh(x) else: x = IdentityLayer(x, name='pre_logits') x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros) return x
def apply(self, x, num_actions, minatar, env, normalize_obs, noisy, dueling, num_atoms,hidden_layer=2, neurons=512): del normalize_obs if minatar: x = x.squeeze(3) x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) else: x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) if env is not None: x = x - env_inf[env]['MIN_VALS'] x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS'] x = 2.0 * x - 1.0 if noisy: def net(x, features): return NoisyNetwork(x, features) else: def net(x, features): return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform()) for _ in range(hidden_layer): x = net(x, features=neurons) #print('x:',x) x = jax.nn.relu(x) if dueling: print('dueling') adv = net(x,features=num_actions * num_atoms) value = net(x, features=num_atoms) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) else: #print('No dueling') x = net(x, features=num_actions * num_atoms) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x, channels, strides=(1, 1), conv_kernel_init=initializers.lecun_normal(), normalizer='batch_norm', train=True): maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(x, name='bn1') y = jax.nn.relu(y) # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv( y, channels, (1, 1), # Note: Some implementations use (3, 3) here. strides, padding='SAME', kernel_init=conv_kernel_init, bias=False) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1', kernel_init=conv_kernel_init, bias=False) y = maybe_normalize(y, name='bn2') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2', kernel_init=conv_kernel_init, bias=False) if normalizer == 'none': y = model_utils.ScalarMultiply(y) return x + y
def apply(self, x: jnp.ndarray, blocks_per_group: int, channel_multiplier: int, num_outputs: int, train: bool = True, true_gradient: bool = False) -> jnp.ndarray: """Implements a WideResnet with ShakeShake regularization module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, 3] where dim is the resolution of the image. blocks_per_group: How many resnet blocks to add to each group (should be 4 blocks for a WRN26 as per standard shake shake implementation). channel_multiplier: The multiplier to apply to the number of filters in the model (1 is classical resnet, 6 for WRN26-2x6, etc...). num_outputs: Dimension of the output of the model (ie number of classes for a classification problem). train: If False, will use the moving average for batch norm statistics. Else, will use statistics computed on the batch. true_gradient: If true, the same mixing parameter will be used for the forward and backward pass (see paper for more details). Returns: The output of the WideResnet with ShakeShake regularization, a tensor of shape [batch_size, num_classes]. """ x = nn.Conv(x, 16, (3, 3), padding='SAME', kernel_init=utils.conv_kernel_init_fn, bias=False, name='init_conv') x = utils.activation(x, apply_relu=False, train=train, name='init_bn') x = WideResnetShakeShakeGroup(x, blocks_per_group, 16 * channel_multiplier, train=train, true_gradient=true_gradient) x = WideResnetShakeShakeGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train, true_gradient=true_gradient) x = WideResnetShakeShakeGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train, true_gradient=true_gradient) x = jax.nn.relu(x) x = nn.avg_pool(x, x.shape[1:3]) x = x.reshape((x.shape[0], -1)) return nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn)
def apply(self, inputs, output_dim, kernel_size=3, biases=True): output = nn.Conv(inputs, features=output_dim, kernel_size=(kernel_size, kernel_size), strides=(1, 1), padding='SAME', bias=biases) output = sum([ output[:, ::2, ::2, :], output[:, 1::2, ::2, :], output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :] ]) / 4. return output
def apply(self, x, num_outputs): """Define the convolutional network architecture. Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533. Note that this is different than the one from "Playing atari with deep reinforcement learning." arxiv.org/abs/1312.5602 (2013) """ dtype = jnp.float32 x = x.astype(dtype) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', dtype=dtype) x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', dtype=dtype) x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', dtype=dtype) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, name='hidden', dtype=dtype) x = nn.relu(x) # Network used to both estimate policy (logits) and expected state value. # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(x, features=1, name='value', dtype=dtype) return policy_log_probabilities, value
def apply(self, x, n_layers, n_features, n_kernel_sizes): x = jnp.expand_dims(x, axis=2) for layer in range(n_layers): features = n_features[layer] kernel_size = (n_kernel_sizes[layer], 1) x = nn.Conv(x, features=features, kernel_size=kernel_size) x = nn.relu(x) x = jnp.squeeze(x, axis=2) return x
def apply(self, x, channels, strides, prob, alpha_min, alpha_max, beta_min, beta_max, train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = batch_norm(x, name='bn_1_pre') y = nn.Conv(y, channels, (1, 1), padding='SAME', name='1x1_conv_contract') y = batch_norm(y, name='bn_1_post') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='3x3') y = batch_norm(y, name='bn_2') y = jax.nn.relu(y) y = nn.Conv(y, channels * 4, (1, 1), padding='SAME', name='1x1_conv_expand') y = batch_norm(y, name='bn_3') if train: y = shake.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min, beta_max) else: y = shake.shake_drop_eval(y, prob, alpha_min, alpha_max) x = shortcut(x, channels * 4, strides) return x + y
def apply(self, x: jnp.ndarray, blocks_per_group: int, channel_multiplier: int, num_outputs: int, train: bool = True) -> jnp.ndarray: """Implements a WideResnet module. Args: x: Input to the module. Should have shape [batch_size, dim, dim, 3] where dim is the resolution of the image. blocks_per_group: How many resnet blocks to add to each group (should be 4 blocks for a WRN28, and 6 for a WRN40). channel_multiplier: The multiplier to apply to the number of filters in the model (1 is classical resnet, 10 for WRN28-10, etc...). num_outputs: Dimension of the output of the model (ie number of classes for a classification problem). train: If False, will use the moving average for batch norm statistics. Returns: The output of the WideResnet, a tensor of shape [batch_size, num_classes]. """ first_x = x x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv', kernel_init=utils.conv_kernel_init_fn, bias=False) x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, activate_before_residual=True, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), train=train) if FLAGS.use_additional_skip_connections: x = _output_add(x, first_x) x = utils.activation(x, train=train, name='pre-pool-bn') x = nn.avg_pool(x, x.shape[1:3]) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=utils.dense_layer_init_fn) return x
def apply(self, x, num_outputs, num_filters=64, num_layers=50, train=True, batch_stats=None, dtype=jnp.float32, batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, virtual_batch_size=None, data_format=None): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] x = nn.Conv(x, num_filters, (3, 3), (1, 1), 'SAME', bias=False, dtype=dtype, name='init_conv') x = normalization.VirtualBatchNorm( x, batch_stats=batch_stats, use_running_average=not train, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon, dtype=dtype, name='init_bn', virtual_batch_size=virtual_batch_size, data_format=data_format) x = nn.relu(x) residual_block = block_type_options[num_layers] for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = residual_block(x, num_filters * 2**i, strides=strides, train=train, batch_stats=batch_stats, dtype=dtype, batch_norm_momentum=batch_norm_momentum, batch_norm_epsilon=batch_norm_epsilon, virtual_batch_size=virtual_batch_size, data_format=data_format) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, dtype=dtype) return x