def ConvBlock(kernel_size, filters, strides=(2, 2)): ks = kernel_size filters1, filters2, filters3 = filters Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
def DeepQNetwork(): init_fun, predict_fun = stax.serial( Conv(16, (8, 8), strides=(4, 4)), Relu, Conv(32, (4, 4), strides=(2, 2)), Relu, Conv(64, (3, 3)), Relu, Flatten, Dense(256), Relu, Dense(6) ) return init_fun, predict_fun
def construct_main(inp_shape): return stax.serial( Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(), Relu, Conv(filters[1], (ks, ks), padding="SAME"), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm(), )
def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial( Conv(filters1, (1, 1)), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding="SAME"), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm(), )
def state_encoder(output_num): return serial( Conv(4, (3, 3), (1, 1), "SAME"), Tanh, # BatchNorm(), Conv(4, (3, 3), (1, 1), "SAME"), Tanh, # BatchNorm(), Conv(4, (3, 3), (1, 1), "SAME"), Tanh, # BatchNorm(), Flatten, Dense(128), Tanh, # BatchNormつけるとなぜか出力が固定値になる, Dense(output_num))
def CnnSmall(): """CNN used for the Catch and Key-to-Door experiments in: https://arxiv.org/abs/2102.12425""" return serial( Conv(32, (2, 2), (1, 1), "VALID"), Relu, Conv(64, (2, 2), (1, 1), "VALID"), Relu, Flatten, Dense(256), Relu, )
def Cnn(n_actions: int, hidden_size) -> Module: return serial( Conv(32, (8, 8), (4, 4), "VALID"), Relu, Conv(64, (4, 4), (2, 2), "VALID"), Relu, Conv(64, (3, 3), (1, 1), "VALID"), Relu, Flatten, Dense(hidden_size), Relu, Dense(n_actions), )
def CnnLarge(): """CNN used for the Pong and Skiing experiments in: https://arxiv.org/abs/2102.12425""" return serial( Conv(32, (3, 3), (2, 2), "VALID"), Relu, Conv(64, (3, 3), (2, 2), "VALID"), Relu, Conv(64, (3, 3), (2, 2), "VALID"), Relu, Flatten, Dense(256), Relu, )
def init_conv_affine_coupling(rng, in_shape, n_channels, flip, sigmoid=True, init_batch=None): """ in_shape: tuple of (h, w, c) """ h, w, c = in_shape assert c % 2 == 0, "channels must be even doooooooog!" half_c = c // 2 net_init, net_apply = stax.serial(Conv(n_channels, (3, 3), padding="SAME"), Relu, Conv(n_channels, (3, 3), padding="SAME"), Relu, Conv(c, (3, 3), padding="SAME")) _, net_params = net_init(rng, (-1, h, w, half_c)) def shift_and_log_scale_fn(net_params, x1): s = net_apply(net_params, x1) return np.split(s, 2, axis=3) def conv_coupling_forward(net_params, prev_sample, prev_logp=0.): x1, x2 = prev_sample[:, :, :, :half_c], prev_sample[:, :, :, half_c:] if flip: x2, x1 = x1, x2 shift, log_scale = shift_and_log_scale_fn(net_params, x1) if sigmoid: log_scale = log_sigmoid(log_scale + 2.) y2 = x2 * np.exp(log_scale) + shift if flip: x1, y2 = y2, x1 y = np.concatenate([x1, y2], axis=-1) return y, prev_logp + np.sum(log_scale, axis=(1, 2, 3)) def conv_coupling_reverse(net_params, next_sample, next_logp=0.): y1, y2 = next_sample[:, :, :, :half_c], next_sample[:, :, :, half_c:] if flip: y1, y2 = y2, y1 shift, log_scale = shift_and_log_scale_fn(net_params, y1) if sigmoid: log_scale = log_sigmoid(log_scale + 2.) x2 = (y2 - shift) * np.exp(-log_scale) if flip: y1, x2 = x2, y1 x = np.concatenate([y1, x2], axis=-1) return x, next_logp - np.sum(log_scale, axis=(1, 2, 3)) return net_params, conv_coupling_forward, conv_coupling_reverse
def convBlock(ks, filters, stride=(1, 1)): Main = stax.serial(Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(), Relu, Conv(filters[1], (ks, ks), strides=stride), BatchNorm(), Relu, Conv(filters[2], (1, 1), strides=(1, 1)), BatchNorm(), Relu) Shortcut = stax.serial( Conv(filters[3], (1, 1), strides=stride), BatchNorm(), ) fullInternal = stax.parallel(Main, Shortcut) return stax.serial(FanOut(2), fullInternal, FanInSum, Relu)
def feature_extractor(rng, dim): """Feature extraction network.""" init_params, forward = stax.serial( Conv(16, (8, 8), padding='SAME', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Conv(32, (4, 4), padding='VALID', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Flatten, Dense(dim), ) temp, rng = random.split(rng) params = init_params(temp, (-1, 28, 28, 1))[1] return params, forward
def init_nn(): """ Initialize Stax model. This function can be customized as needed to define architecture """ layers = [ Conv(16, (3, 3)), Relu, Conv(16, (3, 3)), Relu, Flatten, Dense(10), LogSoftmax, ] return stax.serial(*layers)
def conv(): init_fun, predict = stax.serial( Conv(16, (8, 8), padding='SAME', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Conv(32, (4, 4), padding='VALID', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Flatten, Dense(32), Relu, Dense(10), ) def init_params(rng): return init_fun(rng, (-1, 28, 28, 1))[1] return init_params, predict
def LeNet5(batch_size, num_particles): input_shape = _input_shape(batch_size) return make_model( stax.serial( GeneralConv(('NCHW', 'OIHW', 'NHWC'), out_chan=6, filter_shape=(5, 5), strides=(1, 1), padding="VALID"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="VALID"), Conv(out_chan=16, filter_shape=(5, 5), strides=(1, 1), padding="SAME"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), Conv(out_chan=120, filter_shape=(5, 5), strides=(1, 1), padding="VALID"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), Flatten, Dense(84), Relu, Dense(10), LogSoftmax), input_shape, num_particles)
def __init__(self, filters: int, kernel_size: Tuple[int, int], strides: Tuple[int, int] = (1, 1), padding: str = "valid", activation: str = "linear") -> None: layer_activation: Tuple[Callable, Callable] = getattr(activations, activation)() self.layer: List = [ Conv(out_chan=filters, filter_shape=kernel_size, strides=strides, padding=padding.upper()), layer_activation ]
def LeNet5(num_classes): return stax.serial( GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'), BatchNorm(), Relu, AvgPool((3,3)), Conv(16, (5,5), strides = (1,1),padding="SAME"), BatchNorm(), Relu, AvgPool((3,3)), Flatten, Dense(num_classes*10), Dense(num_classes*5), Dense(num_classes), LogSoftmax )
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(np.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial( Conv(10, (5, 5), (1, 1)), Activator, MaxPool((4, 4)), Flatten, Dense(10), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 # input shape for CNN input_shape = (-1, 28, 28, 1) # training/test split
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(np.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial(Conv(10, (5, 5), (1, 1)), Activator, MaxPool((4, 4)), Flatten, Dense(24), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 # input shape for CNN input_shape = (-1, 28, 28, 1) # training/test split
def MyConv(*args, **kwargs): return Conv(*args, **kwargs)
def conv_net(mode="train"): out_dim = 1 dim_nums = ("NHWC", "HWIO", "NHWC") unit_stride = (1,1) zero_pad = ((0,0), (0,0)) # Primary convolutional layer. conv_channels = 32 conv_init, conv_apply = Conv(out_chan=conv_channels, filter_shape=(3,3), strides=(1,3), padding=zero_pad) # Group all possible pairs. pair_channels, filter_shape = 256, (1, 2) # Convolutional block with the same number of channels. block_channels = pair_channels conv_block_init, conv_block_apply = serial(Conv(block_channels, (1,3), unit_stride, "SAME"), Relu, # One block of convolutions. Conv(block_channels, (1,3), unit_stride, "SAME"), Relu, Conv(block_channels, (1,3), unit_stride, "SAME")) # Forward pass. hidden_size = 2048 dropout_rate = 0.25 serial_init, serial_apply = serial(Conv(block_channels, (1,3), (1, 3), zero_pad), Relu, # Using convolution with strides Flatten, Dense(hidden_size), # instead of pooling for downsampling. # Dropout(dropout_rate, mode), Relu, Dense(out_dim)) def init_fun(rng, input_shape): rng, conv_rng, block_rng, serial_rng = jax.random.split(rng, num=4) # Primary convolutional layer. conv_shape, conv_params = conv_init(conv_rng, (-1,) + input_shape) # Grouping all possible pairs. kernel_shape = [filter_shape[0], filter_shape[1], conv_channels, pair_channels] bias_shape = [1, 1, 1, pair_channels] W_init = glorot_normal(in_axis=2, out_axis=3) b_init = normal(1e-6) k1, k2 = jax.random.split(rng) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) pair_shape = conv_shape[:2] + (15,) + (pair_channels,) pair_params = (W, b) # Convolutional block. conv_block_shape, conv_block_params = conv_block_init(block_rng, pair_shape) # Forward pass. serial_shape, serial_params = serial_init(serial_rng, conv_block_shape) params = [conv_params, pair_params, conv_block_params, serial_params] return serial_shape, params def apply_fun(params, inputs): conv_params, pair_params, conv_block_params, serial_params = params # Apply the primary convolutional layer. conv_out = conv_apply(conv_params, inputs) conv_out = relu(conv_out) # Group all possible pairs. W, b = pair_params pair_1 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,1), dim_nums) + b pair_2 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,2), dim_nums) + b pair_3 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,3), dim_nums) + b pair_4 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,4), dim_nums) + b pair_5 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,5), dim_nums) + b pair_out = jnp.dstack([pair_1, pair_2, pair_3, pair_4, pair_5]) pair_out = relu(pair_out) # Convolutional block. conv_block_out = conv_block_apply(conv_block_params, pair_out) # Residual connection. res_out = conv_block_out + pair_out res_out = relu(res_out) # Forward pass. out = serial_apply(serial_params, res_out) return out return init_fun, apply_fun
# batches = synth_batches() # inputs = next(batches) # train_images, labels, _, _ = mnist(permute_train=True, resize=True) # del _ # inputs = train_images[:data_size] # # del train_images # u, s, v_t = onp.linalg.svd(inputs, full_matrices=False) # I = np.eye(v_t.shape[-1]) # I_add = npr.normal(0.0, 0.002, size=I.shape) # noisy_I = I + I_add init_fun, conv_net = stax.serial( Conv(32, (5, 5), (2, 2), padding="SAME"), BatchNorm(), Relu, Conv(10, (3, 3), (2, 2), padding="SAME"), Relu, Flatten, Dense(num_classes), LogSoftmax, ) _, key = random.split(random.PRNGKey(0)) class DataTopologyAE(AbstractProblem): def __init__(self): self.HPARAMS_PATH = "hparams.json"