def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = stax.Serial(stax.BatchNorm(), stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut), stax.FanInSum())
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.Serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu(), stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(filters3, (1, 1)), stax.BatchNorm()) return stax.Serial(stax.FanOut(), stax.Parallel(main, stax.Identity()), stax.FanInSum(), stax.Relu())
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.Serial(stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu(), stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(filters3, (1, 1)), stax.BatchNorm()) shortcut = stax.Serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut), stax.FanInSum(), stax.Relu())
def test_dense_param_sharing(self): model1 = stax.Serial(stax.Dense(32), stax.Dense(32)) layer = stax.Dense(32) model2 = stax.Serial(layer, layer) init_fun1, _ = model1 init_fun2, _ = model2 rng = random.get_prng(0) _, params1 = init_fun1(rng, [-1, 32]) _, params2 = init_fun2(rng, [-1, 32]) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=stax.Relu, num_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode layers = [stax.Flatten()] for _ in range(num_hidden_layers): layers += [stax.Dense(hidden_size), activation_fn()] layers += [stax.Dense(num_output_classes), stax.LogSoftmax()] return stax.Serial(*layers)
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy and value net function.""" # Layers. layers = [] if bottom_layers is not None: layers.extend(bottom_layers) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. layers.extend([stax.FanOut(), stax.Parallel( stax.Serial(stax.Dense(num_actions), stax.Softmax()), stax.Dense(1) )]) net_init, net_apply = stax.Serial(layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def policy_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy net function.""" # Use the bottom_layers as the bottom part of the network and just add the # required layers on top of it. if bottom_layers is None: bottom_layers = [] bottom_layers.extend([stax.Dense(num_actions), stax.Softmax()]) net_init, net_apply = stax.Serial(bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A value net function.""" del num_actions if bottom_layers is None: bottom_layers = [] bottom_layers.extend([ stax.Dense(1), ]) net_init, net_apply = stax.Serial(bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. mode: is it training or eval. Returns: The WideResnet model with given layer and output sizes. """ del mode return stax.Serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu(), stax.AvgPool(pool_size=(8, 8)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax())
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. mode: whether we are training or evaluating or doing inference. Returns: The ResNet model with the given layer and output sizes. """ del mode return stax.Serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu(), stax.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), stax.AvgPool(pool_size=(7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax())
def WideResnetGroup(n, channels, strides=(1, 1)): blocks = [] blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)] for _ in range(n - 1): blocks += [WideResnetBlock(channels, (1, 1))] return stax.Serial(*blocks)