def __call__(self, x): # Floatify the image. x = x.astype(jnp.float32) / 255.0 # Apply CNN. w_init = DeltaOrthogonal(scale=np.sqrt(2 / (1 + self.negative_slope ** 2))) x = hk.Conv2D(self.num_filters, kernel_shape=4, stride=2, padding="VALID", w_init=w_init)(x) x = nn.leaky_relu(x, self.negative_slope) for _ in range(self.num_layers - 1): x = hk.Conv2D(self.num_filters, kernel_shape=3, stride=1, padding="VALID", w_init=w_init)(x) x = nn.leaky_relu(x, self.negative_slope) # Flatten the feature map. return hk.Flatten()(x)
def __call__(self, x): # Apply linear layer. w_init = hk.initializers.Orthogonal(scale=np.sqrt(2 / (1 + self.negative_slope ** 2))) x = hk.Linear(self.last_conv_dim, w_init=w_init)(x) x = nn.leaky_relu(x, self.negative_slope).reshape(-1, self.map_size, self.map_size, self.num_filters) # Apply Transposed CNN. w_init = DeltaOrthogonal(scale=np.sqrt(2 / (1 + self.negative_slope ** 2))) for _ in range(self.num_layers - 1): x = hk.Conv2DTranspose(self.num_filters, kernel_shape=3, stride=1, padding="VALID", w_init=w_init)(x) x = nn.leaky_relu(x, self.negative_slope) # Apply output layer. w_init = DeltaOrthogonal(scale=1.0) x = hk.Conv2DTranspose(self.state_space.shape[2], kernel_shape=4, stride=2, padding="VALID", w_init=w_init)(x) return x
def __call__(self, x): B, S, H, W, C = x.shape # Floatify the image. x = x.astype(jnp.float32) / 255.0 # Reshape. x = x.reshape([B * S, H, W, C]) # Apply CNN. w_init = DeltaOrthogonal(scale=1.0) depth = [32, 64, 128, 256, self.output_dim] kernel = [5, 3, 3, 3, 4] stride = [2, 2, 2, 2, 1] padding = ["SAME", "SAME", "SAME", "SAME", "VALID"] for i in range(5): x = hk.Conv2D( depth[i], kernel_shape=kernel[i], stride=stride[i], padding=padding[i], w_init=w_init, )(x) x = nn.leaky_relu(x, self.negative_slope) return x.reshape([B, S, -1])
def __call__(self, x): B, S, latent_dim = x.shape # Reshape. x = x.reshape([B * S, 1, 1, latent_dim]) # Apply CNN. w_init = DeltaOrthogonal(scale=1.0) depth = [256, 128, 64, 32, self.state_space.shape[2]] kernel = [4, 3, 3, 3, 5] stride = [1, 2, 2, 2, 2] padding = ["VALID", "SAME", "SAME", "SAME", "SAME"] for i in range(4): x = hk.Conv2DTranspose( depth[i], kernel_shape=kernel[i], stride=stride[i], padding=padding[i], w_init=w_init, )(x) x = nn.leaky_relu(x, self.negative_slope) x = hk.Conv2DTranspose( depth[-1], kernel_shape=kernel[-1], stride=stride[-1], padding=padding[-1], w_init=w_init, )(x) _, W, H, C = x.shape x = x.reshape([B, S, W, H, C]) return x, jax.lax.stop_gradient(jnp.ones_like(x) * self.std)
def apply_fun(params, x, adj, rng, activation=nn.elu, is_training=False, **kwargs): W, a1, a2 = params k1, k2, k3 = random.split(rng, 3) x = drop_fun(None, x, is_training=is_training, rng=k1) x = np.dot(x, W) f_1 = np.dot(x, a1) f_2 = np.dot(x, a2) logits = f_1 + f_2.T coefs = nn.softmax( nn.leaky_relu(logits, negative_slope=0.2) + np.where(adj, 0., -1e9)) coefs = drop_fun(None, coefs, is_training=is_training, rng=k2) x = drop_fun(None, x, is_training=is_training, rng=k3) ret = np.matmul(coefs, x) return activation(ret)
def invertible_mlp_fwd(params, x, slope=0.1): """Forward pass through invertible MLP used as the mixing function. Args: params (list): list where each element is a list of layer weight and bias [W, b]. len(params) is the number of layers. x (vector): input data, here independent components at specific time. slope (float): slope for activation function. Return: Output of MLP, here observed data of mixed independent components. """ z = x for W, b in params[:-1]: z = jnp.matmul(z, W) + b z = jnn.leaky_relu(z, slope) final_W, final_b = params[-1] z = jnp.dot(z, final_W) + final_b return z
def invertible_mlp_inverse(params, x, lrelu_slope=0.1): """Inverse of invertible MLP defined above. Args: params (list): list where each element is a list of layer weight and bias [W, b]. len(params) is the number of layers. x (vector): output of forward MLP, here observed data. slope (float): slope for activation function. Returns: Inputs into the MLP. Here the independent components. """ z = x params_rev = params[::-1] final_W, final_b = params_rev[0] z = z - final_b z = jnp.dot(z, jnp.linalg.inv(final_W)) for W, b in params_rev[1:]: z = jnn.leaky_relu(z, 1. / lrelu_slope) z = z - b z = jnp.dot(z, jnp.linalg.inv(W)) return z
def tailored_lrelu(negative_slope, x): return math.sqrt(2.0 / (1 + negative_slope**2)) * nn.leaky_relu( x, negative_slope=negative_slope)
def onnx_leaky_relu(x, alpha=0.01): return leaky_relu(x, alpha)
def activate_layer_static(params, inputs): s = jnp.dot(params[:, 1:], inputs) + params[:, 0] return nn.leaky_relu(s)