def gen_cnn_conv4(output_units=10, W_initializers_str='glorot_normal()', b_initializers_str='normal()'): # This is an up-scaled version of the CNN in keras tutorial: https://keras.io/examples/cifar10_cnn/ return stax.serial( stax.Conv(out_chan=64, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Conv(out_chan=64, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.MaxPool((2, 2), strides=(2, 2)), stax.Conv(out_chan=128, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Conv(out_chan=128, filter_shape=(3, 3), W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.MaxPool((2, 2), strides=(2, 2)), stax.Flatten, stax.Dense(512, W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)), stax.Relu, stax.Dense(output_units, W_init=eval(W_initializers_str), b_init=eval(b_initializers_str)))
def gen_cnn_lenet_caffe(output_units = 10, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'): return stax.serial( stax.Conv(out_chan = 20, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ), stax.Relu, stax.MaxPool((2, 2), strides = (2, 2)), stax.Conv(out_chan = 50, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ), stax.Relu, stax.MaxPool((2, 2), strides = (2, 2)), stax.Flatten, stax.Dense(500, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)), stax.Relu, stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)))
def cnn(num_classes=10): return stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, # (-1, 800) stax.Dense(64), stax.Tanh, # embeddings stax.Dense(num_classes), # logits )
def cnn(): return stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(32), stax.Tanh, # embeddings stax.Dense(10), # logits )
def Resnet50(hidden_size=64, num_output_classes=1001): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. Returns: The ResNet model with the given layer and output sizes. """ return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes), stax.LogSoftmax)
def __init__(self, kw, kh, name=None): super(MaxPool2d, self).__init__(name) _, self.maxpool = jexp.MaxPool((kw, kh)) if name is None: self.name = F'MaxPool2d+{rand_string()}'
def __init__(self, num_classes=100, encoding=True): blocks = [ stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), self.ConvBlock(3, [64, 64, 256], strides=(1, 1)), self.IdentityBlock(3, [64, 64]), self.IdentityBlock(3, [64, 64]), self.ConvBlock(3, [128, 128, 512]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.ConvBlock(3, [256, 256, 1024]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.ConvBlock(3, [512, 512, 2048]), self.IdentityBlock(3, [512, 512]), self.IdentityBlock(3, [512, 512]), stax.AvgPool((7, 7)) ] if not encoding: blocks.append(stax.Flatten) blocks.append(stax.Dense(num_classes)) self.model = stax.serial(*blocks)
def conv(): init_fun, predict = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(32), stax.Relu, stax.Dense(10), ) def init_params(rng): return init_fun(rng, (-1, 28, 28, 1))[1] return init_params, predict
# %% def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() init_fun, net = stax.serial(stax.Conv(16, (3, 3), (1, 1), padding="SAME"), stax.Relu, stax.MaxPool((2, 2), (2, 2), padding="SAME"), stax.Conv(32, (3, 3), (1, 1), padding="SAME"), stax.Relu, stax.MaxPool((2, 2), (2, 2), padding="SAME"), stax.Flatten, stax.Dense(10), stax.LogSoftmax) _, params = init_fun(key, (64, 1, 28, 28)) def loss(params, batch): inputs, targets = batch preds = net(params, inputs) return -jnp.mean(jnp.sum(preds * targets, axis=1)) def accuracy(params, batch):
#test_newx = computation(params, (train_images, train_labels)) #print(test_newx) """# **Problem 2** Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD. """ from jax.experimental import optimizers from jax.experimental import stax import matplotlib.pyplot as plt """Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`""" init_random_params, predict = stax.serial( stax.Conv(64, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Conv(128, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(128), stax.Relu, stax.Dense(10), stax.LogSoftmax) """We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical).""" def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.mean(preds * targets) """Next, we define the mini-batch SGD optimizer, this time with the optimizers library in JAX."""
flags.DEFINE_integer('epochs', 100, 'Number of finetuning epochs') flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') flags.DEFINE_integer( 'uncertain', 0, '0: entropy' '1: difference between 1st_prob and 2nd_prob' '2: random') flags.DEFINE_integer('n_extra', 3000, 'number of extra points') flags.DEFINE_bool( 'show_label', True, 'visualize predicted label at top/left, true at bottom/right') # BEGIN: define the classifier model init_fn_0, apply_fn_0 = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, # (-1, 800) stax.Dense(32), stax.Tanh, # embeddings ) init_fn_1, apply_fn_1 = stax.serial( stax.Dense(10), # logits ) def predict(params, inputs): params_0 = params[:-1]
Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD. """ from jax.experimental import optimizers from jax.experimental import stax """Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`""" init_random_params, predict = stax.serial( stax.Conv(256, (5,5),strides = (2,2)), stax.Relu, stax.Conv(128, (3,3)), stax.Relu, stax.Conv(32, (3,3)), stax.Relu, stax.MaxPool((2,2)), stax.Flatten, stax.Dense(1024), stax.Relu, stax.Dense(128), stax.Relu, stax.Dense(10), ) """We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical).""" def loss(params, batch): inputs, targets = batch logits = predict(params, inputs) preds = stax.logsoftmax(logits) return -np.sum(targets*preds)/len(targets)
def testPoolingShape(self, window_shape, padding, strides, input_shape): init_fun, apply_fun = stax.MaxPool(window_shape, padding=padding, strides=strides) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
def main(_): logging.info('Starting experiment.') configs = FLAGS.config # Create model folder for outputs try: gfile.MakeDirs(FLAGS.exp_dir) except gfile.GOSError: pass stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+') if configs.optimization == 'sgd': lr_schedule = optimizers.make_schedule(configs.learning_rate) opt_init, opt_update, get_params = optimizers.sgd(lr_schedule) elif configs.optimization == 'momentum': lr_schedule = cosine(configs.learning_rate, configs.train_steps) opt_init, opt_update, get_params = optimizers.momentum( lr_schedule, 0.9) else: raise ValueError('Optimizer not implemented.') with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre: pretrained_opt_state = optimizers.pack_optimizer_state( pickle.load(fpre)) fixed_params = get_params(pretrained_opt_state)[:7] # BEGIN: define the classifier model init_fn_0, apply_fn_0 = stax.serial( stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)), stax.Tanh, stax.MaxPool((2, 2), (1, 1)), stax.Flatten, # representations ) init_fn_1, apply_fn_1 = stax.serial( stax.Dense(64), stax.Tanh, # embeddings stax.Dense(10), # logits ) def predict(params, inputs): representations = apply_fn_0(fixed_params, inputs) # use pretrained params logits = apply_fn_1(params, representations) return logits # END: define the classifier model if configs.seed is not None: key = random.PRNGKey(configs.seed) else: key = random.PRNGKey(int(time.time())) _, _ = init_fn_0(key, (-1, 32, 32, 3)) _, params = init_fn_1(key, (-1, 800)) opt_state = opt_init(params) logging.info('Loading data.') tic = time.time() train_images, train_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'train') train_mu, train_std = onp.mean(train_images), onp.std(train_images) n_train = len(train_images) train = data.DataChunk(X=(train_images - train_mu) / train_std, Y=train_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') test_images, test_labels, _ = datasets.get_dataset_split( FLAGS.dataset, 'test') test = data.DataChunk( X=(test_images - train_mu) / train_std, # normalize w train mean/std Y=test_labels, image_size=32, image_channels=3, label_dim=1, label_format='numeric') # Data augmentation if configs.augment_data: augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5), data.RandomCrop(4), data.ToDevice) else: augmentation = None batch = data.minibatcher(train, configs.batch_size, transform=augmentation) # count params of JAX model def count_parameters(params): return tree_util.tree_reduce( operator.add, tree_util.tree_map(lambda x: np.prod(x.shape), params)) logging.info('Number of parameters: %d', count_parameters(params)) stdout_log.write('Number of params: {}\n'.format(count_parameters(params))) # loss functions def cross_entropy_loss(params, x_img, y_lbl): return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl) def mse_loss(params, x_img, y_lbl): return 0.5 * np.mean((y_lbl - predict(params, x_img))**2) def accuracy(y_lbl_hat, y_lbl): target_class = np.argmax(y_lbl, axis=1) predicted_class = np.argmax(y_lbl_hat, axis=1) return np.mean(predicted_class == target_class) # Loss and gradient if configs.loss == 'xent': loss = cross_entropy_loss elif configs.loss == 'mse': loss = mse_loss else: raise ValueError('Loss function not implemented.') grad_loss = jit(grad(loss)) # learning rate schedule and optimizer def cosine(initial_step_size, train_steps): k = np.pi / (2.0 * train_steps) def schedule(i): return initial_step_size * np.cos(k * i) return schedule def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier, batch_size): """Return differentially private gradients of params, evaluated on batch.""" def _clipped_grad(params, single_example_batch): """Evaluate gradient for a single-example batch and clip its grad norm.""" grads = grad_loss(params, single_example_batch[0].reshape( (-1, 32, 32, 3)), single_example_batch[1]) nonempty_grads, tree_def = tree_util.tree_flatten(grads) total_grad_norm = np.linalg.norm( [np.linalg.norm(neg.ravel()) for neg in nonempty_grads]) divisor = stop_gradient( np.amax((total_grad_norm / l2_norm_clip, 1.))) normalized_nonempty_grads = [ neg / divisor for neg in nonempty_grads ] return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads) px_clipped_grad_fn = vmap(partial(_clipped_grad, params)) std_dev = l2_norm_clip * noise_multiplier noise_ = lambda n: n + std_dev * random.normal(rng, n.shape) normalize_ = lambda n: n / float(batch_size) sum_ = lambda n: np.sum(n, 0) # aggregate aggregated_clipped_grads = tree_util.tree_map( sum_, px_clipped_grad_fn(batch)) noised_aggregated_clipped_grads = tree_util.tree_map( noise_, aggregated_clipped_grads) normalized_noised_aggregated_clipped_grads = (tree_util.tree_map( normalize_, noised_aggregated_clipped_grads)) return normalized_noised_aggregated_clipped_grads # summarize measurements steps_per_epoch = n_train // configs.batch_size def summarize(step, params): """Compute measurements in a zipped way.""" set_entries = [train, test] set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize] set_names, loss_dict, acc_dict = ['train', 'test'], {}, {} for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes, set_names): temp_loss, temp_acc, points = 0.0, 0.0, 0 for b in data.batch(set_entry, set_bsize): temp_loss += loss(params, b.X, b.Y) * b.X.shape[0] temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0] points += b.X.shape[0] loss_dict[set_name] = temp_loss / float(points) acc_dict[set_name] = temp_acc / float(points) logging.info('Step: %s', str(step)) logging.info('Train acc : %.4f', acc_dict['train']) logging.info('Train loss: %.4f', loss_dict['train']) logging.info('Test acc : %.4f', acc_dict['test']) logging.info('Test loss : %.4f', loss_dict['test']) stdout_log.write('Step: {}\n'.format(step)) stdout_log.write('Train acc : {}\n'.format(acc_dict['train'])) stdout_log.write('Train loss: {}\n'.format(loss_dict['train'])) stdout_log.write('Test acc : {}\n'.format(acc_dict['test'])) stdout_log.write('Test loss : {}\n'.format(loss_dict['test'])) stdout_log.flush() return acc_dict['test'] toc = time.time() logging.info('Elapsed SETUP time: %s', str(toc - tic)) stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic)) # BEGIN: training steps logging.info('Training network.') tic = time.time() t = time.time() for s in range(configs.train_steps): b = next(batch) params = get_params(opt_state) # t0 = time.time() if FLAGS.dpsgd: key = random.fold_in(key, s) # get new key for new random numbers opt_state = opt_update( s, private_grad(params, (b.X.reshape( (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip, configs.noise_multiplier, configs.batch_size), opt_state) else: opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state) # t1 = time.time() # logging.info('batch update time: %s', str(t1 - t0)) if s % steps_per_epoch == 0: with gfile.Open( '{}/ckpt_{}'.format(FLAGS.exp_dir, int(s / steps_per_epoch)), 'wb') as fckpt: pickle.dump(optimizers.unpack_optimizer_state(opt_state), fckpt) if FLAGS.dpsgd: eps = compute_epsilon(s, configs.batch_size, n_train, configs.target_delta, configs.noise_multiplier) stdout_log.write( 'For delta={:.0e}, current epsilon is: {:.2f}\n'.format( configs.target_delta, eps)) logging.info('Elapsed EPOCH time: %s', str(time.time() - t)) stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t)) stdout_log.flush() t = time.time() toc = time.time() summarize(configs.train_steps, params) logging.info('Elapsed TRAIN time: %s', str(toc - tic)) stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic)) stdout_log.close()
def u_net(in_channels: int, out_channels: int, levels: int = 4, filters: int or tuple or list = 16, batch_norm: bool = True, activation='ReLU', in_spatial: tuple or int = 2) -> StaxNet: if isinstance(filters, (tuple, list)): assert len( filters ) == levels, f"List of filters has length {len(filters)} but u-net has {levels} levels." else: filters = (filters, ) * levels activation = ACTIVATIONS[activation] if isinstance(in_spatial, int): d = in_spatial in_spatial = (-1, ) * d else: assert isinstance(in_spatial, tuple) d = len(in_spatial) # Create layers inc_init, inc_apply = create_double_conv(d, filters[0], filters[0], batch_norm, activation) init_functions, apply_functions = {}, {} for i in range(1, levels): init_functions[f'down{i}'], apply_functions[ f'down{i}'] = create_double_conv(d, filters[i], filters[i], batch_norm, activation) init_functions[f'up{i}'], apply_functions[ f'up{i}'] = create_double_conv(d, filters[i - 1], filters[i - 1], batch_norm, activation) outc_init, outc_apply = CONV[d](out_channels, (1, ) * d, padding='same') max_pool_init, max_pool_apply = stax.MaxPool((2, ) * d, padding='same', strides=(2, ) * d) _, up_apply = create_upsample() def net_init(rng, input_shape): params = {} rngs = random.split(rng, 2) shape = input_shape # Layers shape, params['inc'] = inc_init(rngs[0], shape) shapes = [shape] for i in range(1, levels): shape, _ = max_pool_init(None, shape) shape, params[f'down{i}'] = init_functions[f'down{i}'](rngs[i], shape) shapes.insert(0, shape) for i in range(1, levels): shape = shapes[i][:-1] + (shapes[i][-1] + shape[-1], ) shape, params[f'up{i}'] = init_functions[f'up{i}'](rngs[levels + i], shape) shape, params['outc'] = outc_init(rngs[-1], shape) return shape, params # no @jax.jit needed here since the user can jit this in the loss_function def net_apply(params, inputs, **kwargs): x = inputs x = inc_apply(params['inc'], x, **kwargs) xs = [x] for i in range(1, levels): x = max_pool_apply(None, x, **kwargs) x = apply_functions[f'down{i}'](params[f'down{i}'], x, **kwargs) xs.insert(0, x) for i in range(1, levels): x = up_apply(None, x, **kwargs) x = jnp.concatenate([x, xs[i]], axis=-1) x = apply_functions[f'up{i}'](params[f'up{i}'], x, **kwargs) x = outc_apply(params['outc'], x, **kwargs) return x net = StaxNet(net_init, net_apply, (-1, ) + in_spatial + (in_channels, )) net.initialize() return net