예제 #1
0
파일: rans.py 프로젝트: j-towns/crayjax
def push(m, starts, freqs, precs):
    head, tail = m
    for i in reversed(range(3)):
        idxs = head >> i * tail_prec + head_prec - precs >= freqs
        tail = stack_push(tail, idxs, head.astype(tail_dtype))
        head = jnp.where(idxs, head >> tail_prec, head)
    head_div_freqs, head_mod_freqs = jnp.divmod(head, freqs)
    return (head_div_freqs << precs) | head_mod_freqs + starts, tail
예제 #2
0
파일: jaxpile.py 프로젝트: dpiponi/pile
def reduce(pile):
    topples, remainders = np.divmod(pile, neighbours)
    collected = lax.conv(np.transpose(topples, [0, 3, 1, 2]),
                         np.transpose(kernel, [2, 3, 0, 1]),
                         window_strides=(1, 1),
                         padding="SAME")
    result = remainders + np.transpose(collected, [0, 2, 3, 1])
    return result
예제 #3
0
 def body(args):
     I, B, b, k = args
     i, j = jnp.divmod(jnp.abs(B).argmax(), N)
     b = B[i, j]
     I, B, i, j = jax.lax.cond(jnp.abs(b) > e,
                               step,
                               lambda args: args,
                               operand=(I, B, i, j))
     return I, B, b, k + 1
예제 #4
0
        def split_top_k(split_queries: Array) -> Tuple[Array, Array, Array]:
            # Find most similar clusters
            prototype_scores = jnp.einsum('qd,pd->qp', split_queries,
                                          prototypes)
            top_indices = jax.lax.top_k(prototype_scores, self.n_search)[1]
            # Perform approximate top-k similarity search over most similar clusters.
            selected_data = table[top_indices]
            split_scores = jnp.einsum('qd,qcrvd->qcrv', split_queries,
                                      selected_data)

            # Find highest scoring vector for each row.
            top_id_by_row = jnp.argmax(split_scores, axis=-1)
            top_score_by_row = jnp.max(split_scores, axis=-1)

            top_id_by_row = top_id_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)
            top_score_by_row = top_score_by_row.reshape(
                queries_per_split, self.n_search * rows_per_cluster)

            # Take k highest scores among all rows.
            top_row_idx = jnp.argsort(top_score_by_row,
                                      axis=-1)[:, :-self.k_top - 1:-1]

            # Sub-select best indices for k best rows.
            ids_by_topk_row = jut.matmul_slice(top_id_by_row, top_row_idx)

            # Gather highest scoring vectors for k best rows.
            query_index = jnp.arange(queries_per_split).reshape(-1, 1).tile(
                [1, self.k_top])
            top_cluster_idx, top_cluster_row_idx = jnp.divmod(
                top_row_idx, rows_per_cluster)
            split_topk_values = selected_data[query_index, top_cluster_idx,
                                              top_cluster_row_idx,
                                              ids_by_topk_row]

            row_offset = jnp.mod(
                jnp.arange(0, self.n_search * values_per_cluster,
                           values_per_row), values_per_cluster)
            cluster_offset = jnp.arange(0, table_size, values_per_cluster)

            # Convert row indices to indices into flattened table.
            top_table_id_by_row = top_id_by_row + row_offset.reshape(
                1, -1) + cluster_offset[top_indices].repeat(rows_per_cluster,
                                                            axis=-1)
            # Get best ids into flattened table.
            split_topk_ids = jut.matmul_slice(top_table_id_by_row, top_row_idx)

            split_topk_scores = jut.matmul_slice(top_score_by_row, top_row_idx)

            return split_topk_values, split_topk_scores, split_topk_ids
예제 #5
0
    def matrix_power_multiply(self,
                              x,
                              power,
                              transpose=False,
                              precision=jax.lax.Precision.HIGHEST):
        """Computes matrix vector product jnp.linalg.matrix_power(M, power) @ x.

    Args:
      x: the matrix or vector to multiply with the matrix M.
      power: the power to raise M to. Note that this power must be less than or
        equal to max_power given to init_matrix_power_state.
      transpose: if True, computes the product with M_T^power instead.
      precision: precision with which matrix multiplcations are performed.

    Returns:
      the matrix-vector product jnp.linalg.matrix_power(M, power) @ x.
    """
        chex.assert_rank(x, {1, 2})

        if transpose:
            base_matrix = self.base_matrix.T
        else:
            base_matrix = self.base_matrix

        z = base_matrix
        n, bit = jnp.divmod(power, 2)
        r = jnp.where(bit, jnp.dot(z, x, precision=precision), x)

        def cond(state):
            n, _, _ = state
            return n > 0

        def body(state):
            n, z, r = state
            z = jnp.dot(z, z, precision=precision)
            n, bit = jnp.divmod(n, 2)
            r = jnp.where(bit, jnp.dot(z, r, precision=precision), r)
            return n, z, r

        _, _, result = jax.lax.while_loop(cond, body, (n, z, r))
        return result
예제 #6
0
    def __iter__(self):
        num_states = self.env.num_states
        states = jax.nn.one_hot(jnp.arange(num_states), num_states)
        num_complete_batches, leftover = jnp.divmod(num_states,
                                                    self.config.batch_size)
        num_batches = num_complete_batches + bool(leftover)
        assert num_batches > 0

        while True:
            self.key, subkey = jax.random.split(self.key)
            perms = jax.random.permutation(subkey, num_states)

            for i in range(num_batches):
                batch_idx_states = perms[i * self.config.batch_size:(i + 1) *
                                         self.config.batch_size]
                if len(batch_idx_states) != self.config.batch_size:
                    break

                inputs = states[batch_idx_states]
                targets = self.aux_task_matrix[batch_idx_states]

                yield inputs, targets
예제 #7
0
def divmod(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.divmod(x1, x2))
예제 #8
0
 def body(state):
     n, z, r = state
     z = jnp.dot(z, z, precision=precision)
     n, bit = jnp.divmod(n, 2)
     r = jnp.where(bit, jnp.dot(z, r, precision=precision), r)
     return n, z, r
예제 #9
0
    def fit_sgd(self,
                observations,
                targets,
                batch_size,
                rng_key=None,
                optimizer=None,
                num_epochs=1):
        '''
        Fits the class conditional bernoulli mixture model using gradient descent algorithm with the given hyperparameters.
        Parameters
        ----------
        observations : array
            The observation sequences which Bernoulli Mixture Model is trained on
        targets : array
            The ground-truth classes
        batch_size : int
            The size of the batch
        rng_key : array
            Random key of shape (2,) and dtype uint32
        optimizer : jax.experimental.optimizers.Optimizer
            Optimizer to be used
        num_epochs : int
            The number of epoch the training process takes place

        Returns
        -------
        * array
            Mean loss values found per epoch
        '''
        global opt_init, opt_update, get_params

        if rng_key is None:
            rng_key = PRNGKey(0)

        if optimizer is not None:
            opt_init, opt_update, get_params = optimizer

        opt_state = opt_init((logit(self.mixing_coeffs), logit(self.probs)))
        itercount = itertools.count()

        num_complete_batches, leftover = jnp.divmod(num_epochs, batch_size)
        num_batches = num_complete_batches + jnp.where(leftover == 0, 0, 1)

        def epoch_step(opt_state, key):
            perm = permutation(key, len(observations))
            _observatios, _targets = observations[perm], targets[perm]
            sample_generator = self._sample_minibatches(
                (_observatios, _targets), batch_size)

            def train_step(opt_state, i):
                opt_state, loss = self.update(next(itercount), opt_state,
                                              next(sample_generator))
                return opt_state, loss

            opt_state, losses = scan(train_step, opt_state,
                                     jnp.arange(num_batches))
            return opt_state, losses.mean()

        epochs = split(rng_key, num_epochs)
        opt_state, history = scan(epoch_step, opt_state, epochs)
        params = get_params(opt_state)
        mixing_coeffs_logits, probs_logits = params

        self.model = (mixing_coeffs_logits, probs_logits)
        self.mixing_coeffs = expit(mixing_coeffs_logits)
        self.probs = expit(probs_logits)
        return history