def gen_cnn_conv4(output_units=10,
                  W_initializers_str='glorot_normal()',
                  b_initializers_str='normal()'):
    # This is an up-scaled version of the CNN in keras tutorial: https://keras.io/examples/cifar10_cnn/
    return stax.serial(
        stax.Conv(out_chan=64,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.Conv(out_chan=64,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.MaxPool((2, 2), strides=(2, 2)),
        stax.Conv(out_chan=128,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.Conv(out_chan=128,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.MaxPool((2, 2), strides=(2, 2)), stax.Flatten,
        stax.Dense(512,
                   W_init=eval(W_initializers_str),
                   b_init=eval(b_initializers_str)), stax.Relu,
        stax.Dense(output_units,
                   W_init=eval(W_initializers_str),
                   b_init=eval(b_initializers_str)))
Esempio n. 2
0
 def MakeMain(input_shape):
     # the number of output channels depends on the number of input channels
     return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(),
                        stax.Relu,
                        stax.Conv(filters2, (ks, ks), padding='SAME'),
                        stax.BatchNorm(), stax.Relu,
                        stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())
Esempio n. 3
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = stax.serial(stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(channels, (3, 3), padding='SAME'))
    shortcut = stax.Identity if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum)
Esempio n. 4
0
 def ConvBlock(self, kernel_size, filters, strides=(2, 2)):
     filters1, filters2, filters3 = filters
     Main = stax.serial(
         stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu,
         stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'),
         stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)),
         stax.BatchNorm())
     Shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                            stax.BatchNorm())
     return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                        stax.FanInSum, stax.Relu)
Esempio n. 5
0
def gen_cnn_lenet_caffe(output_units = 10, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'):
    return stax.serial(
      stax.Conv(out_chan = 20, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ),
      stax.Relu, 
      stax.MaxPool((2, 2), strides = (2, 2)),
      stax.Conv(out_chan = 50, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ),
      stax.Relu, 
      stax.MaxPool((2, 2), strides = (2, 2)),
      stax.Flatten, 
      stax.Dense(500, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)),
      stax.Relu,
      stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str))) 
Esempio n. 6
0
def wide_resnet_block(num_channels, strides=(1, 1), channel_mismatch=False):
    """Wide ResNet block."""
    pre = stax.serial(stax.BatchNorm(), stax.Relu)
    mid = stax.serial(
        pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'),
        stax.BatchNorm(), stax.Relu,
        stax.Conv(num_channels, (3, 3), strides=(1, 1), padding='SAME'))
    if channel_mismatch:
        cut = stax.serial(
            pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'))
    else:
        cut = stax.Identity
    return stax.serial(stax.FanOut(2), stax.parallel(mid, cut), stax.FanInSum)
Esempio n. 7
0
def ConvBlock(kernel_size, filters, strides):
    """ResNet convolutional striding block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = stax.serial(stax.Conv(filters1, (1, 1),
                                 strides), stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters2, (ks, ks), padding='SAME'),
                       stax.BatchNorm(), stax.Relu,
                       stax.Conv(filters3, (1, 1)), stax.BatchNorm())
    shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides),
                           stax.BatchNorm())
    return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                       stax.FanInSum, stax.Relu)
Esempio n. 8
0
def cnn(num_classes=10):
    return stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,  # (-1, 800)
        stax.Dense(64),
        stax.Tanh,  # embeddings
        stax.Dense(num_classes),  # logits
    )
Esempio n. 9
0
def cnn():
    return stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,
        stax.Dense(32),
        stax.Tanh,  # embeddings
        stax.Dense(10),  # logits
    )
Esempio n. 10
0
    def test_conv(self):
        order = 3
        input_shape = (1, 5, 5, 1)
        key = random.PRNGKey(0)
        # TODO(duvenaud): Check all types of padding
        init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID')
        _, (W, b) = init_fun(key, input_shape)

        rng = onp.random.RandomState(0)

        x = rng.randn(*input_shape).astype("float32")
        primals = (W, b, x)

        series_in1 = [
            rng.randn(*W.shape).astype("float32") for _ in range(order)
        ]
        series_in2 = [
            rng.randn(*b.shape).astype("float32") for _ in range(order)
        ]
        series_in3 = [
            rng.randn(*x.shape).astype("float32") for _ in range(order)
        ]

        series_in = (series_in1, series_in2, series_in3)

        def f(W, b, x):
            return apply_fun((W, b), x)

        self.check_jet(f, primals, series_in, check_dtypes=False)
Esempio n. 11
0
def Resnet50(hidden_size=64, num_output_classes=1001):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    return stax.serial(
        stax.Conv(hidden_size, (7, 7), (2, 2),
                  'SAME'), stax.BatchNorm(), stax.Relu,
        stax.MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size]),
        ConvBlock(3,
                  [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [8 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]),
        stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes),
        stax.LogSoftmax)
Esempio n. 12
0
 def testConvShape(self, channels, filter_shape, padding, strides,
                   input_shape):
     init_fun, apply_fun = stax.Conv(channels,
                                     filter_shape,
                                     strides=strides,
                                     padding=padding)
     _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
Esempio n. 13
0
def MyConv(*args, parameterization='standard', order=None, **kwargs):
    """Wrapper for convolutional layer with different parameterizations."""
    if parameterization == 'standard':
        return jax_stax.Conv(*args, **kwargs)
    elif parameterization == 'ntk':
        return stax.Conv(*args, b_std=1.0, **kwargs)[:2]
    elif parameterization == 'taylor':
        return TaylorConv(*args, b_std=1.0, order=order, **kwargs)
Esempio n. 14
0
def wide_resnet(n, k, num_classes):
    """Original WRN from paper and previous experiments."""
    return stax.serial(stax.Conv(16, (3, 3), padding='SAME'),
                       wide_resnet_group(n, 16 * k, strides=(1, 1)),
                       wide_resnet_group(n, 32 * k, strides=(2, 2)),
                       wide_resnet_group(n, 64 * k, strides=(2, 2)),
                       stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)),
                       stax.Flatten, stax.Dense(num_classes))
Esempio n. 15
0
def conv():
    init_fun, predict = stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Relu,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Relu,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,
        stax.Dense(32),
        stax.Relu,
        stax.Dense(10),
    )

    def init_params(rng):
        return init_fun(rng, (-1, 28, 28, 1))[1]

    return init_params, predict
Esempio n. 16
0
def Cnn(n_actions: int, hidden_size: int = 512) -> Module:
    return stax.serial(
        stax.Conv(32, (8, 8), (4, 4), "VALID"),
        stax.Relu,
        stax.Conv(64, (4, 4), (2, 2), "VALID"),
        stax.Relu,
        stax.Conv(64, (3, 3), (1, 1), "VALID"),
        stax.Relu,
        stax.Flatten,
        stax.Dense(hidden_size),
        stax.Relu,
        stax.FanOut(2),
        stax.parallel(
            stax.serial(
                stax.Dense(n_actions),
                stax.Softmax,
            ),  #  actor
            stax.serial(stax.Dense(1), ),  # critic
        ),
    )
Esempio n. 17
0
 def transform(rng, input_dim, output_dim):
     init_fun, apply_fun = stax.serial(
         Reshape(),
         stax.Conv(8,
                   filter_shape=(3, 3),
                   W_init=weight_initializer,
                   b_init=weight_initializer),
         act,
         stax.Conv(16,
                   filter_shape=(3, 3),
                   W_init=weight_initializer,
                   b_init=weight_initializer),
         act,
         stax.Flatten,
         stax.Dense(output_dim,
                    W_init=weight_initializer,
                    b_init=weight_initializer),
     )
     _, params = init_fun(rng, (input_dim, ))
     return params, apply_fun
Esempio n. 18
0
def make_conv(
    strides=None,
    num_channels=256,
):
    return stax.Conv(
        out_chan=num_channels,
        filter_shape=(3, 3),
        padding="VALID",
        strides=strides,
        W_init=nn.initializers.he_normal(),
        b_init=nn.initializers.zeros,
    )
def conv2d(num_classes, layers=((32, 5, 2), (16, 3, 2), (16, 3, 2))):
    """Builds a simple convolutional neural network."""
    stack = []

    # Concatenate convolutional layers.
    for num_units, kernel_size, stride in layers:
        stack += [
            stax.Conv(num_units, (kernel_size, kernel_size), (stride, stride),
                      padding='SAME'), stax.Relu
        ]

    # Output layer.
    stack += [stax.Flatten, stax.Dense(num_classes), stax.LogSoftmax]

    return stax.serial(*stack)
Esempio n. 20
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'),
                       WideResnetGroup(num_blocks, hidden_size),
                       WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                       WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                       stax.BatchNorm(), stax.Relu,
                       stax.AvgPool((8, 8)), stax.Flatten,
                       stax.Dense(num_output_classes), stax.LogSoftmax)
Esempio n. 21
0
def cnn(conv_depth=300,
        kernel_size=5,
        n_conv_layers=2,
        across_batch=False,
        add_pos_encoding=False):
  """Build convolutional neural net."""
  # Input shape: [batch x length x depth]
  if across_batch:
    extra_dim = 0
  else:
    extra_dim = 1
  layers = [ExpandDims(axis=extra_dim)]
  if add_pos_encoding:
    layers.append(positional_encoding())

  for _ in range(n_conv_layers):
    layers.append(
        stax.Conv(conv_depth, (1, kernel_size), padding="same", strides=(1, 1)))
    layers.append(stax.Relu)
  layers.append(AssertNonZeroShape())
  layers.append(squeeze_layer(axis=extra_dim))
  return stax.serial(*layers)
Esempio n. 22
0
num_batches = num_complete_batches + bool(leftover)


# %%
def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]


batches = data_stream()

init_fun, net = stax.serial(stax.Conv(16, (3, 3), (1, 1),
                                      padding="SAME"), stax.Relu,
                            stax.MaxPool((2, 2), (2, 2), padding="SAME"),
                            stax.Conv(32, (3, 3), (1, 1),
                                      padding="SAME"), stax.Relu,
                            stax.MaxPool((2, 2), (2, 2), padding="SAME"),
                            stax.Flatten, stax.Dense(10), stax.LogSoftmax)

_, params = init_fun(key, (64, 1, 28, 28))


def loss(params, batch):
    inputs, targets = batch
    preds = net(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))

#test over here for adversial
#test_newx = computation(params, (train_images, train_labels))

#print(test_newx)
"""# **Problem 2**

Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD.
"""

from jax.experimental import optimizers
from jax.experimental import stax
import matplotlib.pyplot as plt
"""Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`"""

init_random_params, predict = stax.serial(
    stax.Conv(64, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(128, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu,
    stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(128), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)
"""We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical)."""


def loss(params, batch):
    inputs, targets = batch
    logits = predict(params, inputs)
    preds = stax.logsoftmax(logits)
    return -np.mean(preds * targets)


"""Next, we define the mini-batch SGD optimizer, this time with the optimizers library in JAX."""
Esempio n. 24
0
flags.DEFINE_float('learning_rate', .10, 'Learning rate for finetuning.')
flags.DEFINE_integer('batch_size', 256, 'Batch size')
flags.DEFINE_integer('epochs', 100, 'Number of finetuning epochs')
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
flags.DEFINE_integer(
    'uncertain', 0, '0: entropy'
    '1: difference between 1st_prob and 2nd_prob'
    '2: random')
flags.DEFINE_integer('n_extra', 3000, 'number of extra points')
flags.DEFINE_bool(
    'show_label', True,
    'visualize predicted label at top/left, true at bottom/right')

# BEGIN: define the classifier model
init_fn_0, apply_fn_0 = stax.serial(
    stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    stax.Tanh,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    stax.Tanh,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Flatten,  # (-1, 800)
    stax.Dense(32),
    stax.Tanh,  # embeddings
)

init_fn_1, apply_fn_1 = stax.serial(
    stax.Dense(10),  # logits
)

Esempio n. 25
0
def main(_):

    logging.info('Starting experiment.')
    configs = FLAGS.config

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.exp_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+')

    if configs.optimization == 'sgd':
        lr_schedule = optimizers.make_schedule(configs.learning_rate)
        opt_init, opt_update, get_params = optimizers.sgd(lr_schedule)
    elif configs.optimization == 'momentum':
        lr_schedule = cosine(configs.learning_rate, configs.train_steps)
        opt_init, opt_update, get_params = optimizers.momentum(
            lr_schedule, 0.9)
    else:
        raise ValueError('Optimizer not implemented.')

    with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre:
        pretrained_opt_state = optimizers.pack_optimizer_state(
            pickle.load(fpre))
    fixed_params = get_params(pretrained_opt_state)[:7]

    # BEGIN: define the classifier model
    init_fn_0, apply_fn_0 = stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,  # representations
    )

    init_fn_1, apply_fn_1 = stax.serial(
        stax.Dense(64),
        stax.Tanh,  # embeddings
        stax.Dense(10),  # logits
    )

    def predict(params, inputs):
        representations = apply_fn_0(fixed_params,
                                     inputs)  # use pretrained params
        logits = apply_fn_1(params, representations)
        return logits

    # END: define the classifier model

    if configs.seed is not None:
        key = random.PRNGKey(configs.seed)
    else:
        key = random.PRNGKey(int(time.time()))
    _, _ = init_fn_0(key, (-1, 32, 32, 3))
    _, params = init_fn_1(key, (-1, 800))
    opt_state = opt_init(params)

    logging.info('Loading data.')
    tic = time.time()

    train_images, train_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'train')
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    n_train = len(train_images)
    train = data.DataChunk(X=(train_images - train_mu) / train_std,
                           Y=train_labels,
                           image_size=32,
                           image_channels=3,
                           label_dim=1,
                           label_format='numeric')

    test_images, test_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'test')
    test = data.DataChunk(
        X=(test_images - train_mu) / train_std,  # normalize w train mean/std
        Y=test_labels,
        image_size=32,
        image_channels=3,
        label_dim=1,
        label_format='numeric')

    # Data augmentation
    if configs.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    batch = data.minibatcher(train, configs.batch_size, transform=augmentation)

    # count params of JAX model
    def count_parameters(params):
        return tree_util.tree_reduce(
            operator.add, tree_util.tree_map(lambda x: np.prod(x.shape),
                                             params))

    logging.info('Number of parameters: %d', count_parameters(params))
    stdout_log.write('Number of params: {}\n'.format(count_parameters(params)))

    # loss functions
    def cross_entropy_loss(params, x_img, y_lbl):
        return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)

    def mse_loss(params, x_img, y_lbl):
        return 0.5 * np.mean((y_lbl - predict(params, x_img))**2)

    def accuracy(y_lbl_hat, y_lbl):
        target_class = np.argmax(y_lbl, axis=1)
        predicted_class = np.argmax(y_lbl_hat, axis=1)
        return np.mean(predicted_class == target_class)

    # Loss and gradient
    if configs.loss == 'xent':
        loss = cross_entropy_loss
    elif configs.loss == 'mse':
        loss = mse_loss
    else:
        raise ValueError('Loss function not implemented.')
    grad_loss = jit(grad(loss))

    # learning rate schedule and optimizer
    def cosine(initial_step_size, train_steps):
        k = np.pi / (2.0 * train_steps)

        def schedule(i):
            return initial_step_size * np.cos(k * i)

        return schedule

    def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
                     batch_size):
        """Return differentially private gradients of params, evaluated on batch."""
        def _clipped_grad(params, single_example_batch):
            """Evaluate gradient for a single-example batch and clip its grad norm."""
            grads = grad_loss(params, single_example_batch[0].reshape(
                (-1, 32, 32, 3)), single_example_batch[1])

            nonempty_grads, tree_def = tree_util.tree_flatten(grads)
            total_grad_norm = np.linalg.norm(
                [np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
            divisor = stop_gradient(
                np.amax((total_grad_norm / l2_norm_clip, 1.)))
            normalized_nonempty_grads = [
                neg / divisor for neg in nonempty_grads
            ]
            return tree_util.tree_unflatten(tree_def,
                                            normalized_nonempty_grads)

        px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
        std_dev = l2_norm_clip * noise_multiplier
        noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
        normalize_ = lambda n: n / float(batch_size)
        sum_ = lambda n: np.sum(n, 0)  # aggregate
        aggregated_clipped_grads = tree_util.tree_map(
            sum_, px_clipped_grad_fn(batch))
        noised_aggregated_clipped_grads = tree_util.tree_map(
            noise_, aggregated_clipped_grads)
        normalized_noised_aggregated_clipped_grads = (tree_util.tree_map(
            normalize_, noised_aggregated_clipped_grads))
        return normalized_noised_aggregated_clipped_grads

    # summarize measurements
    steps_per_epoch = n_train // configs.batch_size

    def summarize(step, params):
        """Compute measurements in a zipped way."""
        set_entries = [train, test]
        set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize]
        set_names, loss_dict, acc_dict = ['train', 'test'], {}, {}

        for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes,
                                                  set_names):
            temp_loss, temp_acc, points = 0.0, 0.0, 0
            for b in data.batch(set_entry, set_bsize):
                temp_loss += loss(params, b.X, b.Y) * b.X.shape[0]
                temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0]
                points += b.X.shape[0]
            loss_dict[set_name] = temp_loss / float(points)
            acc_dict[set_name] = temp_acc / float(points)

        logging.info('Step: %s', str(step))
        logging.info('Train acc : %.4f', acc_dict['train'])
        logging.info('Train loss: %.4f', loss_dict['train'])
        logging.info('Test acc  : %.4f', acc_dict['test'])
        logging.info('Test loss : %.4f', loss_dict['test'])

        stdout_log.write('Step: {}\n'.format(step))
        stdout_log.write('Train acc : {}\n'.format(acc_dict['train']))
        stdout_log.write('Train loss: {}\n'.format(loss_dict['train']))
        stdout_log.write('Test acc  : {}\n'.format(acc_dict['test']))
        stdout_log.write('Test loss : {}\n'.format(loss_dict['test']))
        stdout_log.flush()

        return acc_dict['test']

    toc = time.time()
    logging.info('Elapsed SETUP time: %s', str(toc - tic))
    stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic))

    # BEGIN: training steps
    logging.info('Training network.')
    tic = time.time()
    t = time.time()

    for s in range(configs.train_steps):
        b = next(batch)
        params = get_params(opt_state)

        # t0 = time.time()
        if FLAGS.dpsgd:
            key = random.fold_in(key, s)  # get new key for new random numbers
            opt_state = opt_update(
                s,
                private_grad(params, (b.X.reshape(
                    (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip,
                             configs.noise_multiplier, configs.batch_size),
                opt_state)
        else:
            opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state)
        # t1 = time.time()
        # logging.info('batch update time: %s', str(t1 - t0))

        if s % steps_per_epoch == 0:
            with gfile.Open(
                    '{}/ckpt_{}'.format(FLAGS.exp_dir,
                                        int(s / steps_per_epoch)),
                    'wb') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)

            if FLAGS.dpsgd:
                eps = compute_epsilon(s, configs.batch_size, n_train,
                                      configs.target_delta,
                                      configs.noise_multiplier)
                stdout_log.write(
                    'For delta={:.0e}, current epsilon is: {:.2f}\n'.format(
                        configs.target_delta, eps))

            logging.info('Elapsed EPOCH time: %s', str(time.time() - t))
            stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t))
            stdout_log.flush()
            t = time.time()

    toc = time.time()
    summarize(configs.train_steps, params)
    logging.info('Elapsed TRAIN time: %s', str(toc - tic))
    stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic))
    stdout_log.close()
def main(_):
    rng = random.PRNGKey(0)

    # Load MNIST dataset
    train_images, train_labels, test_images, test_labels = datasets.mnist()

    batch_size = 128
    batch_shape = (-1, 28, 28, 1)
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    train_images = np.reshape(train_images, batch_shape)
    test_images = np.reshape(test_images, batch_shape)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]

    def save(fn, opt_state):
        params = deepcopy(get_params(opt_state))
        save_dict = {}
        for idx, p in enumerate(params):
            if (p != ()):
                pp = (p[0].tolist(), p[1].tolist())
                params[idx] = pp
        save_dict["params"] = params
        with open(fn, "w") as f:
            json.dump(save_dict, f)

    def load(fn):
        with open(fn, "r") as f:
            params = json.load(f)
        params = params["params"]
        for idx, p in enumerate(params):
            if (p != []):
                pp = (np.array(p[0]), np.array(p[1]))
                params[idx] = pp
            else:
                params[idx] = ()
        return opt_init(params)

    batches = data_stream()

    # Model, loss, and accuracy functions
    init_random_params, predict = stax.serial(
        stax.Conv(32, (8, 8), strides=(2, 2), padding="SAME"),
        stax.Relu,
        stax.Conv(128, (6, 6), strides=(2, 2), padding="VALID"),
        stax.Relu,
        stax.Conv(128, (5, 5), strides=(1, 1), padding="VALID"),
        stax.Flatten,
        stax.Dense(128),
        stax.Relu,
        stax.Dense(10),
    )

    def loss(params, batch):
        inputs, targets = batch
        preds = predict(params, inputs)
        return -np.mean(logsoftmax(preds) * targets)

    def accuracy(params, batch):
        inputs, targets = batch
        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(predict(params, inputs), axis=1)
        return np.mean(predicted_class == target_class)

    def gen_ellipsoid(X, zeta_rel, zeta_const, alpha, N_steps):
        zeta = (np.abs(X).T * zeta_rel).T + zeta_const
        if (alpha is None):
            alpha = 1 / N_steps * zeta
        else:
            assert isinstance(alpha, float), "Alpha must be float"
            alpha = alpha * np.ones_like(X)
        return zeta, alpha

    def gen_ellipsoid_match_volume(X, zeta_const, eps, alpha, N_steps):
        x_norms = np.linalg.norm(np.reshape(X, (X.shape[0], -1)),
                                 ord=1,
                                 axis=1)
        N = np.prod(X.shape[1:])
        zeta_rel = N * (eps - zeta_const) / x_norms
        assert (zeta_rel <= 1.0).all(
        ), "Zeta rel cannot be larger than 1. Please increase zeta const or reduce eps"
        zeta_rel = np.clip(0.0, zeta_rel, 1.0)
        return gen_ellipsoid(X, zeta_rel, zeta_const, alpha, N_steps)

    # Instantiate an optimizer
    opt_init, opt_update, get_params = optimizers.adam(0.001)

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    # Initialize model
    _, init_params = init_random_params(rng, batch_shape)
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    try:
        opt_state = load("tutorials/jax/test_model.json")
    except:
        # Training loop
        print("\nStarting training...")
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_time = time.time() - start_time
        save("tutorials/jax/test_model.json", opt_state)

    # Evaluate model on clean data
    params = get_params(opt_state)

    # Evaluate model on adversarial data
    model_fn = lambda images: predict(params, images)
    # Generate single attacking test image
    idx = 0
    plt.figure(figsize=(15, 6), constrained_layout=True)

    zeta, alpha = gen_ellipsoid(X=test_images[idx].reshape((1, 28, 28, 1)),
                                zeta_rel=FLAGS.zeta_rel,
                                zeta_const=FLAGS.zeta_const,
                                alpha=None,
                                N_steps=40)
    # zeta, alpha = gen_ellipsoid_match_volume(X=test_images[idx].reshape((1,28,28,1)), zeta_const=FLAGS.zeta_const, eps=FLAGS.eps, alpha=None, N_steps=40)
    test_images_pgd_ellipsoid = projected_gradient_descent(
        model_fn, test_images[idx].reshape((1, 28, 28, 1)), zeta, alpha, 40,
        np.inf)
    predict_pgd_ellipsoid = np.argmax(predict(params,
                                              test_images_pgd_ellipsoid),
                                      axis=1)

    test_images_fgm = fast_gradient_method(
        model_fn, test_images[idx].reshape((1, 28, 28, 1)), 0.075, np.inf)
    predict_fgm = np.argmax(predict(params, test_images_fgm), axis=1)

    test_images_pgd = projected_gradient_descent(
        model_fn, test_images[idx].reshape((1, 28, 28, 1)), FLAGS.eps, 0.01,
        40, 2)
    predict_pgd = np.argmax(predict(params, test_images_pgd), axis=1)

    base = 100
    f_ = lambda x: np.log(x) / np.log(base)
    a = base - 1
    transform = 1 + a * test_images[idx].reshape((1, 28, 28, 1))  # [1,base]

    # test_images_pgd_transform = projected_gradient_descent(model_fn, f_(np.where(transform > base,base,transform)), FLAGS.zeta_rel, 0.01, 40, np.inf)
    test_images_pgd_transform = projected_gradient_descent(
        model_fn, f_(np.where(transform > base, base, transform)), 1.8, 0.01,
        40, 2)
    test_images_pgd_transform = np.clip(test_images_pgd_transform, 0.0, 1.0)
    test_images_pgd_transform = (base**test_images_pgd_transform - 1) / a
    predict_transform = np.argmax(predict(params, test_images_pgd_transform),
                                  axis=1)

    plt.subplot(151)
    plt.imshow(np.squeeze(test_images[idx]), cmap='gray')
    plt.title("Original")
    plt.subplot(152)
    plt.imshow(np.squeeze(test_images_fgm), cmap='gray')
    plt.title(f"FGM L-Inf Pred: {predict_fgm}")
    plt.subplot(153)
    plt.imshow(np.squeeze(test_images_pgd), cmap='gray')
    plt.title(f"PGD L2 {predict_pgd}")
    plt.subplot(154)
    plt.imshow(np.squeeze(test_images_pgd_ellipsoid), cmap='gray')
    plt.title(f"PGD Ellipsoid L-Inf Pred: {predict_pgd_ellipsoid}")
    plt.subplot(155)
    plt.imshow(np.squeeze(test_images_pgd_transform), cmap='gray')
    plt.title(f"PGD log{base} L2 Pred: {predict_transform}")

    plt.show()

    transform = 1 + a * test_images
    test_images_pgd_transform = projected_gradient_descent(
        model_fn, f_(np.where(transform > base, base, transform)),
        FLAGS.zeta_rel, 0.01, 40, np.inf)
    test_images_pgd_transform = np.clip(test_images_pgd_transform, 0.0, 1.0)
    test_images_pgd_transform = (base**test_images_pgd_transform - 1) / a
    test_acc_pgd_transform = accuracy(params,
                                      (test_images_pgd_transform, test_labels))

    # Generate whole attacking test images
    # zeta, alpha = gen_ellipsoid(X=test_images, zeta_rel=FLAGS.zeta_rel, zeta_const=FLAGS.zeta_const, alpha=None, N_steps=40)
    zeta, alpha = gen_ellipsoid_match_volume(X=test_images,
                                             zeta_const=FLAGS.zeta_const,
                                             eps=FLAGS.eps,
                                             alpha=None,
                                             N_steps=40)
    test_images_pgd_ellipsoid = projected_gradient_descent(
        model_fn, test_images, zeta, alpha, 40, np.inf)
    test_acc_pgd_ellipsoid = accuracy(params,
                                      (test_images_pgd_ellipsoid, test_labels))

    test_images_fgm = fast_gradient_method(model_fn, test_images, FLAGS.eps,
                                           np.inf)
    test_images_pgd = projected_gradient_descent(model_fn, test_images,
                                                 FLAGS.eps, 0.01, 40, np.inf)

    test_acc_fgm = accuracy(params, (test_images_fgm, test_labels))
    test_acc_pgd = accuracy(params, (test_images_pgd, test_labels))

    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))

    print("Training set accuracy: {}".format(train_acc))
    print("Test set accuracy on clean examples: {}".format(test_acc))
    print("Test set accuracy on FGM adversarial examples: {}".format(
        test_acc_fgm))
    print("Test set accuracy on PGD adversarial examples: {}".format(
        test_acc_pgd))
    print("Test set accuracy on PGD Ellipsoid adversarial examples: {}".format(
        test_acc_pgd_ellipsoid))
    print(
        "Test set accuracy on PGD Ellipsoid via transform adversarial examples: {}"
        .format(test_acc_pgd_transform))
Esempio n. 27
0
class StaxTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape
        } for shape in [(2, 3), (5, )]))
    def testRandnInitShape(self, shape):
        out = stax.randn()(shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape
        } for shape in [(2, 3), (2, 3, 4)]))
    def testGlorotInitShape(self, shape):
        out = stax.glorot()(shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvShape(self, channels, filter_shape, padding, strides,
                      input_shape):
        init_fun, apply_fun = stax.Conv(channels,
                                        filter_shape,
                                        strides=strides,
                                        padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_out_dim={}_input_shape={}".format(out_dim, input_shape),
            "out_dim":
            out_dim,
            "input_shape":
            input_shape
        } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)]))
    def testDenseShape(self, out_dim, input_shape):
        init_fun, apply_fun = stax.Dense(out_dim)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(2, 3), (2, 3, 4)]))
    def testReluShape(self, input_shape):
        init_fun, apply_fun = stax.Relu
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_window_shape={}_padding={}_strides={}_input_shape={}".format(
                window_shape, padding, strides, input_shape),
            "window_shape":
            window_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 5, 6, 1)]))
    def testPoolingShape(self, window_shape, padding, strides, input_shape):
        init_fun, apply_fun = stax.MaxPool(window_shape,
                                           padding=padding,
                                           strides=strides)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(input_shape),
            "input_shape": input_shape
        } for input_shape in [(2, 3), (2, 3, 4)]))
    def testFlattenShape(self, input_shape):
        init_fun, apply_fun = stax.Flatten
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}_spec={}".format(
                    input_shape, i),
                "input_shape": input_shape,
                "spec": spec
            } for input_shape in [(2, 5, 6, 1)]
            for i, spec in enumerate([[stax.Conv(3, (
                2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten,
                          stax.Dense(4)]])))
    def testSerialComposeLayersShape(self, input_shape, spec):
        init_fun, apply_fun = stax.serial(*spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testDropoutShape(self, input_shape):
        init_fun, apply_fun = stax.Dropout(0.9)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testFanInSum(self, input_shape):
        init_fun, apply_fun = stax.FanInSum
        _CheckShapeAgreement(self, init_fun, apply_fun,
                             [input_shape, input_shape])

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_inshapes={}_axis={}".format(
                    input_shapes, axis),
                "input_shapes": input_shapes,
                "axis": axis
            } for input_shapes, axis in [
                ([(2, 3), (2, 1)], 1),
                ([(2, 3), (2, 1)], -1),
                ([(1, 2, 4), (1, 1, 4)], 1),
            ]))
    def testFanInConcat(self, input_shapes, axis):
        init_fun, apply_fun = stax.FanInConcat(axis)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
Esempio n. 28
0
class StaxTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape
        } for shape in [(2, 3), (5, )]))
    def testRandnInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.randn()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape
        } for shape in [(2, 3), (2, 3, 4)]))
    def testGlorotInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.glorot()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvShape(self, channels, filter_shape, padding, strides,
                      input_shape):
        init_fun, apply_fun = stax.Conv(channels,
                                        filter_shape,
                                        strides=strides,
                                        padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_out_dim={}_input_shape={}".format(out_dim, input_shape),
            "out_dim":
            out_dim,
            "input_shape":
            input_shape
        } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)]))
    def testDenseShape(self, out_dim, input_shape):
        init_fun, apply_fun = stax.Dense(out_dim)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(2, 3), (2, 3, 4)]))
    def testReluShape(self, input_shape):
        init_fun, apply_fun = stax.Relu
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_window_shape={}_padding={}_strides={}_input_shape={}"
            "_maxpool={}".format(window_shape, padding, strides, input_shape,
                                 max_pool),
            "window_shape":
            window_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape,
            "max_pool":
            max_pool
        } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 5, 6, 1)]
                            for max_pool in [False, True]))
    def testPoolingShape(self, window_shape, padding, strides, input_shape,
                         max_pool):
        layer = stax.MaxPool if max_pool else stax.AvgPool
        init_fun, apply_fun = layer(window_shape,
                                    padding=padding,
                                    strides=strides)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(input_shape),
            "input_shape": input_shape
        } for input_shape in [(2, 3), (2, 3, 4)]))
    def testFlattenShape(self, input_shape):
        init_fun, apply_fun = stax.Flatten
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}_spec={}".format(
                    input_shape, i),
                "input_shape": input_shape,
                "spec": spec
            } for input_shape in [(2, 5, 6, 1)]
            for i, spec in enumerate([[stax.Conv(3, (
                2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten,
                          stax.Dense(4)]])))
    def testSerialComposeLayersShape(self, input_shape, spec):
        init_fun, apply_fun = stax.serial(*spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testDropoutShape(self, input_shape):
        init_fun, apply_fun = stax.Dropout(0.9)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testFanInSum(self, input_shape):
        init_fun, apply_fun = stax.FanInSum
        _CheckShapeAgreement(self, init_fun, apply_fun,
                             [input_shape, input_shape])

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_inshapes={}_axis={}".format(
                    input_shapes, axis),
                "input_shapes": input_shapes,
                "axis": axis
            } for input_shapes, axis in [
                ([(2, 3), (2, 1)], 1),
                ([(2, 3), (2, 1)], -1),
                ([(1, 2, 4), (1, 1, 4)], 1),
            ]))
    def testFanInConcat(self, input_shapes, axis):
        init_fun, apply_fun = stax.FanInConcat(axis)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)

    def testIssue182(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.Softmax
        input_shape = (10, 3)
        inputs = onp.arange(30.).astype("float32").reshape(input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        assert out_shape == out.shape
        assert onp.allclose(onp.sum(onp.asarray(out), -1), 1.)

    def testBatchNormShapeNHWC(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(onp.random.RandomState(0), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (7, ))
        self.assertEqual(gamma.shape, (7, ))
        self.assertEqual(out_shape, out.shape)

    def testBatchNormShapeNCHW(self):
        key = random.PRNGKey(0)
        # Regression test for https://github.com/google/jax/issues/461
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(onp.random.RandomState(0), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (5, ))
        self.assertEqual(gamma.shape, (5, ))
        self.assertEqual(out_shape, out.shape)
Esempio n. 29
0
def main(_):
    rng = random.PRNGKey(0)

    # Load MNIST dataset
    train_images, train_labels, test_images, test_labels = datasets.mnist()

    batch_size = 128
    batch_shape = (-1, 28, 28, 1)
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    train_images = np.reshape(train_images, batch_shape)
    test_images = np.reshape(test_images, batch_shape)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]

    batches = data_stream()

    # Model, loss, and accuracy functions
    init_random_params, predict = stax.serial(
        stax.Conv(32, (8, 8), strides=(2, 2), padding='SAME'), stax.Relu,
        stax.Conv(128, (6, 6), strides=(2, 2), padding='VALID'), stax.Relu,
        stax.Conv(128, (5, 5), strides=(1, 1), padding='VALID'), stax.Flatten,
        stax.Dense(128), stax.Relu, stax.Dense(10))

    def loss(params, batch):
        inputs, targets = batch
        preds = predict(params, inputs)
        return -np.mean(logsoftmax(preds) * targets)

    def accuracy(params, batch):
        inputs, targets = batch
        target_class = np.argmax(targets, axis=1)
        predicted_class = np.argmax(predict(params, inputs), axis=1)
        return np.mean(predicted_class == target_class)

    # Instantiate an optimizer
    opt_init, opt_update, get_params = optimizers.adam(0.001)

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    # Initialize model
    _, init_params = init_random_params(rng, batch_shape)
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    # Training loop
    print("\nStarting training...")
    for epoch in range(FLAGS.nb_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_time = time.time() - start_time

        # Evaluate model on clean data
        params = get_params(opt_state)
        train_acc = accuracy(params, (train_images, train_labels))
        test_acc = accuracy(params, (test_images, test_labels))

        # Evaluate model on adversarial data
        model_fn = lambda images: predict(params, images)
        test_images_fgm = fast_gradient_method(model_fn, test_images,
                                               FLAGS.eps, np.inf)
        test_images_pgd = projected_gradient_descent(model_fn, test_images,
                                                     FLAGS.eps, 0.01, 40,
                                                     np.inf)
        test_acc_fgm = accuracy(params, (test_images_fgm, test_labels))
        test_acc_pgd = accuracy(params, (test_images_pgd, test_labels))

        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy: {}".format(train_acc))
        print("Test set accuracy on clean examples: {}".format(test_acc))
        print("Test set accuracy on FGM adversarial examples: {}".format(
            test_acc_fgm))
        print("Test set accuracy on PGD adversarial examples: {}".format(
            test_acc_pgd))
Esempio n. 30
0
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

"""# **Problem 2**

Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD.
"""

from jax.experimental import optimizers
from jax.experimental import stax

"""Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`"""

init_random_params, predict = stax.serial(
    stax.Conv(256, (5,5),strides = (2,2)),
    stax.Relu,
    stax.Conv(128, (3,3)),
    stax.Relu,
    stax.Conv(32, (3,3)),
    stax.Relu,
    stax.MaxPool((2,2)),
    stax.Flatten,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(128),
    stax.Relu,
    stax.Dense(10),
)

"""We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical)."""