Пример #1
0
def ConvBlock(kernel_size,
              filters,
              strides=(2, 2),
              batchnorm=True,
              parameterization='standard',
              nonlin=Relu):
    ks = kernel_size
    filters1, filters2, filters3 = filters
    if parameterization == 'standard':

        def MyConv(*args, **kwargs):
            return Conv(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyConv(*args, **kwargs):
            return stax.Conv(*args, **kwargs)[:2]

    if batchnorm:
        Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), BatchNorm(),
                               nonlin,
                               MyConv(filters2, (ks, ks), padding='SAME'),
                               BatchNorm(), nonlin, MyConv(filters3, (1, 1)),
                               BatchNorm())
        Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides),
                                   BatchNorm())
    else:
        Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), nonlin,
                               MyConv(filters2, (ks, ks), padding='SAME'),
                               nonlin, MyConv(filters3, (1, 1)))
        Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides))
    return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut),
                           FanInSum, nonlin)
Пример #2
0
def ConvBlock(kernel_size, filters, strides=(2, 2)):
    ks = kernel_size
    filters1, filters2, filters3 = filters
    Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu,
                       Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(),
                       Relu, Conv(filters3, (1, 1)), BatchNorm())
    Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm())
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       Relu)
Пример #3
0
 def construct_main(inp_shape):
     return stax.serial(
         Conv(filters[0], (1, 1), strides=(1, 1)),
         BatchNorm(),
         Relu,
         Conv(filters[1], (ks, ks), padding="SAME"),
         BatchNorm(),
         Relu,
         Conv(input_shape[3], (1, 1)),
         BatchNorm(),
     )
Пример #4
0
 def make_main(input_shape):
     # the number of output channels depends on the number of input channels
     return stax.serial(
         Conv(filters1, (1, 1)),
         BatchNorm(),
         Relu,
         Conv(filters2, (ks, ks), padding="SAME"),
         BatchNorm(),
         Relu,
         Conv(input_shape[3], (1, 1)),
         BatchNorm(),
     )
Пример #5
0
 def make_main(input_shape):
     # the number of output channels depends on the number of input channels
     if batchnorm:
         return jax_stax.serial(MyConv(filters1, (1, 1)), BatchNorm(),
                                nonlin,
                                MyConv(filters2, (ks, ks), padding='SAME'),
                                BatchNorm(), nonlin,
                                MyConv(input_shape[3], (1, 1)), BatchNorm())
     else:
         return jax_stax.serial(MyConv(filters1, (1, 1)), nonlin,
                                MyConv(filters2, (ks, ks), padding='SAME'),
                                nonlin, MyConv(input_shape[3], (1, 1)))
Пример #6
0
def convBlock(ks, filters, stride=(1, 1)):
    Main = stax.serial(Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(),
                       Relu, Conv(filters[1], (ks, ks), strides=stride),
                       BatchNorm(), Relu,
                       Conv(filters[2], (1, 1),
                            strides=(1, 1)), BatchNorm(), Relu)

    Shortcut = stax.serial(
        Conv(filters[3], (1, 1), strides=stride),
        BatchNorm(),
    )

    fullInternal = stax.parallel(Main, Shortcut)

    return stax.serial(FanOut(2), fullInternal, FanInSum, Relu)
Пример #7
0
def maybe_use_normalization(normalization_method=None):
    if normalization_method == "batch_norm":
        return BatchNorm()
    elif normalization_method == "group_norm":
        return GroupNorm()
    else:
        return Identity
Пример #8
0
def ResNet50(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]),
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        AvgPool((7, 7)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Пример #9
0
def LeNet5(num_classes):
    return stax.serial(
        GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Conv(16, (5,5), strides = (1,1),padding="SAME"),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Flatten,
        Dense(num_classes*10),
        Dense(num_classes*5),
        Dense(num_classes),
        LogSoftmax
    )
Пример #10
0
def denseActivationNormLayer(n_hidden_unit, bias_coef, activation, norm):
    activation = select_activation(activation)
    layer = stax.serial(Dense(n_hidden_unit, b_gain=bias_coef), activation)
    if norm == 'batch_norm':
        layer = stax.serial(layer, BatchNorm(axis=0))
    elif norm is None or norm == 'none':
        pass
    else:
        raise ValueError

    return layer
Пример #11
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
        BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
        convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]),
        identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]),
        identityBlock(3, [128, 128]), identityBlock(3, [128, 128]),
        identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]),
        identityBlock(3, [512, 512]), identityBlock(3, [512, 512]),
        AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
Пример #12
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [4, 4, 4], strides=(1, 1)),
        IdentityBlock(3, [4, 4]),
        AvgPool((3, 3)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Пример #13
0
def ResNet(block,
           expansion,
           layers,
           normalization_method=None,
           width_per_group=64,
           actfn=stax.Relu):
    norm_layer = Identity
    if normalization_method == "group_norm":
        norm_layer = GroupNorm(32)
    elif normalization_method == "batch_norm":
        norm_layer = BatchNorm()
    base_width = width_per_group

    def _make_layer(block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = stax.serial(
                Conv(planes * expansion, (1, 1),
                     strides=(stride, stride),
                     bias=False),
                norm_layer,
            )
        layers = []
        layers.append(block(planes, stride, downsample, base_width,
                            norm_layer))
        for _ in range(1, blocks):
            layers.append(
                block(planes,
                      base_width=base_width,
                      norm_layer=norm_layer,
                      actfn=actfn))
        return stax.serial(*layers)

    return [
        Conv(64, (3, 3), strides=(1, 1), padding="SAME", bias=False),
        norm_layer,
        actfn,
        # MaxPool((3, 3), strides=(2, 2), padding="SAME"),
        _make_layer(block, 64, layers[0]),
        _make_layer(block, 128, layers[1], stride=2),
        _make_layer(block, 256, layers[2], stride=2),
        _make_layer(block, 512, layers[3], stride=2),
        AvgPool((4, 4)),
        Flatten,
    ]
Пример #14
0
def convActivationNormLayer(n_filter, filter_shape, strides, padding,
                            bias_coef, activation, norm):
    activation = select_activation(activation)
    layer = stax.serial(
        Conv(out_chan=n_filter,
             filter_shape=filter_shape,
             strides=strides,
             padding=padding,
             b_gain=bias_coef), activation)
    if norm == 'batch_norm':
        layer = stax.serial(
            layer,
            BatchNorm(axis=(0, 1, 2))  # normalize over N, H, W
        )
    elif norm is None or norm == 'None':
        pass
    else:
        raise ValueError

    return layer
Пример #15
0
def ResNet50(num_classes,
             batchnorm=True,
             parameterization='standard',
             nonlinearity='relu'):
    # Define layer constructors
    if parameterization == 'standard':

        def MyGeneralConv(*args, **kwargs):
            return GeneralConv(*args, **kwargs)

        def MyDense(*args, **kwargs):
            return Dense(*args, **kwargs)
    elif parameterization == 'ntk':

        def MyGeneralConv(*args, **kwargs):
            return stax._GeneralConv(*args, **kwargs)[:2]

        def MyDense(*args, **kwargs):
            return stax.Dense(*args, **kwargs)[:2]

    # Define nonlinearity
    if nonlinearity == 'relu':
        nonlin = Relu
    elif nonlinearity == 'swish':
        nonlin = Swish
    elif nonlinearity == 'swishten':
        nonlin = Swishten
    elif nonlinearity == 'softplus':
        nonlin = Softplus
    return jax_stax.serial(
        MyGeneralConv(('NHWC', 'HWIO', 'NHWC'),
                      64, (7, 7),
                      strides=(2, 2),
                      padding='SAME'),
        BatchNorm() if batchnorm else Identity, nonlin,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256],
                  strides=(1, 1),
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [64, 64],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [64, 64],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [128, 128, 512],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [128, 128],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [256, 256, 1024],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [256, 256],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        ConvBlock(3, [512, 512, 2048],
                  batchnorm=batchnorm,
                  parameterization=parameterization,
                  nonlin=nonlin),
        IdentityBlock(3, [512, 512],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        IdentityBlock(3, [512, 512],
                      batchnorm=batchnorm,
                      parameterization=parameterization,
                      nonlin=nonlin),
        stax.GlobalAvgPool()[:-1], MyDense(num_classes))
Пример #16
0
def GCNLayer(out_dim,
             activation=relu,
             bias=True,
             normalize=True,
             batch_norm=False,
             dropout=0.0,
             W_init=he_normal(),
             b_init=normal()):
    r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
    <https://arxiv.org/abs/1609.02907>`

    Parameters
    ----------
    out_dim : int
        Number of output node features.
    activation : Function
        activation function, default to be relu function.
    bias : bool
        Whether to add bias after affine transformation, default to be True.
    normalize : bool
        Whether to normalize the adjacency matrix or not, default to be True.
    batch_norm : bool
        Whetehr to use BatchNormalization or not, default to be False.
    dropout : float
        The probability for dropout, default to 0.0.
    W_init : initialize function for weight
        Default to be He normal distribution.
    b_init : initialize function for bias
        Default to be normal distribution.

    Returns
    -------
    init_fun : Function
        Initializes the parameters of the layer.
    apply_fun : Function
        Defines the forward computation function.
    """

    _, drop_fun = Dropout(dropout)
    batch_norm_init, batch_norm_fun = BatchNorm()

    def init_fun(rng, input_shape):
        """Initialize parameters.

        Parameters
        ----------
        rng : PRNGKey
            rng is a value for generating random values.
        input_shape : (batch_size, N, M1)
            The shape of input (input node features).
            N is the total number of nodes in the batch of graphs.
            M1 is the input node feature size.

        Returns
        -------
        output_shape : (batch_size, N, M2)
            The shape of output (new node features).
            M2 is the new node feature size and equal to out_dim.
        params: Tuple (W, b, batch_norm_param)
            W is a weight and b is a bias.
            W : ndarray of shape (N, M2) or None
            b : ndarray of shape (M2,)
            batch_norm_param : Tuple (beta, gamma) or None
        """
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2, k3 = random.split(rng, 3)
        W = W_init(k1, (input_shape[-1], out_dim))
        b = b_init(k2, (out_dim, )) if bias else None
        batch_norm_param = None
        if batch_norm:
            output_shape, batch_norm_param = batch_norm_init(k3, output_shape)
        return output_shape, (W, b, batch_norm_param)

    def apply_fun(params, node_feats, adj, rng, is_train):
        """Update node representations.

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, M1)
            Batched input node features.
            N is the total number of nodes in the batch of graphs.
            M1 is the input node feature size.
        adj : ndarray of shape (batch_size, N, N)
            Batched adjacency matrix.
        rng : PRNGKey
            rng is a value for generating random values
        is_train : bool
            Whether the model is training or not.

        Returns
        -------
        new_node_feats : ndarray of shape (batch_size, N, M2)
            Batched new node features.
            M2 is the new node feature size and equal to out_dim.
        """
        W, b, batch_norm_param = params

        if normalize:
            # A' = A + I, where I is the identity matrix
            # D': diagonal node degree matrix of A'
            # H' = D'^(-1/2) × A' × D'^(-1/2) × H × W
            def node_update_func(node_feats, adj):
                adj = adj + jnp.eye(len(adj))
                deg = jnp.sum(adj, axis=1)
                deg_mat = jnp.diag(jnp.where(deg > 0, deg**(-0.5), 0))
                normalized_adj = jnp.dot(deg_mat, jnp.dot(adj, deg_mat))
                return jnp.dot(normalized_adj, jnp.dot(node_feats, W))
        else:
            # H' = A × H × W
            def node_update_func(node_feats, adj):
                return jnp.dot(adj, jnp.dot(node_feats, W))

        # batched operation for updating node features
        new_node_feats = vmap(node_update_func)(node_feats, adj)

        if bias:
            new_node_feats += b
        new_node_feats = activation(new_node_feats)
        if dropout != 0.0:
            rng, key = random.split(rng)
            new_node_feats = drop_fun(None, new_node_feats, is_train, rng=key)
        if batch_norm:
            new_node_feats = batch_norm_fun(batch_norm_param, new_node_feats)
        return new_node_feats

    return init_fun, apply_fun
Пример #17
0
    def initialize_parametric_nonlinearity(self,
                                           init_to='exponential',
                                           method=None,
                                           params_dict=None):

        if method is None:  # if no methods specified, use defaults.
            # this piece of code is quite redundant.
            # need to refactor.
            if hasattr(self, 'nonlinearity'):
                method = self.nonlinearity
            else:
                method = self.filter_nonlinearity
        else:  # overwrite the default nonlinearity
            if hasattr(self, 'nonlinearity'):
                self.nonlinearity = method
            else:
                self.filter_nonlinearity = method
                self.output_nonlinearity = method

        # prepare data
        if params_dict is None:
            params_dict = {}
        xrange = params_dict['xrange'] if 'xrange' in params_dict else 5
        nx = params_dict['nx'] if 'nx' in params_dict else 1000
        x0 = np.linspace(-xrange, xrange, nx)
        if init_to == 'exponential':
            y0 = np.exp(x0)

        elif init_to == 'softplus':
            y0 = softplus(x0)

        elif init_to == 'relu':
            y0 = relu(x0)

        elif init_to == 'nonparametric':
            y0 = self.fnl_nonparametric(x0)

        elif init_to == 'gaussian':
            import scipy.signal
            y0 = scipy.signal.gaussian(nx, nx / 10)

        # fit nonlin
        if method == 'spline':
            smooth = params_dict['smooth'] if 'smooth' in params_dict else 'cr'
            df = params_dict['df'] if 'df' in params_dict else 7
            if smooth == 'cr':
                X = cr(x0, df)
            elif smooth == 'cc':
                X = cc(x0, df)
            elif smooth == 'bs':
                deg = params_dict['degree'] if 'degree' in params_dict else 3
                X = bs(x0, df, deg)

            opt_params = np.linalg.pinv(X.T @ X) @ X.T @ y0

            self.nl_basis = X

            def _nl(opt_params, x_new):
                return np.maximum(interp1d(x0, X @ opt_params)(x_new), 0)

        elif method == 'nn':

            def loss(params, data):
                x = data['x']
                y = data['y']
                yhat = _predict(params, x)
                return np.mean((y - yhat)**2)

            @jit
            def step(i, opt_state, data):
                p = get_params(opt_state)
                g = grad(loss)(p, data)
                return opt_update(i, g, opt_state)

            random_seed = params_dict[
                'random_seed'] if 'random_seed' in params_dict else 2046
            key = random.PRNGKey(random_seed)

            step_size = params_dict[
                'step_size'] if 'step_size' in params_dict else 0.01
            layer_sizes = params_dict[
                'layer_sizes'] if 'layer_sizes' in params_dict else [
                    10, 10, 1
                ]
            layers = []
            for layer_size in layer_sizes:
                layers.append(Dense(layer_size))
                layers.append(BatchNorm(axis=(0, 1)))
                layers.append(Relu)
            else:
                layers.pop(-1)

            init_random_params, _predict = stax.serial(*layers)

            num_subunits = params_dict[
                'num_subunits'] if 'num_subunits' in params_dict else 1
            _, init_params = init_random_params(key, (-1, num_subunits))

            opt_init, opt_update, get_params = optimizers.adam(step_size)
            opt_state = opt_init(init_params)

            num_iters = params_dict[
                'num_iters'] if 'num_iters' in params_dict else 1000
            if num_subunits == 1:
                data = {'x': x0.reshape(-1, 1), 'y': y0.reshape(-1, 1)}
            else:
                data = {
                    'x': np.vstack([x0 for i in range(num_subunits)]).T,
                    'y': y0.reshape(-1, 1)
                }

            for i in range(num_iters):
                opt_state = step(i, opt_state, data)
            opt_params = get_params(opt_state)

            def _nl(opt_params, x_new):
                if len(x_new.shape) == 1:
                    x_new = x_new.reshape(-1, 1)
                return np.maximum(_predict(opt_params, x_new), 0)

        self.nl_xrange = x0
        self.nl_params = opt_params
        self.fnl_fitted = _nl
Пример #18
0
# inputs = next(batches)

# train_images, labels, _, _ = mnist(permute_train=True, resize=True)
# del _
# inputs = train_images[:data_size]
#
# del train_images

# u, s, v_t = onp.linalg.svd(inputs, full_matrices=False)
# I = np.eye(v_t.shape[-1])
# I_add = npr.normal(0.0, 0.002, size=I.shape)
# noisy_I = I + I_add

init_fun, conv_net = stax.serial(
    Conv(32, (5, 5), (2, 2), padding="SAME"),
    BatchNorm(),
    Relu,
    Conv(10, (3, 3), (2, 2), padding="SAME"),
    Relu,
    Flatten,
    Dense(num_classes),
    LogSoftmax,
)
_, key = random.split(random.PRNGKey(0))


class DataTopologyAE(AbstractProblem):
    def __init__(self):
        self.HPARAMS_PATH = "hparams.json"

    @staticmethod