Example #1
0
def feature_extractor(rng, dim):
  """Feature extraction network."""
  init_params, forward = stax.serial(
    Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    Relu,
    MaxPool((2, 2), (1, 1)),
    Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    Relu,
    MaxPool((2, 2), (1, 1)),
    Flatten,
    Dense(dim),
  )
  temp, rng = random.split(rng)
  params = init_params(temp, (-1, 28, 28, 1))[1]
  return params, forward
Example #2
0
def ResNet50(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]),
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        AvgPool((7, 7)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Example #3
0
 def __init__(self,
              pool_size: Tuple[int, int],
              padding: str = "valid") -> None:
     self.layer: List = [
         MaxPool(window_shape=pool_size,
                 padding=padding.upper(),
                 spec="NHWC")
     ]
Example #4
0
def conv():
    init_fun, predict = stax.serial(
        Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        Relu,
        MaxPool((2, 2), (1, 1)),
        Flatten,
        Dense(32),
        Relu,
        Dense(10),
    )

    def init_params(rng):
        return init_fun(rng, (-1, 28, 28, 1))[1]

    return init_params, predict
Example #5
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
        BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
        convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]),
        identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]),
        identityBlock(3, [128, 128]), identityBlock(3, [128, 128]),
        identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]),
        identityBlock(3, [512, 512]), identityBlock(3, [512, 512]),
        AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
Example #6
0
def LeNet5(batch_size, num_particles):
    input_shape = _input_shape(batch_size)
    return make_model(
        stax.serial(
            GeneralConv(('NCHW', 'OIHW', 'NHWC'),
                        out_chan=6,
                        filter_shape=(5, 5),
                        strides=(1, 1),
                        padding="VALID"), Relu,
            MaxPool(window_shape=(2, 2), strides=(2, 2), padding="VALID"),
            Conv(out_chan=16,
                 filter_shape=(5, 5),
                 strides=(1, 1),
                 padding="SAME"), Relu,
            MaxPool(window_shape=(2, 2), strides=(2, 2), padding="SAME"),
            Conv(out_chan=120,
                 filter_shape=(5, 5),
                 strides=(1, 1),
                 padding="VALID"), Relu,
            MaxPool(window_shape=(2, 2),
                    strides=(2, 2), padding="SAME"), Flatten, Dense(84), Relu,
            Dense(10), LogSoftmax), input_shape, num_particles)
Example #7
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [4, 4, 4], strides=(1, 1)),
        IdentityBlock(3, [4, 4]),
        AvgPool((3, 3)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    Conv(10, (5, 5), (1, 1)), Activator,
    MaxPool((4, 4)), Flatten,
    Dense(10), LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)
    
    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)
    
    # training/test split
    (train_images, train_labels), (test_images, test_labels) = mnist_data.tiny_mnist(flatten=False)
Example #9
0
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(Conv(10, (5, 5), (1, 1)), Activator,
                                          MaxPool((4, 4)), Flatten, Dense(24),
                                          LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)

    # training/test split
    (train_images,
Example #10
0
def ResNet50(num_classes,
             batchnorm=True,
             parameterization='standard',
             nonlinearity='relu'):
    # Define layer constructors
    if parameterization == 'standard':

        def MyGeneralConv(*args, **kwargs):
            return GeneralConv(*args, **kwargs)

        def MyDense(*args, **kwargs):
            return Dense(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyGeneralConv(*args, **kwargs):
            return stax._GeneralConv(*args, **kwargs)[:2]

        def MyDense(*args, **kwargs):
            return stax.Dense(*args, **kwargs)[:2]

    # Define nonlinearity
    if nonlinearity == 'relu':
        nonlin = Relu
    elif nonlinearity == 'swish':
        nonlin = Swish
    elif nonlinearity == 'swishten':
        nonlin = Swishten
    elif nonlinearity == 'softplus':
        nonlin = Softplus
    return jax_stax.serial(
        MyGeneralConv(('NHWC', 'HWIO', 'NHWC'),
                      64, (7, 7),
                      strides=(2, 2),
                      padding='SAME'),
        BatchNorm() if batchnorm else Identity, nonlin,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256],
                  strides=(1, 1),
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [64, 64],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [64, 64],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [128, 128, 512],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [256, 256, 1024],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [512, 512, 2048],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [512, 512],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [512, 512],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        stax.GlobalAvgPool()[:-1], MyDense(num_classes))
def loss(params, batch, rng):
    inputs, targets = batch
    preds = predict(params, inputs, rng=rng)
    return -np.mean(np.sum(preds * targets, axis=1))


def accuracy(params, batch, rng):
    inputs, targets = batch
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(params, inputs, rng=rng), axis=1)
    return np.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(Conv(10, (5, 5), (1, 1)), Activator,
                                          Dropout(dropout_rate), MaxPool(
                                              (4, 4)), Flatten, Dense(10),
                                          LogSoftmax)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    step_size = 0.001
    num_epochs = 10
    batch_size = 128
    momentum_mass = 0.9

    # input shape for CNN
    input_shape = (-1, 28, 28, 1)

    # training/test split
    (train_images,