Beispiel #1
0
        def compute_logits_single_example(hidden_states, instruction_pointer,
                                          exit_index, steps, node_embeddings,
                                          true_indexes, false_indexes):
            """single_example refers to selecting a single exit node hidden state."""

            # leaves(hidden_states).shape: num_nodes, hidden_size

            def step_(carry, _):
                hidden_states, instruction_pointer, index = carry
                hidden_states_new, instruction_pointer_new, to_tag = (
                    step_single_example(hidden_states, instruction_pointer,
                                        node_embeddings, true_indexes,
                                        false_indexes, exit_index))
                carry = jax.tree_map(
                    lambda new, old, index=index: jnp.where(
                        index < steps, new, old),
                    (hidden_states_new, instruction_pointer_new, index + 1),
                    (hidden_states, instruction_pointer, index + 1),
                )
                return carry, to_tag

            if config.model.ipagnn.checkpoint and not self.is_initializing():
                step_ = jax.checkpoint(step_)

            carry = (hidden_states, instruction_pointer, jnp.array([0]))
            (hidden_states, instruction_pointer,
             _), to_tag = lax.scan(step_, carry, None, length=max_steps)

            final_state = jax.tree_map(lambda hs: hs[exit_index],
                                       hidden_states)
            # leaves(final_state).shape: hidden_size
            final_state_concat = jnp.concatenate(jax.tree_leaves(final_state),
                                                 axis=0)
            logits = output_dense(final_state_concat)
            to_tag.update({
                'instruction_pointer_final': instruction_pointer,
                'hidden_states_final': hidden_states,
            })
            return logits, to_tag
Beispiel #2
0
    def make_layer(input_shape):

        fx = build_fx(fx_block_type, input_shape, fx_dim, fx_actfn)

        # Creates the unflatten_w function.
        rng = jax.random.PRNGKey(0)  # temp; not used.
        x_shape, tmp_w = fx.init(rng, input_shape)
        assert input_shape == x_shape, f"fx needs to have the same input and output shapes but got {input_shape} and {x_shape}"
        flat_w, unflatten_w = ravel_pytree(tmp_w)
        w_shape = flat_w.shape
        del tmp_w

        x_dim = int(jnp.prod(jnp.array(x_shape)))
        w_dim = int(jnp.prod(jnp.array(w_shape)))

        def f_aug(y, t, args):
            x = y[:x_dim].reshape(x_shape)
            flat_w = y[x_dim:x_dim + w_dim].reshape(w_shape)
            dx = fx.apply(unflatten_w(flat_w),
                          (x,
                           t))[0] if xt else fx.apply(unflatten_w(flat_w), x)
            if w_drift:
                fw_params = args
                dw = fw.apply(fw_params,
                              (flat_w,
                               t))[0] if xt else fw.apply(fw_params, flat_w)
            else:
                dw = jnp.zeros(w_shape)
            # Hardcoded OU Process.
            u = (dw - (-flat_w)) / \
                diff_coef if diff_coef != 0 else jnp.zeros(w_shape)
            dkl = u**2
            return jnp.concatenate(
                [dx.reshape(-1),
                 dw.reshape(-1),
                 dkl.reshape(-1)])

        def g_aug(y, t, args):
            dx = jnp.zeros(x_shape)
            diff_w = jnp.ones(w_shape) * diff_coef

            if w_drift:
                fw_params = tree_util.tree_map(stop_gradient, args)
                drift_w = fw.apply(fw_params,
                                   (flat_w, t))[0] if xt else fw.apply(
                                       fw_params, flat_w)
            else:
                drift_w = jnp.zeros(w_shape)

            # Hardcoded OU Process.
            u = (drift_w - (-flat_w)) / \
                diff_coef if diff_coef != 0 else jnp.zeros(w_shape)
            dkl = u if stl else jnp.zeros(w_shape)
            return jnp.concatenate(
                [dx.reshape(-1),
                 diff_w.reshape(-1),
                 dkl.reshape(-1)])

        def init_fun(rng, input_shape):
            output_shape, w0 = fx.init(rng, input_shape)
            flat_w0, unflatten_w = ravel_pytree(w0)
            if w_drift:
                output_shape, fw_params = fw.init(rng, flat_w0.shape)
                assert flat_w0.shape == output_shape, "fw needs to have the same input and output shapes"
            else:
                fw_params = ()
            return input_shape, (flat_w0, fw_params)

        def apply_fun(params,
                      inputs,
                      rng,
                      full_output=False,
                      fixed_grid=True,
                      **kwargs):
            flat_w0, fw_params = params
            x = inputs
            y0 = jnp.concatenate([
                x.reshape(-1),
                flat_w0.reshape(-1),
                jnp.zeros(flat_w0.shape).reshape(-1)
            ])
            rep = w_dim if stl else 0  # STL
            if fixed_grid:
                ys = sdeint_ito_fixed_grid(f_aug,
                                           g_aug,
                                           y0,
                                           ts,
                                           rng,
                                           fw_params,
                                           method="euler_maruyama",
                                           rep=rep)
            else:
                print("using stochastic adjoint")
                ys = sdeint_ito(f_aug,
                                g_aug,
                                y0,
                                ts,
                                rng,
                                fw_params,
                                method="euler_maruyama",
                                rep=rep)
            y = ys[-1]  # Take last time value.
            x = y[:x_dim].reshape(x_shape)
            # import pdb; pdb.set_trace()
            kl = jnp.sum(y[x_dim + w_dim:])

            # Hack to turn this into a stax.layer API when deterministic.
            if stax_api:
                return x

            if full_output:
                infodict = {
                    name + "_w": ys[:,
                                    x_dim:x_dim + w_dim].reshape(-1, *w_shape)
                }
                return x, kl, infodict

            return x, kl

        if remat:
            apply_fun = jax.checkpoint(apply_fun, concrete=True)
        return init_fun, apply_fun
Beispiel #3
0
    def apply(self,
              edge_embeddings,
              node_embeddings,
              out_dim,
              heads=gin.REQUIRED,
              query_key_dim=None,
              value_dim=None,
              mask=None,
              like_great=False):
        """Apply the attention layer.

    Args:
      edge_embeddings: <float32[num_nodes, num_nodes, edge_embedding_dim]> dense
        edge embedding matrix, where zeros indicate no edge.
      node_embeddings: <float32[num_nodes, node_embedding_dim]> node embedding
        matrix.
      out_dim: Output dimension.
      heads: How many attention heads to use.
      query_key_dim: Dimension of the queries and keys. If not provided, assumed
        to be node_embedding_dim / heads.
      value_dim: Dimension of the queries and keys. If not provided, assumed to
        be the same as query_key_dim.
      mask: <float32[num_nodes, num_nodes]> mask determining which other nodes a
        given node is allowed to attend to.
      like_great: Whether to use GREAT-style key-bias attention instead of more
        powerful vector attention (as in RAT).

    Returns:
      <float32[num_nodes, out_dim]> softmax-weighted value sums over nodes in
      the graph.
    """
        def inner_apply(edge_embeddings, node_embeddings):
            # Einsum letters:
            # n: querying node, attends to others.
            # m: queried node, is attended to.
            # h: attention head
            # d: node embedding dimension
            # e: edge embedding dimension
            # q: query/key dimension
            # v: value dimension
            edge_embedding_dim = edge_embeddings.shape[-1]
            node_embedding_dim = node_embeddings.shape[-1]

            nonlocal query_key_dim, value_dim

            if query_key_dim is None:
                if node_embedding_dim % heads != 0:
                    raise ValueError(
                        "No query_key_dim provided, but node embedding dim "
                        f"({node_embedding_dim}) was not divisible by head count "
                        f"({heads})")
                query_key_dim = node_embedding_dim // heads

            if value_dim is None:
                value_dim = query_key_dim

            # Compute queries.
            query_tensor = self.param("query_tensor",
                                      shape=(heads, node_embedding_dim,
                                             query_key_dim),
                                      initializer=initializers.xavier_normal())
            query = jnp.einsum("nd,hdq->hnq", node_embeddings, query_tensor)

            # Dot-product the queries with the node and edge keys.
            node_key_tensor = self.param(
                "node_key_tensor",
                shape=(heads, node_embedding_dim, query_key_dim),
                initializer=initializers.xavier_normal())
            dot_product_logits = jnp.einsum("md,hdq,hnq->hnm", node_embeddings,
                                            node_key_tensor, query)

            if like_great:
                # Edges contribute based on key sums, as in Hellendoorn et al.
                # edge_biases: <float32[num_nodes, num_nodes]>, computed as `w^T e + b`
                edge_biases = flax.nn.Dense(edge_embeddings,
                                            features=1).squeeze(-1)
                # Einsum sums keys over `q` dim (equivalent to broadcasting out biases).
                edge_logits = jnp.einsum("md,hdq,nm->hnm", node_embeddings,
                                         node_key_tensor, edge_biases)
            else:
                # Queries attend to edge keys, as in Wang et al.
                edge_key_tensor = self.param(
                    "edge_key_tensor",
                    shape=(edge_embedding_dim, query_key_dim),
                    initializer=initializers.xavier_normal())
                edge_logits = jnp.einsum("nme,eq,hnq->hnm", edge_embeddings,
                                         edge_key_tensor, query)

            # Combine, normalize, and possibly mask.
            attention_logits = ((dot_product_logits + edge_logits) /
                                jnp.sqrt(query_key_dim))
            if mask is not None:
                attention_logits = attention_logits + jnp.log(mask)[None, :, :]

            attention_weights = jax.nn.softmax(attention_logits, axis=2)

            # Wrap attention weights with a Flax module so we can extract intermediate
            # outputs.
            attention_weights = jax_util.flax_tag(attention_weights,
                                                  name="attention_weights")

            # Compute values.
            node_value_tensor = self.param(
                "node_value_tensor",
                shape=(heads, node_embedding_dim, value_dim),
                initializer=initializers.xavier_normal())
            attention_node_value = jnp.einsum(
                "hnm,hmv->hnv", attention_weights,
                jnp.einsum("md,hdv->hmv", node_embeddings, node_value_tensor))

            if like_great:
                # Only nodes contribute to values, as in Hellendoorn et al.
                attention_value = attention_node_value
            else:
                # Edges also contribute to values, as in Wang et al.
                edge_value_tensor = self.param(
                    "edge_value_tensor",
                    shape=(edge_embedding_dim, value_dim),
                    initializer=initializers.xavier_normal())
                attention_edge_value = jnp.einsum("hnm,nme,ev->hnv",
                                                  attention_weights,
                                                  edge_embeddings,
                                                  edge_value_tensor)

                attention_value = attention_node_value + attention_edge_value

            # Project them back.
            output_tensor = self.param(
                "output_tensor",
                shape=(heads, value_dim, out_dim),
                initializer=initializers.xavier_normal())
            output = jnp.einsum("hnv,hvo->no", attention_value, output_tensor)
            return output

        # Make sure we don't keep any of the intermediate hidden matrices around
        # any longer than we have to.
        if not self.is_initializing():
            inner_apply = jax.checkpoint(inner_apply)

        return inner_apply(edge_embeddings, node_embeddings)
Beispiel #4
0
    def apply(
        self,
        edge_embeddings,
        node_embeddings,
        mask,
        mlp_vtoe_dims,
        allow_non_adjacent,
        message_passing,
    ):
        """Apply the NRIEdgeLayer.

    Args:
      edge_embeddings: <float32[num_nodes, num_nodes, edge_embedding_dim]> dense
        edge embedding matrix, where zeros indicate no edge.
      node_embeddings: <float32[num_nodes, node_embedding_dim]> node embedding
        matrix.
      mask: <float32[num_nodes, num_nodes]> mask determining which other nodes a
        given node is allowed to send messages to.
      mlp_vtoe_dims: List of hidden and output dimension sizes; determines the
        depth and width of the MLP.
      allow_non_adjacent: Compute messages even for non-adjacent nodes.
      message_passing: If True, accumulate messages to destination nodes.

    Returns:
      If message_passing=True: <float32[num_nodes, out_dim]> messages.
      If message_passing=False: <float32[num_nodes, num_nodes, out_dim]> edge
                                activations.
    """
        if not mlp_vtoe_dims:
            raise ValueError("Must have a nonempty sequence for mlp_vtoe_dims")

        def inner_apply(edge_embeddings, node_embeddings):
            first_layer_dim = mlp_vtoe_dims[0]
            additional_layer_dims = mlp_vtoe_dims[1:]

            if allow_non_adjacent and edge_embeddings is not None:
                num_separate_mlps = 1 + edge_embeddings.shape[-1]
            elif allow_non_adjacent:
                num_separate_mlps = 1
            elif edge_embeddings is not None:
                num_separate_mlps = edge_embeddings.shape[-1]
            else:
                raise ValueError(
                    "Either allow_non_adjacent should be True, or "
                    "edge_embeddings should be provided")

            node_embedding_dim = node_embeddings.shape[-1]

            # First layer: process each node embedding.
            weight_from_source = self.param(
                "l0_weight_from_source",
                shape=(num_separate_mlps, node_embedding_dim, first_layer_dim),
                initializer=initializers.xavier_normal())
            weight_from_dest = self.param(
                "l0_weight_from_dest",
                shape=(num_separate_mlps, node_embedding_dim, first_layer_dim),
                initializer=initializers.xavier_normal())
            bias = self.param("l0_bias",
                              shape=(num_separate_mlps, first_layer_dim),
                              initializer=initializers.zeros)
            from_source = jnp.einsum("sx,kxy->sky", node_embeddings,
                                     weight_from_source)
            from_dest = jnp.einsum("dx,kxy->dky", node_embeddings,
                                   weight_from_dest)
            activations = jax.nn.relu(from_source[:, None, :, :] +
                                      from_dest[None, :, :, :] +
                                      bias[None, None, :, :])

            # Additional layers: MLP for each edge type.
            for i, layer_dim in enumerate(additional_layer_dims):
                weight = self.param(f"l{i+1}_weight",
                                    shape=(num_separate_mlps,
                                           activations.shape[-1], layer_dim),
                                    initializer=initializers.xavier_normal())
                bias = self.param(f"l{i+1}_bias",
                                  shape=(num_separate_mlps, layer_dim),
                                  initializer=initializers.zeros)
                activations = jax.nn.relu(
                    jnp.einsum("sdkx,kxy->sdky", activations, weight) +
                    bias[None, None, :, :])

            # Sum over edge types and possibly over source nodes.
            if edge_embeddings is None:
                result = activations.squeeze(axis=2)
                if mask is not None:
                    result = jnp.where(mask[:, :, None], result,
                                       jnp.zeros_like(result))
                if message_passing:
                    result = jnp.sum(result, axis=0)
            else:
                if allow_non_adjacent:
                    if mask is None:
                        pairwise = jnp.ones(edge_embeddings.shape[:2] + (1, ))
                    else:
                        pairwise = mask
                    mlp_weights = jnp.concatenate([
                        edge_embeddings,
                        pairwise.astype("float")[:, :, None]
                    ], -1)
                else:
                    mlp_weights = edge_embeddings

                if message_passing:
                    result = jnp.einsum("sdky,sdk->dy", activations,
                                        mlp_weights)
                else:
                    result = jnp.einsum("sdky,sdk->sdy", activations,
                                        mlp_weights)

            return result

        # Make sure we don't keep any of the intermediate hidden matrices around
        # any longer than we have to.
        if not self.is_initializing():
            inner_apply = jax.checkpoint(inner_apply)

        return inner_apply(edge_embeddings, node_embeddings)
Beispiel #5
0
def loss(params, batch, rng, kl_coef):  # backprop so checkpoint
    _, nll, kl, _ = jax.checkpoint(_nll)(params, batch, rng)
    if kl_coef > 0:
        return nll + kl * kl_coef
    else:
        return nll