def ConvBlock(kernel_size, filters, strides=(2, 2), batchnorm=True, parameterization='standard', nonlin=Relu): ks = kernel_size filters1, filters2, filters3 = filters if parameterization == 'standard': def MyConv(*args, **kwargs): return Conv(*args, **kwargs) elif parameterization == 'ntk': def MyConv(*args, **kwargs): return stax.Conv(*args, **kwargs)[:2] if batchnorm: Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), BatchNorm(), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), BatchNorm(), nonlin, MyConv(filters3, (1, 1)), BatchNorm()) Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides), BatchNorm()) else: Main = jax_stax.serial(MyConv(filters1, (1, 1), strides), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), nonlin, MyConv(filters3, (1, 1))) Shortcut = jax_stax.serial(MyConv(filters3, (1, 1), strides)) return jax_stax.serial(FanOut(2), jax_stax.parallel(Main, Shortcut), FanInSum, nonlin)
def ConvBlock(kernel_size, filters, strides=(2, 2)): ks = kernel_size filters1, filters2, filters3 = filters Main = stax.serial(Conv(filters1, (1, 1), strides), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), Relu, Conv(filters3, (1, 1)), BatchNorm()) Shortcut = stax.serial(Conv(filters3, (1, 1), strides), BatchNorm()) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, Relu)
def construct_main(inp_shape): return stax.serial( Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(), Relu, Conv(filters[1], (ks, ks), padding="SAME"), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm(), )
def make_main(input_shape): # the number of output channels depends on the number of input channels return stax.serial( Conv(filters1, (1, 1)), BatchNorm(), Relu, Conv(filters2, (ks, ks), padding="SAME"), BatchNorm(), Relu, Conv(input_shape[3], (1, 1)), BatchNorm(), )
def make_main(input_shape): # the number of output channels depends on the number of input channels if batchnorm: return jax_stax.serial(MyConv(filters1, (1, 1)), BatchNorm(), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), BatchNorm(), nonlin, MyConv(input_shape[3], (1, 1)), BatchNorm()) else: return jax_stax.serial(MyConv(filters1, (1, 1)), nonlin, MyConv(filters2, (ks, ks), padding='SAME'), nonlin, MyConv(input_shape[3], (1, 1)))
def convBlock(ks, filters, stride=(1, 1)): Main = stax.serial(Conv(filters[0], (1, 1), strides=(1, 1)), BatchNorm(), Relu, Conv(filters[1], (ks, ks), strides=stride), BatchNorm(), Relu, Conv(filters[2], (1, 1), strides=(1, 1)), BatchNorm(), Relu) Shortcut = stax.serial( Conv(filters[3], (1, 1), strides=stride), BatchNorm(), ) fullInternal = stax.parallel(Main, Shortcut) return stax.serial(FanOut(2), fullInternal, FanInSum, Relu)
def maybe_use_normalization(normalization_method=None): if normalization_method == "batch_norm": return BatchNorm() elif normalization_method == "group_norm": return GroupNorm() else: return Identity
def ResNet50(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax, )
def LeNet5(num_classes): return stax.serial( GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'), BatchNorm(), Relu, AvgPool((3,3)), Conv(16, (5,5), strides = (1,1),padding="SAME"), BatchNorm(), Relu, AvgPool((3,3)), Flatten, Dense(num_classes*10), Dense(num_classes*5), Dense(num_classes), LogSoftmax )
def denseActivationNormLayer(n_hidden_unit, bias_coef, activation, norm): activation = select_activation(activation) layer = stax.serial(Dense(n_hidden_unit, b_gain=bias_coef), activation) if norm == 'batch_norm': layer = stax.serial(layer, BatchNorm(axis=0)) elif norm is None or norm == 'none': pass else: raise ValueError return layer
def ResNet(num_classes): return stax.serial( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]), identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]), identityBlock(3, [512, 512]), identityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
def ResNet(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [4, 4, 4], strides=(1, 1)), IdentityBlock(3, [4, 4]), AvgPool((3, 3)), Flatten, Dense(num_classes), LogSoftmax, )
def ResNet(block, expansion, layers, normalization_method=None, width_per_group=64, actfn=stax.Relu): norm_layer = Identity if normalization_method == "group_norm": norm_layer = GroupNorm(32) elif normalization_method == "batch_norm": norm_layer = BatchNorm() base_width = width_per_group def _make_layer(block, planes, blocks, stride=1): downsample = None if stride != 1: downsample = stax.serial( Conv(planes * expansion, (1, 1), strides=(stride, stride), bias=False), norm_layer, ) layers = [] layers.append(block(planes, stride, downsample, base_width, norm_layer)) for _ in range(1, blocks): layers.append( block(planes, base_width=base_width, norm_layer=norm_layer, actfn=actfn)) return stax.serial(*layers) return [ Conv(64, (3, 3), strides=(1, 1), padding="SAME", bias=False), norm_layer, actfn, # MaxPool((3, 3), strides=(2, 2), padding="SAME"), _make_layer(block, 64, layers[0]), _make_layer(block, 128, layers[1], stride=2), _make_layer(block, 256, layers[2], stride=2), _make_layer(block, 512, layers[3], stride=2), AvgPool((4, 4)), Flatten, ]
def convActivationNormLayer(n_filter, filter_shape, strides, padding, bias_coef, activation, norm): activation = select_activation(activation) layer = stax.serial( Conv(out_chan=n_filter, filter_shape=filter_shape, strides=strides, padding=padding, b_gain=bias_coef), activation) if norm == 'batch_norm': layer = stax.serial( layer, BatchNorm(axis=(0, 1, 2)) # normalize over N, H, W ) elif norm is None or norm == 'None': pass else: raise ValueError return layer
def ResNet50(num_classes, batchnorm=True, parameterization='standard', nonlinearity='relu'): # Define layer constructors if parameterization == 'standard': def MyGeneralConv(*args, **kwargs): return GeneralConv(*args, **kwargs) def MyDense(*args, **kwargs): return Dense(*args, **kwargs) elif parameterization == 'ntk': def MyGeneralConv(*args, **kwargs): return stax._GeneralConv(*args, **kwargs)[:2] def MyDense(*args, **kwargs): return stax.Dense(*args, **kwargs)[:2] # Define nonlinearity if nonlinearity == 'relu': nonlin = Relu elif nonlinearity == 'swish': nonlin = Swish elif nonlinearity == 'swishten': nonlin = Swishten elif nonlinearity == 'softplus': nonlin = Softplus return jax_stax.serial( MyGeneralConv(('NHWC', 'HWIO', 'NHWC'), 64, (7, 7), strides=(2, 2), padding='SAME'), BatchNorm() if batchnorm else Identity, nonlin, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1), batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [64, 64], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [64, 64], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [128, 128, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [128, 128], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [256, 256, 1024], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [256, 256], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), ConvBlock(3, [512, 512, 2048], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [512, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), IdentityBlock(3, [512, 512], batchnorm=batchnorm, parameterization=parameterization, nonlin=nonlin), stax.GlobalAvgPool()[:-1], MyDense(num_classes))
def GCNLayer(out_dim, activation=relu, bias=True, normalize=True, batch_norm=False, dropout=0.0, W_init=he_normal(), b_init=normal()): r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/abs/1609.02907>` Parameters ---------- out_dim : int Number of output node features. activation : Function activation function, default to be relu function. bias : bool Whether to add bias after affine transformation, default to be True. normalize : bool Whether to normalize the adjacency matrix or not, default to be True. batch_norm : bool Whetehr to use BatchNormalization or not, default to be False. dropout : float The probability for dropout, default to 0.0. W_init : initialize function for weight Default to be He normal distribution. b_init : initialize function for bias Default to be normal distribution. Returns ------- init_fun : Function Initializes the parameters of the layer. apply_fun : Function Defines the forward computation function. """ _, drop_fun = Dropout(dropout) batch_norm_init, batch_norm_fun = BatchNorm() def init_fun(rng, input_shape): """Initialize parameters. Parameters ---------- rng : PRNGKey rng is a value for generating random values. input_shape : (batch_size, N, M1) The shape of input (input node features). N is the total number of nodes in the batch of graphs. M1 is the input node feature size. Returns ------- output_shape : (batch_size, N, M2) The shape of output (new node features). M2 is the new node feature size and equal to out_dim. params: Tuple (W, b, batch_norm_param) W is a weight and b is a bias. W : ndarray of shape (N, M2) or None b : ndarray of shape (M2,) batch_norm_param : Tuple (beta, gamma) or None """ output_shape = input_shape[:-1] + (out_dim, ) k1, k2, k3 = random.split(rng, 3) W = W_init(k1, (input_shape[-1], out_dim)) b = b_init(k2, (out_dim, )) if bias else None batch_norm_param = None if batch_norm: output_shape, batch_norm_param = batch_norm_init(k3, output_shape) return output_shape, (W, b, batch_norm_param) def apply_fun(params, node_feats, adj, rng, is_train): """Update node representations. Parameters ---------- node_feats : ndarray of shape (batch_size, N, M1) Batched input node features. N is the total number of nodes in the batch of graphs. M1 is the input node feature size. adj : ndarray of shape (batch_size, N, N) Batched adjacency matrix. rng : PRNGKey rng is a value for generating random values is_train : bool Whether the model is training or not. Returns ------- new_node_feats : ndarray of shape (batch_size, N, M2) Batched new node features. M2 is the new node feature size and equal to out_dim. """ W, b, batch_norm_param = params if normalize: # A' = A + I, where I is the identity matrix # D': diagonal node degree matrix of A' # H' = D'^(-1/2) × A' × D'^(-1/2) × H × W def node_update_func(node_feats, adj): adj = adj + jnp.eye(len(adj)) deg = jnp.sum(adj, axis=1) deg_mat = jnp.diag(jnp.where(deg > 0, deg**(-0.5), 0)) normalized_adj = jnp.dot(deg_mat, jnp.dot(adj, deg_mat)) return jnp.dot(normalized_adj, jnp.dot(node_feats, W)) else: # H' = A × H × W def node_update_func(node_feats, adj): return jnp.dot(adj, jnp.dot(node_feats, W)) # batched operation for updating node features new_node_feats = vmap(node_update_func)(node_feats, adj) if bias: new_node_feats += b new_node_feats = activation(new_node_feats) if dropout != 0.0: rng, key = random.split(rng) new_node_feats = drop_fun(None, new_node_feats, is_train, rng=key) if batch_norm: new_node_feats = batch_norm_fun(batch_norm_param, new_node_feats) return new_node_feats return init_fun, apply_fun
def initialize_parametric_nonlinearity(self, init_to='exponential', method=None, params_dict=None): if method is None: # if no methods specified, use defaults. # this piece of code is quite redundant. # need to refactor. if hasattr(self, 'nonlinearity'): method = self.nonlinearity else: method = self.filter_nonlinearity else: # overwrite the default nonlinearity if hasattr(self, 'nonlinearity'): self.nonlinearity = method else: self.filter_nonlinearity = method self.output_nonlinearity = method # prepare data if params_dict is None: params_dict = {} xrange = params_dict['xrange'] if 'xrange' in params_dict else 5 nx = params_dict['nx'] if 'nx' in params_dict else 1000 x0 = np.linspace(-xrange, xrange, nx) if init_to == 'exponential': y0 = np.exp(x0) elif init_to == 'softplus': y0 = softplus(x0) elif init_to == 'relu': y0 = relu(x0) elif init_to == 'nonparametric': y0 = self.fnl_nonparametric(x0) elif init_to == 'gaussian': import scipy.signal y0 = scipy.signal.gaussian(nx, nx / 10) # fit nonlin if method == 'spline': smooth = params_dict['smooth'] if 'smooth' in params_dict else 'cr' df = params_dict['df'] if 'df' in params_dict else 7 if smooth == 'cr': X = cr(x0, df) elif smooth == 'cc': X = cc(x0, df) elif smooth == 'bs': deg = params_dict['degree'] if 'degree' in params_dict else 3 X = bs(x0, df, deg) opt_params = np.linalg.pinv(X.T @ X) @ X.T @ y0 self.nl_basis = X def _nl(opt_params, x_new): return np.maximum(interp1d(x0, X @ opt_params)(x_new), 0) elif method == 'nn': def loss(params, data): x = data['x'] y = data['y'] yhat = _predict(params, x) return np.mean((y - yhat)**2) @jit def step(i, opt_state, data): p = get_params(opt_state) g = grad(loss)(p, data) return opt_update(i, g, opt_state) random_seed = params_dict[ 'random_seed'] if 'random_seed' in params_dict else 2046 key = random.PRNGKey(random_seed) step_size = params_dict[ 'step_size'] if 'step_size' in params_dict else 0.01 layer_sizes = params_dict[ 'layer_sizes'] if 'layer_sizes' in params_dict else [ 10, 10, 1 ] layers = [] for layer_size in layer_sizes: layers.append(Dense(layer_size)) layers.append(BatchNorm(axis=(0, 1))) layers.append(Relu) else: layers.pop(-1) init_random_params, _predict = stax.serial(*layers) num_subunits = params_dict[ 'num_subunits'] if 'num_subunits' in params_dict else 1 _, init_params = init_random_params(key, (-1, num_subunits)) opt_init, opt_update, get_params = optimizers.adam(step_size) opt_state = opt_init(init_params) num_iters = params_dict[ 'num_iters'] if 'num_iters' in params_dict else 1000 if num_subunits == 1: data = {'x': x0.reshape(-1, 1), 'y': y0.reshape(-1, 1)} else: data = { 'x': np.vstack([x0 for i in range(num_subunits)]).T, 'y': y0.reshape(-1, 1) } for i in range(num_iters): opt_state = step(i, opt_state, data) opt_params = get_params(opt_state) def _nl(opt_params, x_new): if len(x_new.shape) == 1: x_new = x_new.reshape(-1, 1) return np.maximum(_predict(opt_params, x_new), 0) self.nl_xrange = x0 self.nl_params = opt_params self.fnl_fitted = _nl
# inputs = next(batches) # train_images, labels, _, _ = mnist(permute_train=True, resize=True) # del _ # inputs = train_images[:data_size] # # del train_images # u, s, v_t = onp.linalg.svd(inputs, full_matrices=False) # I = np.eye(v_t.shape[-1]) # I_add = npr.normal(0.0, 0.002, size=I.shape) # noisy_I = I + I_add init_fun, conv_net = stax.serial( Conv(32, (5, 5), (2, 2), padding="SAME"), BatchNorm(), Relu, Conv(10, (3, 3), (2, 2), padding="SAME"), Relu, Flatten, Dense(num_classes), LogSoftmax, ) _, key = random.split(random.PRNGKey(0)) class DataTopologyAE(AbstractProblem): def __init__(self): self.HPARAMS_PATH = "hparams.json" @staticmethod