def __call__(self, z, train: bool = True): # Common arguments conv_kwargs = { 'kernel_size': (4, 4), 'strides': (2, 2), 'padding': 'SAME', 'use_bias': False, 'kernel_init': he_normal() } norm_kwargs = { 'use_running_average': not train, 'momentum': 0.99, 'epsilon': 0.001, 'use_scale': True, 'use_bias': True } z = np.reshape(z, (1, 1, self.zdim)) # Layer 1 z = nn.ConvTranspose(features=512, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, kernel_init=he_normal())(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 2 z = nn.ConvTranspose(features=256, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 3 z = nn.ConvTranspose(features=128, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 4 z = nn.ConvTranspose(features=64, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 5 z = nn.ConvTranspose(features=1, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, kernel_init=nn.initializers.xavier_normal())(z) # x = nn.sigmoid(z) x = nn.softplus(z) return jnp.rot90(np.squeeze(x), k=2) # Rotate to match TF output
def __call__(self, x, train: bool = True): # Common arguments kwargs = { 'kernel_size': (4, 4), 'strides': (2, 2), 'padding': 'SAME', 'use_bias': False, 'kernel_init': he_normal() } # x = np.reshape(x, (64, 64, 1)) x = x[..., None] # Layer 1 x = nn.Conv(features=64, **kwargs)(x) x = nn.leaky_relu(x, 0.2) # Layer 2 x = nn.Conv(features=128, **kwargs)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.leaky_relu(x, 0.2) # Layer 3 x = nn.Conv(features=256, **kwargs)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.leaky_relu(x, 0.2) # Layer 4 x = nn.Conv(features=512, **kwargs)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.leaky_relu(x, 0.2) # Layer 5 x = nn.Conv(features=4096, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, kernel_init=he_normal())(x) x = nn.leaky_relu(x, 0.2) # Flatten x = x.flatten() # Predict latent variables z_mean = nn.Dense(features=self.zdim)(x) z_logvar = nn.Dense(features=self.zdim)(x) return z_mean, z_logvar
def ConcatSquashLinear(out_dim, W_init=he_normal(), b_init=normal()): """ y = Sigmoid(at + c)(Wx + b) + dt. Note: he_normal only takes multi dim. """ def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2, k3, k4, k5 = random.split(rng, 5) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, )) w_t, w_tb = b_init(k3, (out_dim, )), b_init(k4, (out_dim, )) b_t = b_init(k5, (out_dim, )) return output_shape, (W, b, w_t, w_tb, b_t) def apply_fun(params, inputs, **kwargs): x, t = inputs W, b, w_t, w_tb, b_t = params # (W.xtt + b) * out = np.dot(x, W) + b # sigmoid(a.t + c) + out *= jax.nn.sigmoid(w_t * t + w_tb) # d.t out += b_t * t return (out, t) return init_fun, apply_fun
def IgnoreConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def apply_fun(params, inputs, **kwargs): x, t = inputs out = apply_fun_wrapped(params, x, **kwargs) return (out, t) return init_fun_wrapped, apply_fun_wrapped
def create_q_net( obs_dim, action_dim, rngkey=jax.random.PRNGKey(0) ) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]: q_init, q_fn = serial( Dense(64, he_normal(), zeros), Relu, Dense(64, he_normal(), zeros), Relu, Dense(action_dim, he_normal(), zeros), ) output_shape, q_params = q_init(rngkey, (1, obs_dim + action_dim)) @jit def q_fn2(q, S, A): return q_fn(q, jnp.hstack([S, A])) return q_params, q_fn2
def create_pi_net( obs_dim: int, action_dim: int, rngkey=jax.random.PRNGKey(0) ) -> TT.Tuple[RT.NNParams, RT.NNParamsFn]: pi_init, pi_fn = serial( Dense(64, he_normal(), zeros), Relu, FanOut(2), parallel( serial( Dense(64, he_normal(), zeros), Relu, Dense(action_dim, he_normal(), zeros), ), serial( Dense(64, he_normal(), zeros), Relu, Dense(action_dim, he_normal(), zeros), ), ), ) output_shape, pi_params = pi_init(rngkey, (1, obs_dim)) pi_fn = jit(pi_fn) return pi_params, pi_fn
def IgnoreLinear(out_dim, W_init=he_normal(), b_init=normal()): """ y = Wx + b """ def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2 = random.split(rng) W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, )) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): x, t = inputs W, b = params return (np.dot(x, W) + b, t) return init_fun, apply_fun
def ConcatSquashConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2, k3, k4 = random.split(rng, 4) output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape) W_hyper_gate, b_hyper_gate = W_init(k2, (1, out_dim)), b_init( k3, (out_dim, )) W_hyper_bias = W_init(k4, (1, out_dim)) return output_shape_conv, (params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias) def apply_fun(params, inputs, **kwargs): x, t = inputs params_conv, W_hyper_gate, b_hyper_gate, W_hyper_bias = params conv_out = apply_fun_wrapped(params_conv, x, **kwargs) gate_out = jax.nn.sigmoid( np.dot(t.view(1, 1), W_hyper_gate) + b_hyper_gate).view( 1, 1, 1, -1) bias_out = np.dot(t.view(1, 1), W_hyper_bias).view(1, 1, 1, -1) out = conv_out * gate_out + bias_out return (out, t) return init_fun, apply_fun
def ConcatCoordConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): concat_input_shape = list(input_shape) # add time and coord channels; from 1 (torch) -> 0 concat_input_shape[-1] += 3 concat_input_shape = tuple(concat_input_shape) return init_fun_wrapped(rng, concat_input_shape) def apply_fun(params, inputs, **kwargs): x, t = inputs b, h, w, c = x.shape hh = np.arange(h).view(1, h, 1, 1).expand(b, h, w, 1) ww = np.arange(w).view(1, 1, w, 1).expand(b, h, w, 1) tt = t.view(1, 1, 1, 1).expand(b, h, w, 1) x_aug = np.concatenate([x, hh, ww, tt], axis=-1) out = apply_fun_wrapped(params, x_aug, **kwargs) return (out, t) return init_fun, apply_fun
def Dense(out_dim, W_init=he_normal(), b_init=normal(), rho_init=partial(const, c=-5)): """Layer constructor function for a dense (fully-connected) Bayesian linear layer.""" def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2, k3, k4 = random.split(rng, 4) W_mu, b_mu = W_init(k1, (input_shape[-1], out_dim)), b_init( k2, (out_dim, )) W_rho, b_rho = rho_init((input_shape[-1], out_dim)), rho_init( (out_dim, )) return output_shape, (W_mu, b_mu, W_rho, b_rho) def apply_fun(params, inputs, rng, **kwargs): # print(inputs[0][0]) inputs, kl = inputs # kl = 0 subkeys = random.split(rng, 2) W_mu, b_mu, W_rho, b_rho = params W_eps = random.normal(subkeys[0], W_mu.shape) b_eps = random.normal(subkeys[1], b_mu.shape) # q dist W_std = np.exp(W_rho) b_std = np.exp(b_rho) W = W_eps * W_std + W_mu b = b_eps * b_std + b_mu # Bayes by Backprop training W_kl = normal_kldiv(W_mu, 0., W_rho, 0.) b_kl = normal_kldiv(b_mu, 0., b_rho, 0.) W_kl, b_kl = np.sum(W_kl), np.sum(b_kl) kl_loss = W_kl + b_kl kl_loss = kl_loss + np.array( kl) # TODO: why do we get compatibility issues? # print(W.shape) return (np.dot(inputs, W) + b, kl_loss) return init_fun, apply_fun
def ConcatConv2D_v2(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2 = random.split(rng) output_shape_conv, params_conv = init_fun_wrapped(k1, input_shape) W_hyper_bias = W_init(k2, (1, out_dim)) return output_shape_conv, (params_conv, W_hyper_bias) def apply_fun(params, inputs, **kwargs): x, t = inputs params_conv, W_hyper_bias = params out = apply_fun_wrapped(params_conv, x, **kwargs) + np.dot( t.view(1, 1), W_hyper_bias).view( 1, 1, 1, -1) # if ncwh stead of nhwc: .view(1, -1, 1, 1) return (out, t) return init_fun, apply_fun
def BlendConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): k1, k2 = random.split(rng) output_shape, params_f = init_fun_wrapped(k1, input_shape) _, params_g = init_fun_wrapped(k2, input_shape) return output_shape, (params_f, params_g) def apply_fun(params, inputs, **kwargs): x, t = inputs params_f, params_g = params f = apply_fun_wrapped(params_f, x) g = apply_fun_wrapped(params_g, x) out = f + (g - f) * t return (out, t) return init_fun, apply_fun
def ConcatConv2D(out_dim, W_init=he_normal(), b_init=normal(), kernel=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): assert dilation == 1 and groups == 1 if not transpose: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConv( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) else: init_fun_wrapped, apply_fun_wrapped = stax.GeneralConvTranspose( dimension_numbers, out_chan=out_dim, filter_shape=(kernel, kernel), strides=(stride, stride), padding=padding) def init_fun(rng, input_shape): # note, input shapes only take x concat_input_shape = list(input_shape) concat_input_shape[-1] += 1 # add time channel dim concat_input_shape = tuple(concat_input_shape) return init_fun_wrapped(rng, concat_input_shape) def apply_fun(params, inputs, **kwargs): x, t = inputs tt = np.ones_like(x[:, :, :, :1]) * t xtt = np.concatenate([x, tt], axis=-1) out = apply_fun_wrapped(params, xtt, **kwargs) return (out, t) return init_fun, apply_fun
def ConcatLinear(out_dim, W_init=he_normal(), b_init=normal()): """ y = Wx + b + at """ def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) k1, k2 = random.split(rng) W, b = W_init(k1, (input_shape[-1] + 1, out_dim)), b_init(k2, (out_dim, )) return output_shape, (W, b) def apply_fun(params, inputs, **kwargs): x, t = inputs W, b = params # concatenate t onto the inputs tt = t.reshape([-1] * (x.ndim - 1) + [1]) # single batch example # i.e. [:, :, ..., :, :1] column vector tt = np.tile(tt, x.shape[:-1] + (1, )) xtt = np.concatenate([x, tt], axis=-1) return (np.dot(xtt, W) + b, t) return init_fun, apply_fun
def __call__(self, key, shape, dtype=None): if dtype is None: dtype = "float32" initializer_fn = jax_initializers.he_normal() return initializer_fn(key, shape, dtype)
def GCNLayer(out_dim, activation=relu, bias=True, normalize=True, batch_norm=False, dropout=0.0, W_init=he_normal(), b_init=normal()): r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/abs/1609.02907>` Parameters ---------- out_dim : int Number of output node features. activation : Function activation function, default to be relu function. bias : bool Whether to add bias after affine transformation, default to be True. normalize : bool Whether to normalize the adjacency matrix or not, default to be True. batch_norm : bool Whetehr to use BatchNormalization or not, default to be False. dropout : float The probability for dropout, default to 0.0. W_init : initialize function for weight Default to be He normal distribution. b_init : initialize function for bias Default to be normal distribution. Returns ------- init_fun : Function Initializes the parameters of the layer. apply_fun : Function Defines the forward computation function. """ _, drop_fun = Dropout(dropout) batch_norm_init, batch_norm_fun = BatchNorm() def init_fun(rng, input_shape): """Initialize parameters. Parameters ---------- rng : PRNGKey rng is a value for generating random values. input_shape : (batch_size, N, M1) The shape of input (input node features). N is the total number of nodes in the batch of graphs. M1 is the input node feature size. Returns ------- output_shape : (batch_size, N, M2) The shape of output (new node features). M2 is the new node feature size and equal to out_dim. params: Tuple (W, b, batch_norm_param) W is a weight and b is a bias. W : ndarray of shape (N, M2) or None b : ndarray of shape (M2,) batch_norm_param : Tuple (beta, gamma) or None """ output_shape = input_shape[:-1] + (out_dim, ) k1, k2, k3 = random.split(rng, 3) W = W_init(k1, (input_shape[-1], out_dim)) b = b_init(k2, (out_dim, )) if bias else None batch_norm_param = None if batch_norm: output_shape, batch_norm_param = batch_norm_init(k3, output_shape) return output_shape, (W, b, batch_norm_param) def apply_fun(params, node_feats, adj, rng, is_train): """Update node representations. Parameters ---------- node_feats : ndarray of shape (batch_size, N, M1) Batched input node features. N is the total number of nodes in the batch of graphs. M1 is the input node feature size. adj : ndarray of shape (batch_size, N, N) Batched adjacency matrix. rng : PRNGKey rng is a value for generating random values is_train : bool Whether the model is training or not. Returns ------- new_node_feats : ndarray of shape (batch_size, N, M2) Batched new node features. M2 is the new node feature size and equal to out_dim. """ W, b, batch_norm_param = params if normalize: # A' = A + I, where I is the identity matrix # D': diagonal node degree matrix of A' # H' = D'^(-1/2) × A' × D'^(-1/2) × H × W def node_update_func(node_feats, adj): adj = adj + jnp.eye(len(adj)) deg = jnp.sum(adj, axis=1) deg_mat = jnp.diag(jnp.where(deg > 0, deg**(-0.5), 0)) normalized_adj = jnp.dot(deg_mat, jnp.dot(adj, deg_mat)) return jnp.dot(normalized_adj, jnp.dot(node_feats, W)) else: # H' = A × H × W def node_update_func(node_feats, adj): return jnp.dot(adj, jnp.dot(node_feats, W)) # batched operation for updating node features new_node_feats = vmap(node_update_func)(node_feats, adj) if bias: new_node_feats += b new_node_feats = activation(new_node_feats) if dropout != 0.0: rng, key = random.split(rng) new_node_feats = drop_fun(None, new_node_feats, is_train, rng=key) if batch_norm: new_node_feats = batch_norm_fun(batch_norm_param, new_node_feats) return new_node_feats return init_fun, apply_fun