def gen_cnn_conv4(output_units=10, W_initializers_str='glorot_normal()', b_initializers_str='normal()'): # This is an up-scaled version of the CNN in keras tutorial: https://keras.io/examples/cifar10_cnn/ return stax.serial( stax.Conv(out_chan=64, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Conv(out_chan=64, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.MaxPool((2, 2), strides=(2, 2)), stax.Conv(out_chan=128, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Conv(out_chan=128, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.MaxPool((2, 2), strides=(2, 2)), stax.Flatten, stax.Dense(512, W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Dense(output_units, W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)))
def MakeMain(input_shape): # the number of output channels depends on the number of input channels return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = stax.serial(stax.BatchNorm(), stax.Relu, stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(channels, (3, 3), padding='SAME')) shortcut = stax.Identity if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum)
def ConvBlock(self, kernel_size, filters, strides=(2, 2)): filters1, filters2, filters3 = filters Main = stax.serial( stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)), stax.BatchNorm()) Shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum, stax.Relu)
def gen_cnn_lenet_caffe(output_units = 10, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'): return stax.serial( stax.Conv(out_chan = 20, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ), stax.Relu, stax.MaxPool((2, 2), strides = (2, 2)), stax.Conv(out_chan = 50, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ), stax.Relu, stax.MaxPool((2, 2), strides = (2, 2)), stax.Flatten, stax.Dense(500, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)))
def wide_resnet_block(num_channels, strides=(1, 1), channel_mismatch=False): """Wide ResNet block.""" pre = stax.serial(stax.BatchNorm(), stax.Relu) mid = stax.serial( pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(num_channels, (3, 3), strides=(1, 1), padding='SAME')) if channel_mismatch: cut = stax.serial( pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME')) else: cut = stax.Identity return stax.serial(stax.FanOut(2), stax.parallel(mid, cut), stax.FanInSum)
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.serial(stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)), stax.BatchNorm()) shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum, stax.Relu)
def cnn(num_classes=10): return stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, # (-1, 800) stax.Dense(64), stax.Tanh, # embeddings stax.Dense(num_classes), # logits )
def cnn(): return stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(32), stax.Tanh, # embeddings stax.Dense(10), # logits )
def test_conv(self): order = 3 input_shape = (1, 5, 5, 1) key = random.PRNGKey(0) # TODO(duvenaud): Check all types of padding init_fun, apply_fun = stax.Conv(3, (2, 2), padding='VALID') _, (W, b) = init_fun(key, input_shape) rng = onp.random.RandomState(0) x = rng.randn(*input_shape).astype("float32") primals = (W, b, x) series_in1 = [ rng.randn(*W.shape).astype("float32") for _ in range(order) ] series_in2 = [ rng.randn(*b.shape).astype("float32") for _ in range(order) ] series_in3 = [ rng.randn(*x.shape).astype("float32") for _ in range(order) ] series_in = (series_in1, series_in2, series_in3) def f(W, b, x): return apply_fun((W, b), x) self.check_jet(f, primals, series_in, check_dtypes=False)
def Resnet50(hidden_size=64, num_output_classes=1001): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. Returns: The ResNet model with the given layer and output sizes. """ return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes), stax.LogSoftmax)
def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
def MyConv(*args, parameterization='standard', order=None, **kwargs): """Wrapper for convolutional layer with different parameterizations.""" if parameterization == 'standard': return jax_stax.Conv(*args, **kwargs) elif parameterization == 'ntk': return stax.Conv(*args, b_std=1.0, **kwargs)[:2] elif parameterization == 'taylor': return TaylorConv(*args, b_std=1.0, order=order, **kwargs)
def wide_resnet(n, k, num_classes): """Original WRN from paper and previous experiments.""" return stax.serial(stax.Conv(16, (3, 3), padding='SAME'), wide_resnet_group(n, 16 * k, strides=(1, 1)), wide_resnet_group(n, 32 * k, strides=(2, 2)), wide_resnet_group(n, 64 * k, strides=(2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten, stax.Dense(num_classes))
def conv(): init_fun, predict = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(32), stax.Relu, stax.Dense(10), ) def init_params(rng): return init_fun(rng, (-1, 28, 28, 1))[1] return init_params, predict
def Cnn(n_actions: int, hidden_size: int = 512) -> Module: return stax.serial( stax.Conv(32, (8, 8), (4, 4), "VALID"), stax.Relu, stax.Conv(64, (4, 4), (2, 2), "VALID"), stax.Relu, stax.Conv(64, (3, 3), (1, 1), "VALID"), stax.Relu, stax.Flatten, stax.Dense(hidden_size), stax.Relu, stax.FanOut(2), stax.parallel( stax.serial( stax.Dense(n_actions), stax.Softmax, ), # actor stax.serial(stax.Dense(1), ), # critic ), )
def transform(rng, input_dim, output_dim): init_fun, apply_fun = stax.serial( Reshape(), stax.Conv(8, filter_shape=(3, 3), W_init=weight_initializer, b_init=weight_initializer), act, stax.Conv(16, filter_shape=(3, 3), W_init=weight_initializer, b_init=weight_initializer), act, stax.Flatten, stax.Dense(output_dim, W_init=weight_initializer, b_init=weight_initializer), ) _, params = init_fun(rng, (input_dim, )) return params, apply_fun
def make_conv( strides=None, num_channels=256, ): return stax.Conv( out_chan=num_channels, filter_shape=(3, 3), padding="VALID", strides=strides, W_init=nn.initializers.he_normal(), b_init=nn.initializers.zeros, )
def conv2d(num_classes, layers=((32, 5, 2), (16, 3, 2), (16, 3, 2))): """Builds a simple convolutional neural network.""" stack = [] # Concatenate convolutional layers. for num_units, kernel_size, stride in layers: stack += [ stax.Conv(num_units, (kernel_size, kernel_size), (stride, stride), padding='SAME'), stax.Relu ] # Output layer. stack += [stax.Flatten, stax.Dense(num_classes), stax.LogSoftmax] return stax.serial(*stack)
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. Returns: The WideResnet model with given layer and output sizes. """ return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten, stax.Dense(num_output_classes), stax.LogSoftmax)
def cnn(conv_depth=300, kernel_size=5, n_conv_layers=2, across_batch=False, add_pos_encoding=False): """Build convolutional neural net.""" # Input shape: [batch x length x depth] if across_batch: extra_dim = 0 else: extra_dim = 1 layers = [ExpandDims(axis=extra_dim)] if add_pos_encoding: layers.append(positional_encoding()) for _ in range(n_conv_layers): layers.append( stax.Conv(conv_depth, (1, kernel_size), padding="same", strides=(1, 1))) layers.append(stax.Relu) layers.append(AssertNonZeroShape()) layers.append(squeeze_layer(axis=extra_dim)) return stax.serial(*layers)
num_batches = num_complete_batches + bool(leftover) # %% def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() init_fun, net = stax.serial(stax.Conv(16, (3, 3), (1, 1), padding="SAME"), stax.Relu, stax.MaxPool((2, 2), (2, 2), padding="SAME"), stax.Conv(32, (3, 3), (1, 1), padding="SAME"), stax.Relu, stax.MaxPool((2, 2), (2, 2), padding="SAME"), stax.Flatten, stax.Dense(10), stax.LogSoftmax) _, params = init_fun(key, (64, 1, 28, 28)) def loss(params, batch): inputs, targets = batch preds = net(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1))
#test over here for adversial #test_newx = computation(params, (train_images, train_labels)) #print(test_newx) """# **Problem 2** Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD. """ from jax.experimental import optimizers from jax.experimental import stax import matplotlib.pyplot as plt """Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`""" init_random_params, predict = stax.serial( stax.Conv(64, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Conv(128, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(128), stax.Relu, stax.Dense(10), stax.LogSoftmax) """We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical).""" def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.mean(preds * targets) """Next, we define the mini-batch SGD optimizer, this time with the optimizers library in JAX."""
flags.DEFINE_float('learning_rate', .10, 'Learning rate for finetuning.') flags.DEFINE_integer('batch_size', 256, 'Batch size') flags.DEFINE_integer('epochs', 100, 'Number of finetuning epochs') flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') flags.DEFINE_integer( 'uncertain', 0, '0: entropy' '1: difference between 1st_prob and 2nd_prob' '2: random') flags.DEFINE_integer('n_extra', 3000, 'number of extra points') flags.DEFINE_bool( 'show_label', True, 'visualize predicted label at top/left, true at bottom/right') # BEGIN: define the classifier model init_fn_0, apply_fn_0 = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, # (-1, 800) stax.Dense(32), stax.Tanh, # embeddings ) init_fn_1, apply_fn_1 = stax.serial( stax.Dense(10), # logits )
def main(_): logging.info('Starting experiment.') configs = FLAGS.config # Create model folder for outputs try: gfile.MakeDirs(FLAGS.exp_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+') if configs.optimization == 'sgd': lr_schedule = optimizers.make_schedule(configs.learning_rate) opt_init, opt_update, get_params = optimizers.sgd(lr_schedule) elif configs.optimization == 'momentum': lr_schedule = cosine(configs.learning_rate, configs.train_steps) opt_init, opt_update, get_params = optimizers.momentum( lr_schedule, 0.9) else: raise ValueError('Optimizer not implemented.') with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre: pretrained_opt_state = optimizers.pack_optimizer_state( pickle.load(fpre)) fixed_params = get_params(pretrained_opt_state)[:7] # BEGIN: define the classifier model init_fn_0, apply_fn_0 = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, # representations ) init_fn_1, apply_fn_1 = stax.serial( stax.Dense(64), stax.Tanh, # embeddings stax.Dense(10), # logits ) def predict(params, inputs): representations = apply_fn_0(fixed_params, inputs) # use pretrained params logits = apply_fn_1(params, representations) return logits # END: define the classifier model if configs.seed is not None: key = random.PRNGKey(configs.seed) else: key = random.PRNGKey(int(time.time())) _, _ = init_fn_0(key, (-1, 32, 32, 3)) _, params = init_fn_1(key, (-1, 800)) opt_state = opt_init(params) logging.info('Loading data.') tic = time.time() train_images, train_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'train') train_mu, train_std = onp.mean(train_images), onp.std(train_images) n_train = len(train_images) train = data.DataChunk(X=(train_images - train_mu) / train_std, Y=train_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') test_images, test_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'test') test = data.DataChunk( X=(test_images - train_mu) / train_std, # normalize w train mean/std Y=test_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') # Data augmentation if configs.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None batch = data.minibatcher(train, configs.batch_size, transform=augmentation) # count params of JAX model def count_parameters(params): return tree_util.tree_reduce( operator.add, tree_util.tree_map(lambda x: np.prod(x.shape), params)) logging.info('Number of parameters: %d', count_parameters(params)) stdout_log.write('Number of params: {}\n'.format(count_parameters(params))) # loss functions def cross_entropy_loss(params, x_img, y_lbl): return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl) def mse_loss(params, x_img, y_lbl): return 0.5 * np.mean((y_lbl - predict(params, x_img))**2) def accuracy(y_lbl_hat, y_lbl): target_class = np.argmax(y_lbl, axis=1) predicted_class = np.argmax(y_lbl_hat, axis=1) return np.mean(predicted_class == target_class) # Loss and gradient if configs.loss == 'xent': loss = cross_entropy_loss elif configs.loss == 'mse': loss = mse_loss else: raise ValueError('Loss function not implemented.') grad_loss = jit(grad(loss)) # learning rate schedule and optimizer def cosine(initial_step_size, train_steps): k = np.pi / (2.0 * train_steps) def schedule(i): return initial_step_size * np.cos(k * i) return schedule def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size): """Return differentially private gradients of params, evaluated on batch.""" def _clipped_grad(params, single_example_batch): """Evaluate gradient for a single-example batch and clip its grad norm.""" grads = grad_loss(params, single_example_batch[0].reshape( (-1, 32, 32, 3)), single_example_batch[1]) nonempty_grads, tree_def = tree_util.tree_flatten(grads) total_grad_norm = np.linalg.norm( [np.linalg.norm(neg.ravel()) for neg in nonempty_grads]) divisor = stop_gradient( np.amax((total_grad_norm / l2_norm_clip, 1.))) normalized_nonempty_grads = [ neg / divisor for neg in nonempty_grads ] return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads) px_clipped_grad_fn = vmap(partial(_clipped_grad, params)) std_dev = l2_norm_clip * noise_multiplier noise_ = lambda n: n + std_dev * random.normal(rng, n.shape) normalize_ = lambda n: n / float(batch_size) sum_ = lambda n: np.sum(n, 0) # aggregate aggregated_clipped_grads = tree_util.tree_map( sum_, px_clipped_grad_fn(batch)) noised_aggregated_clipped_grads = tree_util.tree_map( noise_, aggregated_clipped_grads) normalized_noised_aggregated_clipped_grads = (tree_util.tree_map( normalize_, noised_aggregated_clipped_grads)) return normalized_noised_aggregated_clipped_grads # summarize measurements steps_per_epoch = n_train // configs.batch_size def summarize(step, params): """Compute measurements in a zipped way.""" set_entries = [train, test] set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize] set_names, loss_dict, acc_dict = ['train', 'test'], {}, {} for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes, set_names): temp_loss, temp_acc, points = 0.0, 0.0, 0 for b in data.batch(set_entry, set_bsize): temp_loss += loss(params, b.X, b.Y) * b.X.shape[0] temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0] points += b.X.shape[0] loss_dict[set_name] = temp_loss / float(points) acc_dict[set_name] = temp_acc / float(points) logging.info('Step: %s', str(step)) logging.info('Train acc : %.4f', acc_dict['train']) logging.info('Train loss: %.4f', loss_dict['train']) logging.info('Test acc : %.4f', acc_dict['test']) logging.info('Test loss : %.4f', loss_dict['test']) stdout_log.write('Step: {}\n'.format(step)) stdout_log.write('Train acc : {}\n'.format(acc_dict['train'])) stdout_log.write('Train loss: {}\n'.format(loss_dict['train'])) stdout_log.write('Test acc : {}\n'.format(acc_dict['test'])) stdout_log.write('Test loss : {}\n'.format(loss_dict['test'])) stdout_log.flush() return acc_dict['test'] toc = time.time() logging.info('Elapsed SETUP time: %s', str(toc - tic)) stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic)) # BEGIN: training steps logging.info('Training network.') tic = time.time() t = time.time() for s in range(configs.train_steps): b = next(batch) params = get_params(opt_state) # t0 = time.time() if FLAGS.dpsgd: key = random.fold_in(key, s) # get new key for new random numbers opt_state = opt_update( s, private_grad(params, (b.X.reshape( (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip, configs.noise_multiplier, configs.batch_size), opt_state) else: opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state) # t1 = time.time() # logging.info('batch update time: %s', str(t1 - t0)) if s % steps_per_epoch == 0: with gfile.Open( '{}/ckpt_{}'.format(FLAGS.exp_dir, int(s / steps_per_epoch)), 'wb') as fckpt: pickle.dump(optimizers.unpack_optimizer_state(opt_state), fckpt) if FLAGS.dpsgd: eps = compute_epsilon(s, configs.batch_size, n_train, configs.target_delta, configs.noise_multiplier) stdout_log.write( 'For delta={:.0e}, current epsilon is: {:.2f}\n'.format( configs.target_delta, eps)) logging.info('Elapsed EPOCH time: %s', str(time.time() - t)) stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t)) stdout_log.flush() t = time.time() toc = time.time() summarize(configs.train_steps, params) logging.info('Elapsed TRAIN time: %s', str(toc - tic)) stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic)) stdout_log.close()
def main(_): rng = random.PRNGKey(0) # Load MNIST dataset train_images, train_labels, test_images, test_labels = datasets.mnist() batch_size = 128 batch_shape = (-1, 28, 28, 1) num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) train_images = np.reshape(train_images, batch_shape) test_images = np.reshape(test_images, batch_shape) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] def save(fn, opt_state): params = deepcopy(get_params(opt_state)) save_dict = {} for idx, p in enumerate(params): if (p != ()): pp = (p[0].tolist(), p[1].tolist()) params[idx] = pp save_dict["params"] = params with open(fn, "w") as f: json.dump(save_dict, f) def load(fn): with open(fn, "r") as f: params = json.load(f) params = params["params"] for idx, p in enumerate(params): if (p != []): pp = (np.array(p[0]), np.array(p[1])) params[idx] = pp else: params[idx] = () return opt_init(params) batches = data_stream() # Model, loss, and accuracy functions init_random_params, predict = stax.serial( stax.Conv(32, (8, 8), strides=(2, 2), padding="SAME"), stax.Relu, stax.Conv(128, (6, 6), strides=(2, 2), padding="VALID"), stax.Relu, stax.Conv(128, (5, 5), strides=(1, 1), padding="VALID"), stax.Flatten, stax.Dense(128), stax.Relu, stax.Dense(10), ) def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(logsoftmax(preds) * targets) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) def gen_ellipsoid(X, zeta_rel, zeta_const, alpha, N_steps): zeta = (np.abs(X).T * zeta_rel).T + zeta_const if (alpha is None): alpha = 1 / N_steps * zeta else: assert isinstance(alpha, float), "Alpha must be float" alpha = alpha * np.ones_like(X) return zeta, alpha def gen_ellipsoid_match_volume(X, zeta_const, eps, alpha, N_steps): x_norms = np.linalg.norm(np.reshape(X, (X.shape[0], -1)), ord=1, axis=1) N = np.prod(X.shape[1:]) zeta_rel = N * (eps - zeta_const) / x_norms assert (zeta_rel <= 1.0).all( ), "Zeta rel cannot be larger than 1. Please increase zeta const or reduce eps" zeta_rel = np.clip(0.0, zeta_rel, 1.0) return gen_ellipsoid(X, zeta_rel, zeta_const, alpha, N_steps) # Instantiate an optimizer opt_init, opt_update, get_params = optimizers.adam(0.001) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) # Initialize model _, init_params = init_random_params(rng, batch_shape) opt_state = opt_init(init_params) itercount = itertools.count() try: opt_state = load("tutorials/jax/test_model.json") except: # Training loop print("\nStarting training...") for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches)) epoch_time = time.time() - start_time save("tutorials/jax/test_model.json", opt_state) # Evaluate model on clean data params = get_params(opt_state) # Evaluate model on adversarial data model_fn = lambda images: predict(params, images) # Generate single attacking test image idx = 0 plt.figure(figsize=(15, 6), constrained_layout=True) zeta, alpha = gen_ellipsoid(X=test_images[idx].reshape((1, 28, 28, 1)), zeta_rel=FLAGS.zeta_rel, zeta_const=FLAGS.zeta_const, alpha=None, N_steps=40) # zeta, alpha = gen_ellipsoid_match_volume(X=test_images[idx].reshape((1,28,28,1)), zeta_const=FLAGS.zeta_const, eps=FLAGS.eps, alpha=None, N_steps=40) test_images_pgd_ellipsoid = projected_gradient_descent( model_fn, test_images[idx].reshape((1, 28, 28, 1)), zeta, alpha, 40, np.inf) predict_pgd_ellipsoid = np.argmax(predict(params, test_images_pgd_ellipsoid), axis=1) test_images_fgm = fast_gradient_method( model_fn, test_images[idx].reshape((1, 28, 28, 1)), 0.075, np.inf) predict_fgm = np.argmax(predict(params, test_images_fgm), axis=1) test_images_pgd = projected_gradient_descent( model_fn, test_images[idx].reshape((1, 28, 28, 1)), FLAGS.eps, 0.01, 40, 2) predict_pgd = np.argmax(predict(params, test_images_pgd), axis=1) base = 100 f_ = lambda x: np.log(x) / np.log(base) a = base - 1 transform = 1 + a * test_images[idx].reshape((1, 28, 28, 1)) # [1,base] # test_images_pgd_transform = projected_gradient_descent(model_fn, f_(np.where(transform > base,base,transform)), FLAGS.zeta_rel, 0.01, 40, np.inf) test_images_pgd_transform = projected_gradient_descent( model_fn, f_(np.where(transform > base, base, transform)), 1.8, 0.01, 40, 2) test_images_pgd_transform = np.clip(test_images_pgd_transform, 0.0, 1.0) test_images_pgd_transform = (base**test_images_pgd_transform - 1) / a predict_transform = np.argmax(predict(params, test_images_pgd_transform), axis=1) plt.subplot(151) plt.imshow(np.squeeze(test_images[idx]), cmap='gray') plt.title("Original") plt.subplot(152) plt.imshow(np.squeeze(test_images_fgm), cmap='gray') plt.title(f"FGM L-Inf Pred: {predict_fgm}") plt.subplot(153) plt.imshow(np.squeeze(test_images_pgd), cmap='gray') plt.title(f"PGD L2 {predict_pgd}") plt.subplot(154) plt.imshow(np.squeeze(test_images_pgd_ellipsoid), cmap='gray') plt.title(f"PGD Ellipsoid L-Inf Pred: {predict_pgd_ellipsoid}") plt.subplot(155) plt.imshow(np.squeeze(test_images_pgd_transform), cmap='gray') plt.title(f"PGD log{base} L2 Pred: {predict_transform}") plt.show() transform = 1 + a * test_images test_images_pgd_transform = projected_gradient_descent( model_fn, f_(np.where(transform > base, base, transform)), FLAGS.zeta_rel, 0.01, 40, np.inf) test_images_pgd_transform = np.clip(test_images_pgd_transform, 0.0, 1.0) test_images_pgd_transform = (base**test_images_pgd_transform - 1) / a test_acc_pgd_transform = accuracy(params, (test_images_pgd_transform, test_labels)) # Generate whole attacking test images # zeta, alpha = gen_ellipsoid(X=test_images, zeta_rel=FLAGS.zeta_rel, zeta_const=FLAGS.zeta_const, alpha=None, N_steps=40) zeta, alpha = gen_ellipsoid_match_volume(X=test_images, zeta_const=FLAGS.zeta_const, eps=FLAGS.eps, alpha=None, N_steps=40) test_images_pgd_ellipsoid = projected_gradient_descent( model_fn, test_images, zeta, alpha, 40, np.inf) test_acc_pgd_ellipsoid = accuracy(params, (test_images_pgd_ellipsoid, test_labels)) test_images_fgm = fast_gradient_method(model_fn, test_images, FLAGS.eps, np.inf) test_images_pgd = projected_gradient_descent(model_fn, test_images, FLAGS.eps, 0.01, 40, np.inf) test_acc_fgm = accuracy(params, (test_images_fgm, test_labels)) test_acc_pgd = accuracy(params, (test_images_pgd, test_labels)) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) print("Training set accuracy: {}".format(train_acc)) print("Test set accuracy on clean examples: {}".format(test_acc)) print("Test set accuracy on FGM adversarial examples: {}".format( test_acc_fgm)) print("Test set accuracy on PGD adversarial examples: {}".format( test_acc_pgd)) print("Test set accuracy on PGD Ellipsoid adversarial examples: {}".format( test_acc_pgd_ellipsoid)) print( "Test set accuracy on PGD Ellipsoid via transform adversarial examples: {}" .format(test_acc_pgd_transform))
class StaxTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape } for shape in [(2, 3), (5, )])) def testRandnInitShape(self, shape): out = stax.randn()(shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape } for shape in [(2, 3), (2, 3, 4)])) def testGlorotInitShape(self, shape): out = stax.glorot()(shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 10, 11, 1)])) def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_out_dim={}_input_shape={}".format(out_dim, input_shape), "out_dim": out_dim, "input_shape": input_shape } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)])) def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(2, 3), (2, 3, 4)])) def testReluShape(self, input_shape): init_fun, apply_fun = stax.Relu _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}".format( window_shape, padding, strides, input_shape), "window_shape": window_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 5, 6, 1)])) def testPoolingShape(self, window_shape, padding, strides, input_shape): init_fun, apply_fun = stax.MaxPool(window_shape, padding=padding, strides=strides) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(2, 3), (2, 3, 4)])) def testFlattenShape(self, input_shape): init_fun, apply_fun = stax.Flatten _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}_spec={}".format( input_shape, i), "input_shape": input_shape, "spec": spec } for input_shape in [(2, 5, 6, 1)] for i, spec in enumerate([[stax.Conv(3, ( 2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))) def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testDropoutShape(self, input_shape): init_fun, apply_fun = stax.Dropout(0.9) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testFanInSum(self, input_shape): init_fun, apply_fun = stax.FanInSum _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape]) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inshapes={}_axis={}".format( input_shapes, axis), "input_shapes": input_shapes, "axis": axis } for input_shapes, axis in [ ([(2, 3), (2, 1)], 1), ([(2, 3), (2, 1)], -1), ([(1, 2, 4), (1, 1, 4)], 1), ])) def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)
class StaxTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape } for shape in [(2, 3), (5, )])) def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(shape), "shape": shape } for shape in [(2, 3), (2, 3, 4)])) def testGlorotInitShape(self, shape): key = random.PRNGKey(0) out = stax.glorot()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 10, 11, 1)])) def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_out_dim={}_input_shape={}".format(out_dim, input_shape), "out_dim": out_dim, "input_shape": input_shape } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)])) def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(2, 3), (2, 3, 4)])) def testReluShape(self, input_shape): init_fun, apply_fun = stax.Relu _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}" "_maxpool={}".format(window_shape, padding, strides, input_shape, max_pool), "window_shape": window_shape, "padding": padding, "strides": strides, "input_shape": input_shape, "max_pool": max_pool } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 5, 6, 1)] for max_pool in [False, True])) def testPoolingShape(self, window_shape, padding, strides, input_shape, max_pool): layer = stax.MaxPool if max_pool else stax.AvgPool init_fun, apply_fun = layer(window_shape, padding=padding, strides=strides) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(2, 3), (2, 3, 4)])) def testFlattenShape(self, input_shape): init_fun, apply_fun = stax.Flatten _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}_spec={}".format( input_shape, i), "input_shape": input_shape, "spec": spec } for input_shape in [(2, 5, 6, 1)] for i, spec in enumerate([[stax.Conv(3, ( 2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))) def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testDropoutShape(self, input_shape): init_fun, apply_fun = stax.Dropout(0.9) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape } for input_shape in [(3, 4), (2, 5, 6, 1)])) def testFanInSum(self, input_shape): init_fun, apply_fun = stax.FanInSum _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape]) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inshapes={}_axis={}".format( input_shapes, axis), "input_shapes": input_shapes, "axis": axis } for input_shapes, axis in [ ([(2, 3), (2, 1)], 1), ([(2, 3), (2, 1)], -1), ([(1, 2, 4), (1, 1, 4)], 1), ])) def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes) def testIssue182(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.Softmax input_shape = (10, 3) inputs = onp.arange(30.).astype("float32").reshape(input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) assert out_shape == out.shape assert onp.allclose(onp.sum(onp.asarray(out), -1), 1.) def testBatchNormShapeNHWC(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2)) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (7, )) self.assertEqual(gamma.shape, (7, )) self.assertEqual(out_shape, out.shape) def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) # Regression test for https://github.com/google/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (5, )) self.assertEqual(gamma.shape, (5, )) self.assertEqual(out_shape, out.shape)
def main(_): rng = random.PRNGKey(0) # Load MNIST dataset train_images, train_labels, test_images, test_labels = datasets.mnist() batch_size = 128 batch_shape = (-1, 28, 28, 1) num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) train_images = np.reshape(train_images, batch_shape) test_images = np.reshape(test_images, batch_shape) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() # Model, loss, and accuracy functions init_random_params, predict = stax.serial( stax.Conv(32, (8, 8), strides=(2, 2), padding='SAME'), stax.Relu, stax.Conv(128, (6, 6), strides=(2, 2), padding='VALID'), stax.Relu, stax.Conv(128, (5, 5), strides=(1, 1), padding='VALID'), stax.Flatten, stax.Dense(128), stax.Relu, stax.Dense(10)) def loss(params, batch): inputs, targets = batch preds = predict(params, inputs) return -np.mean(logsoftmax(preds) * targets) def accuracy(params, batch): inputs, targets = batch target_class = np.argmax(targets, axis=1) predicted_class = np.argmax(predict(params, inputs), axis=1) return np.mean(predicted_class == target_class) # Instantiate an optimizer opt_init, opt_update, get_params = optimizers.adam(0.001) @jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) # Initialize model _, init_params = init_random_params(rng, batch_shape) opt_state = opt_init(init_params) itercount = itertools.count() # Training loop print("\nStarting training...") for epoch in range(FLAGS.nb_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches)) epoch_time = time.time() - start_time # Evaluate model on clean data params = get_params(opt_state) train_acc = accuracy(params, (train_images, train_labels)) test_acc = accuracy(params, (test_images, test_labels)) # Evaluate model on adversarial data model_fn = lambda images: predict(params, images) test_images_fgm = fast_gradient_method(model_fn, test_images, FLAGS.eps, np.inf) test_images_pgd = projected_gradient_descent(model_fn, test_images, FLAGS.eps, 0.01, 40, np.inf) test_acc_fgm = accuracy(params, (test_images_fgm, test_labels)) test_acc_pgd = accuracy(params, (test_images_pgd, test_labels)) print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy: {}".format(train_acc)) print("Test set accuracy on clean examples: {}".format(test_acc)) print("Test set accuracy on FGM adversarial examples: {}".format( test_acc_fgm)) print("Test set accuracy on PGD adversarial examples: {}".format( test_acc_pgd))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time)) print("Training set accuracy {}".format(train_acc)) print("Test set accuracy {}".format(test_acc)) """# **Problem 2** Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD. """ from jax.experimental import optimizers from jax.experimental import stax """Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`""" init_random_params, predict = stax.serial( stax.Conv(256, (5,5),strides = (2,2)), stax.Relu, stax.Conv(128, (3,3)), stax.Relu, stax.Conv(32, (3,3)), stax.Relu, stax.MaxPool((2,2)), stax.Flatten, stax.Dense(1024), stax.Relu, stax.Dense(128), stax.Relu, stax.Dense(10), ) """We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical)."""