def feature_extractor(rng, dim): """Feature extraction network.""" init_params, forward = stax.serial( Conv(16, (8, 8), padding='SAME', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Conv(32, (4, 4), padding='VALID', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Flatten, Dense(dim), ) temp, rng = random.split(rng) params = init_params(temp, (-1, 28, 28, 1))[1] return params, forward
def ResNet50(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax, )
def __init__(self, pool_size: Tuple[int, int], padding: str = "valid") -> None: self.layer: List = [ MaxPool(window_shape=pool_size, padding=padding.upper(), spec="NHWC") ]
def conv(): init_fun, predict = stax.serial( Conv(16, (8, 8), padding='SAME', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Conv(32, (4, 4), padding='VALID', strides=(2, 2)), Relu, MaxPool((2, 2), (1, 1)), Flatten, Dense(32), Relu, Dense(10), ) def init_params(rng): return init_fun(rng, (-1, 28, 28, 1))[1] return init_params, predict
def ResNet(num_classes): return stax.serial( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]), identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]), identityBlock(3, [512, 512]), identityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
def LeNet5(batch_size, num_particles): input_shape = _input_shape(batch_size) return make_model( stax.serial( GeneralConv(('NCHW', 'OIHW', 'NHWC'), out_chan=6, filter_shape=(5, 5), strides=(1, 1), padding="VALID"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="VALID"), Conv(out_chan=16, filter_shape=(5, 5), strides=(1, 1), padding="SAME"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), Conv(out_chan=120, filter_shape=(5, 5), strides=(1, 1), padding="VALID"), Relu, MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"), Flatten, Dense(84), Relu, Dense(10), LogSoftmax), input_shape, num_particles)
def ResNet(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [4, 4, 4], strides=(1, 1)), IdentityBlock(3, [4, 4]), AvgPool((3, 3)), Flatten, Dense(num_classes), LogSoftmax, )
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(np.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial( Conv(10, (5, 5), (1, 1)), Activator, MaxPool((4, 4)), Flatten, Dense(10), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 # input shape for CNN input_shape = (-1, 28, 28, 1) # training/test split (train_images, train_labels), (test_images, test_labels) = mnist_data.tiny_mnist(flatten=False)
def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(np.sum(preds * targets, axis=1)) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial(Conv(10, (5, 5), (1, 1)), Activator, MaxPool((4, 4)), Flatten, Dense(24), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 # input shape for CNN input_shape = (-1, 28, 28, 1) # training/test split (train_images,
def ResNet50(num_classes, batchnorm=True, parameterization='standard', nonlinearity='relu'): # Define layer constructors if parameterization == 'standard': def MyGeneralConv(*args, **kwargs): return GeneralConv(*args, **kwargs) def MyDense(*args, **kwargs): return Dense(*args, **kwargs) elif parameterization == 'ntk': def MyGeneralConv(*args, **kwargs): return stax._GeneralConv(*args, **kwargs)[:2] def MyDense(*args, **kwargs): return stax.Dense(*args, **kwargs)[:2] # Define nonlinearity if nonlinearity == 'relu': nonlin = Relu elif nonlinearity == 'swish': nonlin = Swish elif nonlinearity == 'swishten': nonlin = Swishten elif nonlinearity == 'softplus': nonlin = Softplus return jax_stax.serial( MyGeneralConv(('NHWC', 'HWIO', 'NHWC'), 64, (7, 7), strides=(2, 2), padding='SAME'), BatchNorm() if batchnorm else Identity, nonlin, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1), batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [64, 64], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [64, 64], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [128, 128, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [256, 256, 1024], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [512, 512, 2048], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [512, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [512, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), stax.GlobalAvgPool()[:-1], MyDense(num_classes))
def loss(params, batch, rng): inputs, targets = batch preds = predict(params, inputs, rng=rng) return -np.mean(np.sum(preds * targets, axis=1)) def accuracy(params, batch, rng): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs, rng=rng), axis=1) return np.mean(predicted_class == target_class) init_random_params, predict = stax.serial(Conv(10, (5, 5), (1, 1)), Activator, Dropout(dropout_rate), MaxPool( (4, 4)), Flatten, Dense(10), LogSoftmax) if __name__ == "__main__": rng = random.PRNGKey(0) step_size = 0.001 num_epochs = 10 batch_size = 128 momentum_mass = 0.9 # input shape for CNN input_shape = (-1, 28, 28, 1) # training/test split (train_images,