def ResidualBlock(out_channels, kernel_size, stride, padding, input_format): double_conv = stax.serial( GeneralConv(input_format, out_channels, kernel_size, stride, padding), Elu, GeneralConv(input_format, out_channels, kernel_size, stride, padding), ) return Module( *stax.serial(FanOut(2), stax.parallel(double_conv, Identity), FanInSum) )
def ResNet( hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format ): residual = stax.serial( GeneralConv(input_format, hidden_channels, kernel_size, strides, padding), *[ ResidualBlock(hidden_channels, kernel_size, strides, padding, input_format) for _ in range(depth) ], GeneralConv(input_format, out_channels, kernel_size, strides, padding) ) return Module( *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1)) )
def RevNet( hidden_channels, out_channels, kernel_size, strides, padding, depth, input_format ): residual = stax.serial( # Split(input_format[0].lower().index("c")), GeneralConv(input_format, hidden_channels, kernel_size, strides, padding), *[ ReversibleBlock(hidden_channels, kernel_size, input_format) for _ in range(depth) ], GeneralConv(input_format, out_channels, kernel_size, strides, padding) ) return Module( *stax.serial(FanOut(2), stax.parallel(residual, Identity), AddLastItem(1)) )
def ResNet50(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax, )
def constructDuelNetwork(n_actions, seed, input_shape): advantage_stream = stax.serial(Dense(512), Relu, Dense(n_actions)) state_function_stream = stax.serial(Dense(512), Relu, Dense(1)) dueling_architecture = stax.serial( elementwise(lambda x: x / 255.0), GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)), Relu, GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)), Relu, GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)), Relu, Flatten, FanOut(2), parallel(advantage_stream, state_function_stream), ) def duelingNetworkMapping(inputs): advantage_values = inputs[0] state_values = inputs[1] advantage_sums = jnp.sum(advantage_values, axis=1) advantage_sums = advantage_sums / float(n_actions) advantage_sums = advantage_sums.reshape(-1, 1) Q_values = state_values + (advantage_values - advantage_sums) return Q_values duelArchitectureMapping = jit(duelingNetworkMapping) ##### Create deep neural net model = DDQN(n_actions, input_shape, adam_params, architecture=dueling_architecture, seed=seed, mappingFunction=duelArchitectureMapping) return model
def constructSingleStreamNetwork(n_actions, seed, input_shape): single_stream_architecture = stax.serial( elementwise(lambda x: x / 255.0), # normalize ### convolutional NN (CNN) GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)), Relu, GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)), Relu, GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)), Relu, Flatten, # flatten output Dense(1024), Relu, Dense(n_actions)) model = DDQN(n_actions, input_shape, adam_params, architecture=single_stream_architecture, seed=seed) return model
def ResNet(num_classes): return stax.serial( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]), identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]), identityBlock(3, [512, 512]), identityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
def ResNet(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [4, 4, 4], strides=(1, 1)), IdentityBlock(3, [4, 4]), AvgPool((3, 3)), Flatten, Dense(num_classes), LogSoftmax, )
def LeNet5(num_classes): return stax.serial( GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'), BatchNorm(), Relu, AvgPool((3,3)), Conv(16, (5,5), strides = (1,1),padding="SAME"), BatchNorm(), Relu, AvgPool((3,3)), Flatten, Dense(num_classes*10), Dense(num_classes*5), Dense(num_classes), LogSoftmax )
def _create_network_architecture(self, action_dim): dim_nums = ('NHWC', 'HWIO', 'NHWC') initialize_params, predict = stax.serial( # elementwise(lambda x: x/255.0), # normalize ### convolutional NN (CNN) GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)), Relu, # GeneralConv(dim_nums, 64, (4,4), strides=(2,2) ), # Relu, # GeneralConv(dim_nums, 64, (3,3), strides=(1,1) ), # Relu, Flatten, # flatten output Dense(512), Relu, Dense(action_dim)) return initialize_params, predict
def LeNet5(batch_size, num_particles): input_shape = _input_shape(batch_size) return make_model( stax.serial( GeneralConv(('NCHW', 'OIHW', 'NHWC'), out_chan=6, filter_shape=(5, 5), strides=(1, 1), padding="VALID"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="VALID"), Conv(out_chan=16, filter_shape=(5, 5), strides=(1, 1), padding="SAME"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), Conv(out_chan=120, filter_shape=(5, 5), strides=(1, 1), padding="VALID"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), Flatten, Dense(84), Relu, Dense(10), LogSoftmax), input_shape, num_particles)
def MyGeneralConv(*args, **kwargs): return GeneralConv(*args, **kwargs)
def ConvBlock(out_channels, kernel_size, input_format): return stax.serial( GeneralConv(input_format, out_channels, kernel_size, 1, "SAME"), Elu, GeneralConv(input_format, out_channels, kernel_size, 1, "SAME"), )