def compute_logits_single_example(hidden_states, instruction_pointer, exit_index, steps, node_embeddings, true_indexes, false_indexes): """single_example refers to selecting a single exit node hidden state.""" # leaves(hidden_states).shape: num_nodes, hidden_size def step_(carry, _): hidden_states, instruction_pointer, index = carry hidden_states_new, instruction_pointer_new, to_tag = ( step_single_example(hidden_states, instruction_pointer, node_embeddings, true_indexes, false_indexes, exit_index)) carry = jax.tree_map( lambda new, old, index=index: jnp.where( index < steps, new, old), (hidden_states_new, instruction_pointer_new, index + 1), (hidden_states, instruction_pointer, index + 1), ) return carry, to_tag if config.model.ipagnn.checkpoint and not self.is_initializing(): step_ = jax.checkpoint(step_) carry = (hidden_states, instruction_pointer, jnp.array([0])) (hidden_states, instruction_pointer, _), to_tag = lax.scan(step_, carry, None, length=max_steps) final_state = jax.tree_map(lambda hs: hs[exit_index], hidden_states) # leaves(final_state).shape: hidden_size final_state_concat = jnp.concatenate(jax.tree_leaves(final_state), axis=0) logits = output_dense(final_state_concat) to_tag.update({ 'instruction_pointer_final': instruction_pointer, 'hidden_states_final': hidden_states, }) return logits, to_tag
def make_layer(input_shape): fx = build_fx(fx_block_type, input_shape, fx_dim, fx_actfn) # Creates the unflatten_w function. rng = jax.random.PRNGKey(0) # temp; not used. x_shape, tmp_w = fx.init(rng, input_shape) assert input_shape == x_shape, f"fx needs to have the same input and output shapes but got {input_shape} and {x_shape}" flat_w, unflatten_w = ravel_pytree(tmp_w) w_shape = flat_w.shape del tmp_w x_dim = int(jnp.prod(jnp.array(x_shape))) w_dim = int(jnp.prod(jnp.array(w_shape))) def f_aug(y, t, args): x = y[:x_dim].reshape(x_shape) flat_w = y[x_dim:x_dim + w_dim].reshape(w_shape) dx = fx.apply(unflatten_w(flat_w), (x, t))[0] if xt else fx.apply(unflatten_w(flat_w), x) if w_drift: fw_params = args dw = fw.apply(fw_params, (flat_w, t))[0] if xt else fw.apply(fw_params, flat_w) else: dw = jnp.zeros(w_shape) # Hardcoded OU Process. u = (dw - (-flat_w)) / \ diff_coef if diff_coef != 0 else jnp.zeros(w_shape) dkl = u**2 return jnp.concatenate( [dx.reshape(-1), dw.reshape(-1), dkl.reshape(-1)]) def g_aug(y, t, args): dx = jnp.zeros(x_shape) diff_w = jnp.ones(w_shape) * diff_coef if w_drift: fw_params = tree_util.tree_map(stop_gradient, args) drift_w = fw.apply(fw_params, (flat_w, t))[0] if xt else fw.apply( fw_params, flat_w) else: drift_w = jnp.zeros(w_shape) # Hardcoded OU Process. u = (drift_w - (-flat_w)) / \ diff_coef if diff_coef != 0 else jnp.zeros(w_shape) dkl = u if stl else jnp.zeros(w_shape) return jnp.concatenate( [dx.reshape(-1), diff_w.reshape(-1), dkl.reshape(-1)]) def init_fun(rng, input_shape): output_shape, w0 = fx.init(rng, input_shape) flat_w0, unflatten_w = ravel_pytree(w0) if w_drift: output_shape, fw_params = fw.init(rng, flat_w0.shape) assert flat_w0.shape == output_shape, "fw needs to have the same input and output shapes" else: fw_params = () return input_shape, (flat_w0, fw_params) def apply_fun(params, inputs, rng, full_output=False, fixed_grid=True, **kwargs): flat_w0, fw_params = params x = inputs y0 = jnp.concatenate([ x.reshape(-1), flat_w0.reshape(-1), jnp.zeros(flat_w0.shape).reshape(-1) ]) rep = w_dim if stl else 0 # STL if fixed_grid: ys = sdeint_ito_fixed_grid(f_aug, g_aug, y0, ts, rng, fw_params, method="euler_maruyama", rep=rep) else: print("using stochastic adjoint") ys = sdeint_ito(f_aug, g_aug, y0, ts, rng, fw_params, method="euler_maruyama", rep=rep) y = ys[-1] # Take last time value. x = y[:x_dim].reshape(x_shape) # import pdb; pdb.set_trace() kl = jnp.sum(y[x_dim + w_dim:]) # Hack to turn this into a stax.layer API when deterministic. if stax_api: return x if full_output: infodict = { name + "_w": ys[:, x_dim:x_dim + w_dim].reshape(-1, *w_shape) } return x, kl, infodict return x, kl if remat: apply_fun = jax.checkpoint(apply_fun, concrete=True) return init_fun, apply_fun
def apply(self, edge_embeddings, node_embeddings, out_dim, heads=gin.REQUIRED, query_key_dim=None, value_dim=None, mask=None, like_great=False): """Apply the attention layer. Args: edge_embeddings: <float32[num_nodes, num_nodes, edge_embedding_dim]> dense edge embedding matrix, where zeros indicate no edge. node_embeddings: <float32[num_nodes, node_embedding_dim]> node embedding matrix. out_dim: Output dimension. heads: How many attention heads to use. query_key_dim: Dimension of the queries and keys. If not provided, assumed to be node_embedding_dim / heads. value_dim: Dimension of the queries and keys. If not provided, assumed to be the same as query_key_dim. mask: <float32[num_nodes, num_nodes]> mask determining which other nodes a given node is allowed to attend to. like_great: Whether to use GREAT-style key-bias attention instead of more powerful vector attention (as in RAT). Returns: <float32[num_nodes, out_dim]> softmax-weighted value sums over nodes in the graph. """ def inner_apply(edge_embeddings, node_embeddings): # Einsum letters: # n: querying node, attends to others. # m: queried node, is attended to. # h: attention head # d: node embedding dimension # e: edge embedding dimension # q: query/key dimension # v: value dimension edge_embedding_dim = edge_embeddings.shape[-1] node_embedding_dim = node_embeddings.shape[-1] nonlocal query_key_dim, value_dim if query_key_dim is None: if node_embedding_dim % heads != 0: raise ValueError( "No query_key_dim provided, but node embedding dim " f"({node_embedding_dim}) was not divisible by head count " f"({heads})") query_key_dim = node_embedding_dim // heads if value_dim is None: value_dim = query_key_dim # Compute queries. query_tensor = self.param("query_tensor", shape=(heads, node_embedding_dim, query_key_dim), initializer=initializers.xavier_normal()) query = jnp.einsum("nd,hdq->hnq", node_embeddings, query_tensor) # Dot-product the queries with the node and edge keys. node_key_tensor = self.param( "node_key_tensor", shape=(heads, node_embedding_dim, query_key_dim), initializer=initializers.xavier_normal()) dot_product_logits = jnp.einsum("md,hdq,hnq->hnm", node_embeddings, node_key_tensor, query) if like_great: # Edges contribute based on key sums, as in Hellendoorn et al. # edge_biases: <float32[num_nodes, num_nodes]>, computed as `w^T e + b` edge_biases = flax.nn.Dense(edge_embeddings, features=1).squeeze(-1) # Einsum sums keys over `q` dim (equivalent to broadcasting out biases). edge_logits = jnp.einsum("md,hdq,nm->hnm", node_embeddings, node_key_tensor, edge_biases) else: # Queries attend to edge keys, as in Wang et al. edge_key_tensor = self.param( "edge_key_tensor", shape=(edge_embedding_dim, query_key_dim), initializer=initializers.xavier_normal()) edge_logits = jnp.einsum("nme,eq,hnq->hnm", edge_embeddings, edge_key_tensor, query) # Combine, normalize, and possibly mask. attention_logits = ((dot_product_logits + edge_logits) / jnp.sqrt(query_key_dim)) if mask is not None: attention_logits = attention_logits + jnp.log(mask)[None, :, :] attention_weights = jax.nn.softmax(attention_logits, axis=2) # Wrap attention weights with a Flax module so we can extract intermediate # outputs. attention_weights = jax_util.flax_tag(attention_weights, name="attention_weights") # Compute values. node_value_tensor = self.param( "node_value_tensor", shape=(heads, node_embedding_dim, value_dim), initializer=initializers.xavier_normal()) attention_node_value = jnp.einsum( "hnm,hmv->hnv", attention_weights, jnp.einsum("md,hdv->hmv", node_embeddings, node_value_tensor)) if like_great: # Only nodes contribute to values, as in Hellendoorn et al. attention_value = attention_node_value else: # Edges also contribute to values, as in Wang et al. edge_value_tensor = self.param( "edge_value_tensor", shape=(edge_embedding_dim, value_dim), initializer=initializers.xavier_normal()) attention_edge_value = jnp.einsum("hnm,nme,ev->hnv", attention_weights, edge_embeddings, edge_value_tensor) attention_value = attention_node_value + attention_edge_value # Project them back. output_tensor = self.param( "output_tensor", shape=(heads, value_dim, out_dim), initializer=initializers.xavier_normal()) output = jnp.einsum("hnv,hvo->no", attention_value, output_tensor) return output # Make sure we don't keep any of the intermediate hidden matrices around # any longer than we have to. if not self.is_initializing(): inner_apply = jax.checkpoint(inner_apply) return inner_apply(edge_embeddings, node_embeddings)
def apply( self, edge_embeddings, node_embeddings, mask, mlp_vtoe_dims, allow_non_adjacent, message_passing, ): """Apply the NRIEdgeLayer. Args: edge_embeddings: <float32[num_nodes, num_nodes, edge_embedding_dim]> dense edge embedding matrix, where zeros indicate no edge. node_embeddings: <float32[num_nodes, node_embedding_dim]> node embedding matrix. mask: <float32[num_nodes, num_nodes]> mask determining which other nodes a given node is allowed to send messages to. mlp_vtoe_dims: List of hidden and output dimension sizes; determines the depth and width of the MLP. allow_non_adjacent: Compute messages even for non-adjacent nodes. message_passing: If True, accumulate messages to destination nodes. Returns: If message_passing=True: <float32[num_nodes, out_dim]> messages. If message_passing=False: <float32[num_nodes, num_nodes, out_dim]> edge activations. """ if not mlp_vtoe_dims: raise ValueError("Must have a nonempty sequence for mlp_vtoe_dims") def inner_apply(edge_embeddings, node_embeddings): first_layer_dim = mlp_vtoe_dims[0] additional_layer_dims = mlp_vtoe_dims[1:] if allow_non_adjacent and edge_embeddings is not None: num_separate_mlps = 1 + edge_embeddings.shape[-1] elif allow_non_adjacent: num_separate_mlps = 1 elif edge_embeddings is not None: num_separate_mlps = edge_embeddings.shape[-1] else: raise ValueError( "Either allow_non_adjacent should be True, or " "edge_embeddings should be provided") node_embedding_dim = node_embeddings.shape[-1] # First layer: process each node embedding. weight_from_source = self.param( "l0_weight_from_source", shape=(num_separate_mlps, node_embedding_dim, first_layer_dim), initializer=initializers.xavier_normal()) weight_from_dest = self.param( "l0_weight_from_dest", shape=(num_separate_mlps, node_embedding_dim, first_layer_dim), initializer=initializers.xavier_normal()) bias = self.param("l0_bias", shape=(num_separate_mlps, first_layer_dim), initializer=initializers.zeros) from_source = jnp.einsum("sx,kxy->sky", node_embeddings, weight_from_source) from_dest = jnp.einsum("dx,kxy->dky", node_embeddings, weight_from_dest) activations = jax.nn.relu(from_source[:, None, :, :] + from_dest[None, :, :, :] + bias[None, None, :, :]) # Additional layers: MLP for each edge type. for i, layer_dim in enumerate(additional_layer_dims): weight = self.param(f"l{i+1}_weight", shape=(num_separate_mlps, activations.shape[-1], layer_dim), initializer=initializers.xavier_normal()) bias = self.param(f"l{i+1}_bias", shape=(num_separate_mlps, layer_dim), initializer=initializers.zeros) activations = jax.nn.relu( jnp.einsum("sdkx,kxy->sdky", activations, weight) + bias[None, None, :, :]) # Sum over edge types and possibly over source nodes. if edge_embeddings is None: result = activations.squeeze(axis=2) if mask is not None: result = jnp.where(mask[:, :, None], result, jnp.zeros_like(result)) if message_passing: result = jnp.sum(result, axis=0) else: if allow_non_adjacent: if mask is None: pairwise = jnp.ones(edge_embeddings.shape[:2] + (1, )) else: pairwise = mask mlp_weights = jnp.concatenate([ edge_embeddings, pairwise.astype("float")[:, :, None] ], -1) else: mlp_weights = edge_embeddings if message_passing: result = jnp.einsum("sdky,sdk->dy", activations, mlp_weights) else: result = jnp.einsum("sdky,sdk->sdy", activations, mlp_weights) return result # Make sure we don't keep any of the intermediate hidden matrices around # any longer than we have to. if not self.is_initializing(): inner_apply = jax.checkpoint(inner_apply) return inner_apply(edge_embeddings, node_embeddings)
def loss(params, batch, rng, kl_coef): # backprop so checkpoint _, nll, kl, _ = jax.checkpoint(_nll)(params, batch, rng) if kl_coef > 0: return nll + kl * kl_coef else: return nll