コード例 #1
0
ファイル: regressions.py プロジェクト: lindermanlab/jxf
def preprocess_autoregression_data(data,
                                   num_lags=1,
                                   covariates=None,
                                   fit_intercept=True,
                                   **kwargs):
    """Prepare a dataset for fitting with an AR model.

    This takes in data and returns a dictionary with data and covariates.
    """
    # prepare the covariates
    all_covariates = []
    for lag in range(1, num_lags + 1):
        # TODO: Try to avoid memory allocation
        all_covariates.append(
            np.row_stack(
                [np.zeros((lag, data.shape[-1])), data[..., :-lag, :]]))
    if covariates is not None:
        all_covariates.append(covariates)
    if fit_intercept:
        all_covariates.append(np.ones(data.shape[:-1] + (1, )))
    all_covariates = np.concatenate(all_covariates, axis=-1)

    # extract only the valid slice of the data and covariates
    valid_data = data[num_lags:]
    valid_covariates = all_covariates[num_lags:]
    # valid_data = data
    # valid_covariates = all_covariates

    return dict(data=valid_data, covariates=valid_covariates, **kwargs)
コード例 #2
0
ファイル: kmeans.py プロジェクト: benman1/jax-ml
 def initialize_centers(self, X):
     '''Roughly the kmeans++ initialization
     '''
     self.centers = self._get_center(X)
     for c in range(1, self.k):
         weights = self.dist_fun(X, self.centers)
         if c > 1:
             # harmonic mean gives error
             weights = jnp.mean(weights, axis=-1)
         new_center = self._get_center(X, weights)
         self.centers = jnp.row_stack((self.centers, new_center))
コード例 #3
0
    def __call__(self, node_features, node_ids):
        """Embeds nodes by features and node_id.

    Args:
      node_features: float or int tensor representing the current node's fixed
        features. These features are not learned.
      node_ids: id of the node in the image. Used in place of the position in
        the image.

    Returns:
      logits: float tensor of shape (num_classes,)
    """
        cfg = self.config

        num_nodes = len(node_ids)

        # Embed nodes
        node_embs = self.node_embedding(node_features)
        node_embs = node_embs.reshape(num_nodes, -1)
        node_hiddens = self.node_hidden_layer(node_embs)
        graph_hidden = self.graph_embedding(jnp.zeros(1, dtype=int))
        node_hiddens = jnp.row_stack((node_hiddens, graph_hidden))

        # Embed positions
        # TODO(gnegiar): We need to clip the "not a node" node to make sure it
        # propagates gradients correctly. jax.experimental.sparse uses an out of
        # bounds index to encode elements with 0 value.
        # See https://github.com/google/jax/issues/5760
        node_ids = jnp.clip(node_ids, a_max=cfg.image_size - 1)
        position_embs = self.position_embedding(node_ids + 1)
        position_hiddens = self.position_hidden_layer(position_embs)
        # The graph node has no position.
        position_hiddens = jnp.row_stack(
            (position_hiddens, jnp.zeros(position_hiddens.shape[-1])))

        return node_hiddens, position_hiddens
コード例 #4
0
    def _add_supernode(self, node_features, dense_submat, dense_q):
        """Adds supernode with full incoming and outgoing connectivity.

    Adds a row and column of 1s to `dense_submat`, and normalizes. Also adds a
      row to `node_features`, containing the average of the other node features.
      Adds a weight of 1 at the end of `dense_q`.

    Args:
      node_features: Shape (num_nodes, feature_dim) Matrix of node features.
      dense_submat: Shape (num_nodes, num_nodes) Adjacency matrix.
      dense_q: Shape (num_nodes,) Node weights.

    Returns:
      node_features: Shape (num_nodes + 1, feature_dim) Matrix of node features.
      dense_submat: Shape (num_nodes + 1, num_nodes + 1) Adjacency matrix.
      dense_q: Shape (num_nodes + 1,) Node weights.
    """
        dense_submat = jnp.row_stack(
            (dense_submat, jnp.ones(dense_submat.shape[1])))
        dense_submat = jnp.column_stack(
            (dense_submat, jnp.ones(dense_submat.shape[0])))
        # Normalize nonzero elements
        # The sum is bounded away from 0, so this is always differentiable
        # TODO(gnegiar): Do we want this? It means the supernode gets half the
        # outgoing weights
        dense_submat = dense_submat / dense_submat.sum(axis=-1, keepdims=True)
        # Add a weight to the supernode
        dense_q = jnp.append(dense_q, jnp.mean(dense_q))
        # We embed the supernode using a distinct value.
        # TODO(gnegiar): Should we use another embedding?
        node_features = jnp.append(node_features,
                                   jnp.full((1, node_features.shape[1]),
                                            2,
                                            dtype=int),
                                   axis=0)
        return node_features, dense_submat, dense_q
コード例 #5
0
ファイル: matrixops.py プロジェクト: Techtonique/nnetsauce
def rbind(x, y, backend="cpu"):
    # if len(x.shape) == 1 or len(y.shape) == 1:
    sys_platform = platform.system()
    if backend in ("gpu", "tpu") and (sys_platform in ("Linux", "Darwin")):
        return jnp.row_stack((x, y))
    return np.row_stack((x, y))
コード例 #6
0
ファイル: kmeans.py プロジェクト: benman1/jax-ml
 def adjust_centers(self, X):
     '''Adjust centers given cluster assignments
     '''
     jnp.row_stack(
         [self._mean(X[self.clusters == c], axis=0) for c in self.clusters])