コード例 #1
0
def gen_cnn_conv4(output_units=10,
                  W_initializers_str='glorot_normal()',
                  b_initializers_str='normal()'):
    # This is an up-scaled version of the CNN in keras tutorial: https://keras.io/examples/cifar10_cnn/
    return stax.serial(
        stax.Conv(out_chan=64,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.Conv(out_chan=64,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.MaxPool((2, 2), strides=(2, 2)),
        stax.Conv(out_chan=128,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.Conv(out_chan=128,
                  filter_shape=(3, 3),
                  W_init=eval(W_initializers_str),
                  b_init=eval(b_initializers_str)), stax.Relu,
        stax.MaxPool((2, 2), strides=(2, 2)), stax.Flatten,
        stax.Dense(512,
                   W_init=eval(W_initializers_str),
                   b_init=eval(b_initializers_str)), stax.Relu,
        stax.Dense(output_units,
                   W_init=eval(W_initializers_str),
                   b_init=eval(b_initializers_str)))
コード例 #2
0
def gen_cnn_lenet_caffe(output_units = 10, W_initializers_str = 'glorot_normal()', b_initializers_str = 'normal()'):
    return stax.serial(
      stax.Conv(out_chan = 20, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ),
      stax.Relu, 
      stax.MaxPool((2, 2), strides = (2, 2)),
      stax.Conv(out_chan = 50, filter_shape = (5, 5), W_init= eval(W_initializers_str), b_init= eval(b_initializers_str) ),
      stax.Relu, 
      stax.MaxPool((2, 2), strides = (2, 2)),
      stax.Flatten, 
      stax.Dense(500, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str)),
      stax.Relu,
      stax.Dense(output_units, W_init= eval(W_initializers_str), b_init= eval(b_initializers_str))) 
コード例 #3
0
ファイル: train_dpsgd_svhn.py プロジェクト: adp-anonymous/adp
def cnn(num_classes=10):
    return stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,  # (-1, 800)
        stax.Dense(64),
        stax.Tanh,  # embeddings
        stax.Dense(num_classes),  # logits
    )
コード例 #4
0
ファイル: train_svhn.py プロジェクト: adp-anonymous/adp
def cnn():
    return stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,
        stax.Dense(32),
        stax.Tanh,  # embeddings
        stax.Dense(10),  # logits
    )
コード例 #5
0
ファイル: resnet.py プロジェクト: skhong0831/tensor2tensor
def Resnet50(hidden_size=64, num_output_classes=1001):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    return stax.serial(
        stax.Conv(hidden_size, (7, 7), (2, 2),
                  'SAME'), stax.BatchNorm(), stax.Relu,
        stax.MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size]),
        ConvBlock(3,
                  [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [8 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]),
        stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes),
        stax.LogSoftmax)
コード例 #6
0
 def __init__(self, kw, kh, name=None):
     super(MaxPool2d, self).__init__(name)
     
     _, self.maxpool = jexp.MaxPool((kw, kh))
     
     if name is None:
         self.name = F'MaxPool2d+{rand_string()}'
コード例 #7
0
    def __init__(self, num_classes=100, encoding=True):

        blocks = [
            stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2),
                             'SAME'),
            stax.BatchNorm(), stax.Relu,
            stax.MaxPool((3, 3), strides=(2, 2)),
            self.ConvBlock(3, [64, 64, 256], strides=(1, 1)),
            self.IdentityBlock(3, [64, 64]),
            self.IdentityBlock(3, [64, 64]),
            self.ConvBlock(3, [128, 128, 512]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.IdentityBlock(3, [128, 128]),
            self.ConvBlock(3, [256, 256, 1024]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.IdentityBlock(3, [256, 256]),
            self.ConvBlock(3, [512, 512, 2048]),
            self.IdentityBlock(3, [512, 512]),
            self.IdentityBlock(3, [512, 512]),
            stax.AvgPool((7, 7))
        ]

        if not encoding:
            blocks.append(stax.Flatten)
            blocks.append(stax.Dense(num_classes))

        self.model = stax.serial(*blocks)
コード例 #8
0
def conv():
    init_fun, predict = stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Relu,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Relu,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,
        stax.Dense(32),
        stax.Relu,
        stax.Dense(10),
    )

    def init_params(rng):
        return init_fun(rng, (-1, 28, 28, 1))[1]

    return init_params, predict
コード例 #9
0
# %%
def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]


batches = data_stream()

init_fun, net = stax.serial(stax.Conv(16, (3, 3), (1, 1),
                                      padding="SAME"), stax.Relu,
                            stax.MaxPool((2, 2), (2, 2), padding="SAME"),
                            stax.Conv(32, (3, 3), (1, 1),
                                      padding="SAME"), stax.Relu,
                            stax.MaxPool((2, 2), (2, 2), padding="SAME"),
                            stax.Flatten, stax.Dense(10), stax.LogSoftmax)

_, params = init_fun(key, (64, 1, 28, 28))


def loss(params, batch):
    inputs, targets = batch
    preds = net(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))


def accuracy(params, batch):
コード例 #10
0
#test_newx = computation(params, (train_images, train_labels))

#print(test_newx)
"""# **Problem 2**

Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD.
"""

from jax.experimental import optimizers
from jax.experimental import stax
import matplotlib.pyplot as plt
"""Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`"""

init_random_params, predict = stax.serial(
    stax.Conv(64, (8, 8), padding='SAME', strides=(2, 2)), stax.Relu,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(128, (4, 4), padding='VALID', strides=(2, 2)), stax.Relu,
    stax.MaxPool((2, 2), (1, 1)), stax.Flatten, stax.Dense(128), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)
"""We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical)."""


def loss(params, batch):
    inputs, targets = batch
    logits = predict(params, inputs)
    preds = stax.logsoftmax(logits)
    return -np.mean(preds * targets)


"""Next, we define the mini-batch SGD optimizer, this time with the optimizers library in JAX."""
コード例 #11
0
flags.DEFINE_integer('epochs', 100, 'Number of finetuning epochs')
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
flags.DEFINE_integer(
    'uncertain', 0, '0: entropy'
    '1: difference between 1st_prob and 2nd_prob'
    '2: random')
flags.DEFINE_integer('n_extra', 3000, 'number of extra points')
flags.DEFINE_bool(
    'show_label', True,
    'visualize predicted label at top/left, true at bottom/right')

# BEGIN: define the classifier model
init_fn_0, apply_fn_0 = stax.serial(
    stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
    stax.Tanh,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
    stax.Tanh,
    stax.MaxPool((2, 2), (1, 1)),
    stax.Flatten,  # (-1, 800)
    stax.Dense(32),
    stax.Tanh,  # embeddings
)

init_fn_1, apply_fn_1 = stax.serial(
    stax.Dense(10),  # logits
)


def predict(params, inputs):
    params_0 = params[:-1]
コード例 #12
0
Before we get started, we need to import two small libraries that contain boilerplate code for common neural network layer types and for optimizers like mini-batch SGD.
"""

from jax.experimental import optimizers
from jax.experimental import stax

"""Here is a fully-connected neural network architecture, like the one of Problem 1, but this time defined with `stax`"""

init_random_params, predict = stax.serial(
    stax.Conv(256, (5,5),strides = (2,2)),
    stax.Relu,
    stax.Conv(128, (3,3)),
    stax.Relu,
    stax.Conv(32, (3,3)),
    stax.Relu,
    stax.MaxPool((2,2)),
    stax.Flatten,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(128),
    stax.Relu,
    stax.Dense(10),
)

"""We redefine the cross-entropy loss for this model. As done in Problem 1, complete the return line below (it's identical)."""

def loss(params, batch):
  inputs, targets = batch
  logits = predict(params, inputs)
  preds  = stax.logsoftmax(logits)
  return -np.sum(targets*preds)/len(targets)
コード例 #13
0
ファイル: stax_test.py プロジェクト: heldyyusliar/jax
 def testPoolingShape(self, window_shape, padding, strides, input_shape):
     init_fun, apply_fun = stax.MaxPool(window_shape,
                                        padding=padding,
                                        strides=strides)
     _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
コード例 #14
0
ファイル: train_dpsgd_svhn.py プロジェクト: adp-anonymous/adp
def main(_):

    logging.info('Starting experiment.')
    configs = FLAGS.config

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.exp_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+')

    if configs.optimization == 'sgd':
        lr_schedule = optimizers.make_schedule(configs.learning_rate)
        opt_init, opt_update, get_params = optimizers.sgd(lr_schedule)
    elif configs.optimization == 'momentum':
        lr_schedule = cosine(configs.learning_rate, configs.train_steps)
        opt_init, opt_update, get_params = optimizers.momentum(
            lr_schedule, 0.9)
    else:
        raise ValueError('Optimizer not implemented.')

    with gfile.Open(FLAGS.pretrained_dir, 'rb') as fpre:
        pretrained_opt_state = optimizers.pack_optimizer_state(
            pickle.load(fpre))
    fixed_params = get_params(pretrained_opt_state)[:7]

    # BEGIN: define the classifier model
    init_fn_0, apply_fn_0 = stax.serial(
        stax.Conv(16, (8, 8), padding='SAME', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Conv(32, (4, 4), padding='VALID', strides=(2, 2)),
        stax.Tanh,
        stax.MaxPool((2, 2), (1, 1)),
        stax.Flatten,  # representations
    )

    init_fn_1, apply_fn_1 = stax.serial(
        stax.Dense(64),
        stax.Tanh,  # embeddings
        stax.Dense(10),  # logits
    )

    def predict(params, inputs):
        representations = apply_fn_0(fixed_params,
                                     inputs)  # use pretrained params
        logits = apply_fn_1(params, representations)
        return logits

    # END: define the classifier model

    if configs.seed is not None:
        key = random.PRNGKey(configs.seed)
    else:
        key = random.PRNGKey(int(time.time()))
    _, _ = init_fn_0(key, (-1, 32, 32, 3))
    _, params = init_fn_1(key, (-1, 800))
    opt_state = opt_init(params)

    logging.info('Loading data.')
    tic = time.time()

    train_images, train_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'train')
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    n_train = len(train_images)
    train = data.DataChunk(X=(train_images - train_mu) / train_std,
                           Y=train_labels,
                           image_size=32,
                           image_channels=3,
                           label_dim=1,
                           label_format='numeric')

    test_images, test_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'test')
    test = data.DataChunk(
        X=(test_images - train_mu) / train_std,  # normalize w train mean/std
        Y=test_labels,
        image_size=32,
        image_channels=3,
        label_dim=1,
        label_format='numeric')

    # Data augmentation
    if configs.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    batch = data.minibatcher(train, configs.batch_size, transform=augmentation)

    # count params of JAX model
    def count_parameters(params):
        return tree_util.tree_reduce(
            operator.add, tree_util.tree_map(lambda x: np.prod(x.shape),
                                             params))

    logging.info('Number of parameters: %d', count_parameters(params))
    stdout_log.write('Number of params: {}\n'.format(count_parameters(params)))

    # loss functions
    def cross_entropy_loss(params, x_img, y_lbl):
        return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)

    def mse_loss(params, x_img, y_lbl):
        return 0.5 * np.mean((y_lbl - predict(params, x_img))**2)

    def accuracy(y_lbl_hat, y_lbl):
        target_class = np.argmax(y_lbl, axis=1)
        predicted_class = np.argmax(y_lbl_hat, axis=1)
        return np.mean(predicted_class == target_class)

    # Loss and gradient
    if configs.loss == 'xent':
        loss = cross_entropy_loss
    elif configs.loss == 'mse':
        loss = mse_loss
    else:
        raise ValueError('Loss function not implemented.')
    grad_loss = jit(grad(loss))

    # learning rate schedule and optimizer
    def cosine(initial_step_size, train_steps):
        k = np.pi / (2.0 * train_steps)

        def schedule(i):
            return initial_step_size * np.cos(k * i)

        return schedule

    def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
                     batch_size):
        """Return differentially private gradients of params, evaluated on batch."""
        def _clipped_grad(params, single_example_batch):
            """Evaluate gradient for a single-example batch and clip its grad norm."""
            grads = grad_loss(params, single_example_batch[0].reshape(
                (-1, 32, 32, 3)), single_example_batch[1])

            nonempty_grads, tree_def = tree_util.tree_flatten(grads)
            total_grad_norm = np.linalg.norm(
                [np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
            divisor = stop_gradient(
                np.amax((total_grad_norm / l2_norm_clip, 1.)))
            normalized_nonempty_grads = [
                neg / divisor for neg in nonempty_grads
            ]
            return tree_util.tree_unflatten(tree_def,
                                            normalized_nonempty_grads)

        px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
        std_dev = l2_norm_clip * noise_multiplier
        noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
        normalize_ = lambda n: n / float(batch_size)
        sum_ = lambda n: np.sum(n, 0)  # aggregate
        aggregated_clipped_grads = tree_util.tree_map(
            sum_, px_clipped_grad_fn(batch))
        noised_aggregated_clipped_grads = tree_util.tree_map(
            noise_, aggregated_clipped_grads)
        normalized_noised_aggregated_clipped_grads = (tree_util.tree_map(
            normalize_, noised_aggregated_clipped_grads))
        return normalized_noised_aggregated_clipped_grads

    # summarize measurements
    steps_per_epoch = n_train // configs.batch_size

    def summarize(step, params):
        """Compute measurements in a zipped way."""
        set_entries = [train, test]
        set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize]
        set_names, loss_dict, acc_dict = ['train', 'test'], {}, {}

        for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes,
                                                  set_names):
            temp_loss, temp_acc, points = 0.0, 0.0, 0
            for b in data.batch(set_entry, set_bsize):
                temp_loss += loss(params, b.X, b.Y) * b.X.shape[0]
                temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0]
                points += b.X.shape[0]
            loss_dict[set_name] = temp_loss / float(points)
            acc_dict[set_name] = temp_acc / float(points)

        logging.info('Step: %s', str(step))
        logging.info('Train acc : %.4f', acc_dict['train'])
        logging.info('Train loss: %.4f', loss_dict['train'])
        logging.info('Test acc  : %.4f', acc_dict['test'])
        logging.info('Test loss : %.4f', loss_dict['test'])

        stdout_log.write('Step: {}\n'.format(step))
        stdout_log.write('Train acc : {}\n'.format(acc_dict['train']))
        stdout_log.write('Train loss: {}\n'.format(loss_dict['train']))
        stdout_log.write('Test acc  : {}\n'.format(acc_dict['test']))
        stdout_log.write('Test loss : {}\n'.format(loss_dict['test']))
        stdout_log.flush()

        return acc_dict['test']

    toc = time.time()
    logging.info('Elapsed SETUP time: %s', str(toc - tic))
    stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic))

    # BEGIN: training steps
    logging.info('Training network.')
    tic = time.time()
    t = time.time()

    for s in range(configs.train_steps):
        b = next(batch)
        params = get_params(opt_state)

        # t0 = time.time()
        if FLAGS.dpsgd:
            key = random.fold_in(key, s)  # get new key for new random numbers
            opt_state = opt_update(
                s,
                private_grad(params, (b.X.reshape(
                    (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip,
                             configs.noise_multiplier, configs.batch_size),
                opt_state)
        else:
            opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state)
        # t1 = time.time()
        # logging.info('batch update time: %s', str(t1 - t0))

        if s % steps_per_epoch == 0:
            with gfile.Open(
                    '{}/ckpt_{}'.format(FLAGS.exp_dir,
                                        int(s / steps_per_epoch)),
                    'wb') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)

            if FLAGS.dpsgd:
                eps = compute_epsilon(s, configs.batch_size, n_train,
                                      configs.target_delta,
                                      configs.noise_multiplier)
                stdout_log.write(
                    'For delta={:.0e}, current epsilon is: {:.2f}\n'.format(
                        configs.target_delta, eps))

            logging.info('Elapsed EPOCH time: %s', str(time.time() - t))
            stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t))
            stdout_log.flush()
            t = time.time()

    toc = time.time()
    summarize(configs.train_steps, params)
    logging.info('Elapsed TRAIN time: %s', str(toc - tic))
    stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic))
    stdout_log.close()
コード例 #15
0
ファイル: nets.py プロジェクト: andresgm/PhiFlow
def u_net(in_channels: int,
          out_channels: int,
          levels: int = 4,
          filters: int or tuple or list = 16,
          batch_norm: bool = True,
          activation='ReLU',
          in_spatial: tuple or int = 2) -> StaxNet:
    if isinstance(filters, (tuple, list)):
        assert len(
            filters
        ) == levels, f"List of filters has length {len(filters)} but u-net has {levels} levels."
    else:
        filters = (filters, ) * levels
    activation = ACTIVATIONS[activation]
    if isinstance(in_spatial, int):
        d = in_spatial
        in_spatial = (-1, ) * d
    else:
        assert isinstance(in_spatial, tuple)
        d = len(in_spatial)
    # Create layers
    inc_init, inc_apply = create_double_conv(d, filters[0], filters[0],
                                             batch_norm, activation)
    init_functions, apply_functions = {}, {}
    for i in range(1, levels):
        init_functions[f'down{i}'], apply_functions[
            f'down{i}'] = create_double_conv(d, filters[i], filters[i],
                                             batch_norm, activation)
        init_functions[f'up{i}'], apply_functions[
            f'up{i}'] = create_double_conv(d, filters[i - 1], filters[i - 1],
                                           batch_norm, activation)
    outc_init, outc_apply = CONV[d](out_channels, (1, ) * d, padding='same')
    max_pool_init, max_pool_apply = stax.MaxPool((2, ) * d,
                                                 padding='same',
                                                 strides=(2, ) * d)
    _, up_apply = create_upsample()

    def net_init(rng, input_shape):
        params = {}
        rngs = random.split(rng, 2)
        shape = input_shape
        # Layers
        shape, params['inc'] = inc_init(rngs[0], shape)
        shapes = [shape]
        for i in range(1, levels):
            shape, _ = max_pool_init(None, shape)
            shape, params[f'down{i}'] = init_functions[f'down{i}'](rngs[i],
                                                                   shape)
            shapes.insert(0, shape)
        for i in range(1, levels):
            shape = shapes[i][:-1] + (shapes[i][-1] + shape[-1], )
            shape, params[f'up{i}'] = init_functions[f'up{i}'](rngs[levels +
                                                                    i], shape)
        shape, params['outc'] = outc_init(rngs[-1], shape)
        return shape, params

    # no @jax.jit needed here since the user can jit this in the loss_function
    def net_apply(params, inputs, **kwargs):
        x = inputs
        x = inc_apply(params['inc'], x, **kwargs)
        xs = [x]
        for i in range(1, levels):
            x = max_pool_apply(None, x, **kwargs)
            x = apply_functions[f'down{i}'](params[f'down{i}'], x, **kwargs)
            xs.insert(0, x)
        for i in range(1, levels):
            x = up_apply(None, x, **kwargs)
            x = jnp.concatenate([x, xs[i]], axis=-1)
            x = apply_functions[f'up{i}'](params[f'up{i}'], x, **kwargs)
        x = outc_apply(params['outc'], x, **kwargs)
        return x

    net = StaxNet(net_init, net_apply, (-1, ) + in_spatial + (in_channels, ))
    net.initialize()
    return net