Exemple #1
0
    def __init__(self,
                 indim,
                 outdim,
                 topology,
                 omega=1.0,
                 transform=None,
                 seed=0):
        """
        Arguments
        ---------

        indim:
            Dimensionality of the a single data input.

        outdim:
            Dimensionality of the a single data output.

        topology: Tuple
            Defines the structure of the inner layers for the network.

        omega:
            Weight distribution factor ω₀ for the first layer (as described in [1]).

        transform: Optional[Callable]
            Optional pre-network transformation function.

        seed: Optional[int]
            Initial seed for weight initialization.
        """

        tlayer = build_transform_layer(transform)
        # Weight initialization for Sirens
        pdf_in = variance_scaling(1.0 / 3, "fan_in", "uniform")
        pdf = variance_scaling(2.0 / omega**2, "fan_in", "uniform")
        # Sine activation function
        σ_in = stax.elementwise(lambda x: np.sin(omega * x))
        σ = stax.elementwise(lambda x: np.sin(x))
        # Build layers
        layer_in = [
            stax.Flatten, *tlayer,
            stax.Dense(topology[0], pdf_in), σ_in
        ]
        layers = list(
            chain.from_iterable((stax.Dense(i, pdf), σ) for i in topology[1:]))
        layers = layer_in + layers + [stax.Dense(outdim, pdf)]
        #
        super().__init__(indim, layers, seed)
 def build_flax_module(self):
   kernel_inits = [
       initializers.variance_scaling(scale, 'fan_in', 'truncated_normal')
       for scale in self.hps.kernel_scales
   ]
   return FullyConnected.partial(
       num_outputs=self.hps['output_shape'][-1],
       hid_sizes=self.hps.hid_sizes,
       activation_function=self.hps.activation_function,
       kernel_inits=kernel_inits)
Exemple #3
0
    def apply(self, node_types, num_node_types, embedding_dim):
        """Compute initial node embeddings.

    Args:
      node_types: <int32[num_nodes]> giving node type indices of each node.
      num_node_types: Total number of node types.
      embedding_dim: Dimensionality of the embedding space.

    Returns:
      <float32[num_nodes, embedding_dim]> embedding array.
    """
        node_type_embeddings = self.param(
            "node_type_embeddings",
            shape=(num_node_types, embedding_dim),
            initializer=initializers.variance_scaling(1.0, "fan_out",
                                                      "truncated_normal"))
        return node_type_embeddings[node_types]
Exemple #4
0
    def apply(self,
              operator,
              vocab_size,
              num_nodes,
              embedding_dim,
              bottleneck_dim=None):
        """Compute token node embeddings.

    Args:
      operator: Operator from tokens to nodes.
      vocab_size: How many tokens there are in the vocabulary.
      num_nodes: How many nodes there are in the graph.
      embedding_dim: Dimensionality of the embedding space.
      bottleneck_dim: Optional initial dimension of the embedding space, which
        will be projected out.

    Returns:
      <float32[num_nodes, embedding_dim]> embedding array.
    """
        param_dim = (bottleneck_dim
                     if bottleneck_dim is not None else embedding_dim)
        token_embeddings = self.param(
            "token_embeddings",
            shape=(vocab_size, param_dim),
            initializer=initializers.variance_scaling(1.0, "fan_out",
                                                      "truncated_normal"))

        node_token_embeddings = operator.apply_add(token_embeddings,
                                                   jnp.zeros(
                                                       (num_nodes, param_dim)),
                                                   in_dims=(0, ),
                                                   out_dims=(0, ))

        if bottleneck_dim is not None:
            node_token_embeddings = flax.nn.Dense(node_token_embeddings,
                                                  features=embedding_dim,
                                                  bias=False)

        return node_token_embeddings
Exemple #5
0
    def apply(self,
              edges,
              num_nodes,
              num_edge_types,
              forward_edge_type_indices,
              reverse_edge_type_indices,
              embedding_dim=gin.REQUIRED):
        """Compute multi-hot binary edge embeddings.

    Args:
      edges: Edges, represented as a sparse operator from a vector indexed by
        edge type to an adjacency matrix.
      num_nodes: Number of nodes in the graph.
      num_edge_types: How many total edge types there are.
      forward_edge_type_indices: Indices of the edge types to embed in the
        forward direction.
      reverse_edge_type_indices: Indices of the edge types to embed in the
        reverse direction.
      embedding_dim: Dimension of the learned embedding.

    Returns:
      <float32[num_nodes, num_nodes, embedding_dim]> embedding array
    """
        total_edge_count = (len(forward_edge_type_indices) +
                            len(reverse_edge_type_indices))
        edge_type_embeddings = self.param(
            "edge_type_embeddings",
            shape=(total_edge_count, embedding_dim),
            initializer=initializers.variance_scaling(1.0, "fan_out",
                                                      "truncated_normal"))

        # Build new operators that include only our desired edge types by mapping
        # the `num_edge_types` to `total_edge_count`.
        (forward_index_map, forward_values, reverse_index_map,
         reverse_values) = (_forward_and_reverse_subsets(
             num_edge_types, forward_edge_type_indices,
             reverse_edge_type_indices))

        e_in_flat = edges.input_indices.squeeze(1)

        forward_operator = sparse_operator.SparseCoordOperator(
            input_indices=forward_index_map[edges.input_indices],
            output_indices=edges.output_indices,
            values=edges.values * forward_values[e_in_flat])

        reverse_operator = sparse_operator.SparseCoordOperator(
            input_indices=reverse_index_map[edges.input_indices],
            output_indices=edges.output_indices,
            values=edges.values * reverse_values[e_in_flat])

        # Apply our adjusted operators, gathering from our extended embeddings
        # array.
        result = jnp.zeros([embedding_dim, num_nodes, num_nodes])
        result = forward_operator.apply_add(in_array=edge_type_embeddings,
                                            out_array=result,
                                            in_dims=[0],
                                            out_dims=[1, 2])
        result = reverse_operator.apply_add(in_array=edge_type_embeddings,
                                            out_array=result,
                                            in_dims=[0],
                                            out_dims=[2, 1])

        # Force it to actually be materialized as
        # [(batch,) embedding_dim, num_nodes, num_nodes] to reduce downstream
        # effects of the bad padding required by the above.
        result = jax_util.force_physical_layout(result)

        return result.transpose((1, 2, 0))
Exemple #6
0
 def init(key, shape, dtype=np.float32):
     scale = np.sqrt(2 /
                     (shape[0] * np.prod(shape[2:]))) * num_layers**(-0.5)
     return variance_scaling(scale=scale,
                             mode="fan_in",
                             distribution="normal")(key, shape, dtype)