Exemple #1
0
 def grad_fn(g):
     g_logits = jnp.expand_dims(g,
                                axis=-1) * (exp_shifted / sum_exp - targets)
     return jnp.asarray(g_logits,
                        logits.dtype), jnp.asarray(g, targets.dtype)
Exemple #2
0
 def concrete_bound(relax_params):
   return self._bind(
       scanner.node_relaxations, relax_params).concrete_bound_chunk(
           graph, inputs, env, node_ref, jnp.expand_dims(one_obj, 0))
    def dot_product_attention(
        self,
        query,
        key,
        value,
        dtype=jnp.float32,
        bias=None,
        axis=None,
        broadcast_dropout=True,
        dropout_rng=None,
        dropout_rate=0.0,
        deterministic=False,
        precision=None,
    ):

        assert key.shape[:-1] == value.shape[:-1]
        assert query.shape[0:1] == key.shape[0:1] and query.shape[
            -1] == key.shape[-1]
        if axis is None:
            axis = tuple(range(1, key.ndim - 2))
        if not isinstance(axis, Iterable):
            axis = (axis, )
        assert key.ndim == query.ndim
        assert key.ndim == value.ndim
        for ax in axis:
            if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
                raise ValueError(
                    "Attention axis must be between the batch axis and the last-two axes."
                )
        n = key.ndim

        # Constructing projection tensor.
        if self.redraw_features:
            # TODO(kchoro): Get rid of the constant below.
            query_seed = lax.convert_element_type(
                jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
            rng = random.PRNGKey(query_seed)
            self.projection_matrix = self.draw_weights(rng)

        # batch_dims is  <bs, <non-attention dims>, num_heads>
        batch_dims = tuple(onp.delete(range(n), axis + (n - 1, )))
        # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
        qk_perm = batch_dims + axis + (n - 1, )
        k_extra_perm = axis + batch_dims + (n - 1, )
        key_extra = key.transpose(k_extra_perm)
        key = key.transpose(qk_perm)
        query = query.transpose(qk_perm)
        # v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
        v_perm = batch_dims + axis + (n - 1, )
        value = value.transpose(v_perm)
        batch_dims_t = tuple(range(len(batch_dims)))
        attention_dims_t = tuple(
            range(len(batch_dims),
                  len(batch_dims) + len(axis)))

        # Constructing tensors Q^{'} and K^{'}.
        query_prime = self.kernel_feature_creator(query,
                                                  self.projection_matrix,
                                                  attention_dims_t,
                                                  batch_dims_t, precision,
                                                  True)
        key_prime = self.kernel_feature_creator(key, self.projection_matrix,
                                                attention_dims_t, batch_dims_t,
                                                precision, False)

        if self.unidirectional:
            index = attention_dims_t[0]
            z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
                key_prime.shape[-1], ) + (value.shape[-1], )

            numerator_fn = _numerator(z_slice_shape, precision,
                                      self.lax_scan_unroll)
            W = numerator_fn(jnp.moveaxis(query_prime, index, 0),
                             jnp.moveaxis(key_prime, index, 0),
                             jnp.moveaxis(value, index, 0))

            # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
            W = jnp.moveaxis(W, 0, index)

            if not self.renormalize_attention:
                # Unidirectional, not-normalized attention.
                perm_inv = _invert_perm(qk_perm)
                result = W.transpose(perm_inv)
                return result
            else:
                # Unidirectional, normalized attention.
                thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
                    key_extra.shape[0:len(axis)])

                index = attention_dims_t[0]
                t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
                    key_prime.shape[-1], )
                denominator_fn = _denominator(t_slice_shape, precision,
                                              self.lax_scan_unroll)
                R = denominator_fn(jnp.moveaxis(query_prime, index, 0),
                                   jnp.moveaxis(key_prime, index, 0))

                R = jnp.moveaxis(R, 0, index)
        else:
            contract_query = tuple(
                range(
                    len(batch_dims) + len(axis),
                    len(batch_dims) + len(axis) + 1))
            contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
            # Constructing Z = (K^{'})^{T}V
            # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
            Z = lax.dot_general(
                key_prime,
                value,
                ((attention_dims_t, attention_dims_t),
                 (batch_dims_t, batch_dims_t)),
                precision=precision,
            )
            # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
            # q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
            # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
            # W (bs,  <non-attention dims>, num_heads, <attention dims>, channels_v)
            W = lax.dot_general(query_prime,
                                Z, ((contract_query, contract_z),
                                    (batch_dims_t, batch_dims_t)),
                                precision=precision)
            if not self.renormalize_attention:
                # Bidirectional, not-normalized attention.
                perm_inv = _invert_perm(qk_perm)
                result = W.transpose(perm_inv)
                return result
            else:
                # Bidirectional, normalized attention.
                thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
                    key_extra.shape[0:len(axis)])
                contract_key = tuple(
                    range(len(batch_dims),
                          len(batch_dims) + len(axis)))
                contract_thick_all_ones = tuple(
                    range(thick_all_ones.ndim - len(axis),
                          thick_all_ones.ndim))
                # Construct T = (K^{'})^{T} 1_L
                # k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
                T = lax.dot_general(
                    key_prime,
                    thick_all_ones,
                    ((contract_key, contract_thick_all_ones),
                     (batch_dims_t, batch_dims_t)),
                    precision=precision,
                )

                # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
                # q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
                # T   (bs, <non-attention dims>, num_heads, channels_m)
                R = lax.dot_general(
                    query_prime,
                    T,
                    (((query_prime.ndim - 1, ), (T.ndim - 1, )),
                     (batch_dims_t, range(0,
                                          len(T.shape) - 1))),
                    precision=precision,
                )

        R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <=
                                                 self.numerical_stabilizer)
        R = jnp.reciprocal(R)
        R = jnp.expand_dims(R, len(R.shape))
        # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
        # R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
        result = W * R
        # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
        perm_inv = _invert_perm(qk_perm)
        result = result.transpose(perm_inv)
        return result
Exemple #4
0
    def apply(self, inputs, info, config, train=False, cache=None):
        """Apply the full IPAGNN model to a batch of input programs.

    Args:
      inputs: A dictionary with the following fields, each with a leading batch
        dimension.
        - true_branch_nodes: For each node in the statement-level control flow
            graph, the index of the node that would be reached if the true
            branch were followed. If not a branch node, this is simply the index
            of the next node and matches the index given by false_indexes.
        - false_branch_nodes: For each node in the statement-level control flow
            graph, the index of the node that would be reached if the false
            branch were followed. If not a branch node, this is simply the index
            of the next node and matches the index given by true_indexes.
        - start_index: The node index where the function starts.
        - exit_index: The node index of the exit-node. Both the true- and
            false- index of the exit node are the exit node itself.
        - steps: The maximum number of model steps to take for a particular
            program.
        - data: Has shape (4, number of nodes). Each 4-tuple represents a single
            statement in the program. The meaning of each entry in a 4-tuple is
            described in Figure 1 of the paper.
      info: Information about the dataset.
      config: The experimental config.
      train: (bool) Whether the model is being trained.
      cache: Unused.

    Returns:
      The logits predicted from each program in the batch's output nodes.
    """
        # Inputs
        true_indexes = inputs['true_branch_nodes']
        false_indexes = inputs['false_branch_nodes']
        start_indexes = inputs['start_index']  # pylint: disable=unused-variable
        exit_indexes = inputs['exit_index']
        steps_all = inputs['steps']
        vocab_size = info.features[info._builder.key('statements')].vocab_size  # pylint: disable=protected-access
        output_token_vocabulary_size = info.output_vocab_size
        hidden_size = config.model.hidden_size
        data = inputs['data'].astype('int32')
        batch_size, num_nodes, unused_statement_length = data.shape

        # An upper bound on the number of steps to take.
        max_steps = int(1.5 * info.max_diameter)

        # Init parameters
        def emb_init(key, shape, dtype=jnp.float32):
            return jax.random.uniform(key, shape, dtype,
                                      -config.initialization.maxval,
                                      config.initialization.maxval)

        embed = Embed.shared(num_embeddings=vocab_size,
                             features=hidden_size,
                             emb_init=emb_init,
                             name='embed')
        branch_decide_dense = nn.Dense.shared(
            name='branch_decide_dense',
            features=2,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6))
        cells = create_lstm_cells(config.model.rnn_cell.layers)
        lstm = StackedRNNCell.shared(cells=cells)
        output_dense = nn.Dense.shared(
            name='output_dense',
            features=output_token_vocabulary_size,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6))

        # Init state
        def _create_hidden_states():
            rng = jax.random.PRNGKey(0)
            return StackedRNNCell.initialize_carry(rng, cells, (
                batch_size,
                num_nodes,
            ), hidden_size)

        def _create_instruction_pointer():
            return jax.ops.index_add(
                jnp.zeros((
                    batch_size,
                    num_nodes,
                )),
                jax.ops.
                index[:, 0],  # TODO(dbieber): Use "start_index" instead of 0.
                1)

        hidden_states = _create_hidden_states()
        # leaves(hidden_states).shape: batch_size, num_nodes, hidden_size
        instruction_pointer = _create_instruction_pointer()
        # instruction_pointer.shape: batch_size, num_nodes,
        node_embeddings = embed(data)

        # node_embeddings.shape:
        #     batch_size, num_nodes, statement_length, hidden_size

        # Apply
        def execute_single_node(hidden_state, node_embedding):
            carry, _ = lax.scan(lstm, hidden_state, node_embedding)
            return carry

        execute = jax.vmap(execute_single_node)

        def branch_decide_single_node(hidden_state):
            # leaves(hidden_state).shape: hidden_size
            hidden_state_concat = jnp.concatenate(
                jax.tree_leaves(hidden_state), axis=0)
            return branch_decide_dense(hidden_state_concat)

        branch_decide = jax.vmap(branch_decide_single_node)

        def update_instruction_pointer(instruction_pointer, branch_decisions,
                                       true_indexes, false_indexes):
            # instruction_pointer.shape: num_nodes,
            # branch_decisions: num_nodes, 2,
            # true_indexes: num_nodes,
            # false_indexes: num_nodes
            p_true = branch_decisions[:, 0]
            p_false = branch_decisions[:, 1]
            true_contributions = jax.ops.segment_sum(p_true *
                                                     instruction_pointer,
                                                     true_indexes,
                                                     num_segments=num_nodes)
            false_contributions = jax.ops.segment_sum(p_false *
                                                      instruction_pointer,
                                                      false_indexes,
                                                      num_segments=num_nodes)
            return true_contributions + false_contributions

        def aggregate(hidden_states, instruction_pointer, branch_decisions,
                      true_indexes, false_indexes):
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # branch_decisions: num_nodes, 2,
            # true_indexes: num_nodes,
            # false_indexes: num_nodes
            p_true = branch_decisions[:, 0]
            p_false = branch_decisions[:, 1]
            denominators = update_instruction_pointer(instruction_pointer,
                                                      branch_decisions,
                                                      true_indexes,
                                                      false_indexes)
            denominators += 1e-7

            # denominator.shape: num_nodes,

            def aggregate_component(h):
                # h.shape: num_nodes
                # p_true.shape: num_nodes
                # instruction_pointer.shape: num_nodes
                true_contributions = jax.ops.segment_sum(
                    h * p_true * instruction_pointer,
                    true_indexes,
                    num_segments=num_nodes)
                false_contributions = jax.ops.segment_sum(
                    h * p_false * instruction_pointer,
                    false_indexes,
                    num_segments=num_nodes)
                # *_contributions.shape: num_nodes, hidden_size
                return (true_contributions +
                        false_contributions) / denominators

            aggregate_component = jax.vmap(aggregate_component,
                                           in_axes=1,
                                           out_axes=1)

            return jax.tree_map(aggregate_component, hidden_states)

        def step_single_example(hidden_states, instruction_pointer,
                                node_embeddings, true_indexes, false_indexes,
                                exit_index):
            # Execution (e.g. apply RNN)
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # node_embeddings.shape: num_nodes, statement_length, hidden_size
            hidden_state_contributions = execute(hidden_states,
                                                 node_embeddings)

            # leaves(hidden_state_contributions).shape: num_nodes, hidden_size

            # Use the exit node's hidden state as it's hidden state contribution
            # to avoid "executing" the exit node.
            def mask_h(h_contribution, h):
                return h_contribution.at[exit_index, :].set(h[exit_index, :])

            hidden_state_contributions = jax.tree_multimap(
                mask_h, hidden_state_contributions, hidden_states)

            # Branch decisions (e.g. Dense layer)
            branch_decision_logits = branch_decide(hidden_state_contributions)
            branch_decisions = nn.softmax(branch_decision_logits, axis=-1)

            # Update state
            instruction_pointer_new = update_instruction_pointer(
                instruction_pointer, branch_decisions, true_indexes,
                false_indexes)
            hidden_states_new = aggregate(hidden_state_contributions,
                                          instruction_pointer,
                                          branch_decisions, true_indexes,
                                          false_indexes)

            to_tag = {
                'branch_decisions': branch_decisions,
                'hidden_state_contributions': hidden_state_contributions,
                'hidden_states_before': hidden_states,
                'hidden_states': hidden_states_new,
                'instruction_pointer_before': instruction_pointer,
                'instruction_pointer': instruction_pointer_new,
                'true_indexes': true_indexes,
                'false_indexes': false_indexes,
            }
            return hidden_states_new, instruction_pointer_new, to_tag

        def compute_logits_single_example(hidden_states, instruction_pointer,
                                          exit_index, steps, node_embeddings,
                                          true_indexes, false_indexes):
            """single_example refers to selecting a single exit node hidden state."""

            # leaves(hidden_states).shape: num_nodes, hidden_size

            def step_(carry, _):
                hidden_states, instruction_pointer, index = carry
                hidden_states_new, instruction_pointer_new, to_tag = (
                    step_single_example(hidden_states, instruction_pointer,
                                        node_embeddings, true_indexes,
                                        false_indexes, exit_index))
                carry = jax.tree_multimap(
                    lambda new, old, index=index: jnp.where(
                        index < steps, new, old),
                    (hidden_states_new, instruction_pointer_new, index + 1),
                    (hidden_states, instruction_pointer, index + 1),
                )
                return carry, to_tag

            if config.model.ipagnn.checkpoint and not self.is_initializing():
                step_ = jax.checkpoint(step_)

            carry = (hidden_states, instruction_pointer, jnp.array([0]))
            (hidden_states, instruction_pointer,
             _), to_tag = lax.scan(step_, carry, None, length=max_steps)

            final_state = jax.tree_map(lambda hs: hs[exit_index],
                                       hidden_states)
            # leaves(final_state).shape: hidden_size
            final_state_concat = jnp.concatenate(jax.tree_leaves(final_state),
                                                 axis=0)
            logits = output_dense(final_state_concat)
            to_tag.update({
                'instruction_pointer_final': instruction_pointer,
                'hidden_states_final': hidden_states,
            })
            return logits, to_tag

        compute_logits = jax.vmap(compute_logits_single_example,
                                  in_axes=(0, 0, 0, 0, 0, 0, 0))

        logits, to_tag = compute_logits(hidden_states, instruction_pointer,
                                        exit_indexes, steps_all,
                                        node_embeddings, true_indexes,
                                        false_indexes)
        for key, value in to_tag.items():
            value = Tag(value, name=key)
        logits = jnp.expand_dims(logits, axis=1)
        return logits
def _cov_helper_with_p(data, p):
    return jnp.expand_dims(jnp.matmul(jnp.conj(jnp.transpose(data)),
                                      jnp.multiply(p[:, None], data)),
                           axis=0)
Exemple #6
0
 def expand_dims(self: TensorType, axis: int) -> TensorType:
     return type(self)(np.expand_dims(self.raw, axis=axis))
Exemple #7
0
def test_input_admin(t, y, r, t_test, y_test, r_test):
    """
    TODO: tidy this function up
    Order the inputs, remove duplicates, and index the train and test input locations.
    :param t: training inputs [N, 1]
    :param y: observations at the training inputs [N, 1]
    :param r: training spatial inputs
    :param t_test: testing inputs [N*, 1]
    :param y_test: observations at the test inputs [N*, 1]
    :param r_test: test spatial inputs
    :return:
        t_all: the combined and sorted training and test inputs [N + N*, 1]
        y_all: an array of observations y augmented with nans at test locations [N + N*, R]
        r_all: spatial inputs with nans at test locations [N + N*, R]
        dt_all: combined training and test step sizes, Δtₙ = tₙ - tₙ₋₁ [N + N*, 1]
        dt_train: training step sizes, Δtₙ = tₙ - tₙ₋₁ [N, 1]
        train_id: an array of indices corresponding to the training inputs [N, 1]
        test_id: an array of indices corresponding to the test inputs [N*, 1]
        mask: boolean array to signify training locations [N + N*, 1]
    """
    assert t.shape[0] == y.shape[0]
    if t.ndim < 2:
        t = np.expand_dims(t, 1)  # make 2-D
    if y.ndim < 2:
        y = np.expand_dims(y, 1)  # make 2-D
    if r is None:
        r = np.nan * t  # np.empty((1,) + x.shape[1:]) * np.nan
    if r.ndim < 2:
        r = np.expand_dims(r, 1)  # make 2-D
    ind = np.argsort(t[:, 0], axis=0)
    t_train = t[ind, ...]
    y_train = y[ind, ...]
    r_train = r[ind, ...]
    if t_test is None:
        t_test = np.empty((1, ) + t_train.shape[1:]) * np.nan
        r_test = np.empty((1, ) + t_train.shape[1:]) * np.nan
    else:
        if t_test.ndim < 2:
            t_test = np.expand_dims(t_test, 1)  # make 2-D
        test_sort_ind = np.argsort(t_test[:, 0], axis=0)
        t_test = t_test[test_sort_ind, ...]
        if y_test is not None:
            y_test = y_test[test_sort_ind, ...].reshape((-1, ) + y.shape[1:])
        if r_test is not None:
            r_test = r_test[test_sort_ind, ...]
        else:
            r_test = np.nan * t_test
    if not (t_test.shape[1] == t_train.shape[1]):
        t_test = np.concatenate([
            t_test[:, 0][:, None],
            np.nan * np.empty([t_test.shape[0], t_train.shape[1] - 1])
        ],
                                axis=1)
    # here we use non-JAX numpy to sort out indexing of these static arrays
    t_train_test = np.concatenate([t_train, t_test])
    keep_ind = ~np.isnan(t_train_test[:, 0])
    t_train_test = t_train_test[keep_ind, ...]
    if r_test.shape[1] != r_train.shape[
            1]:  # do spatial test points have different dimensionality to training points?
        r_test_nan = np.nan * np.zeros([r_test.shape[0], r_train.shape[1]])
    else:
        r_test_nan = r_test
    r_train_test = np.concatenate([r_train, r_test_nan])
    r_train_test = r_train_test[keep_ind, ...]
    t_ind = np.argsort(t_train_test[:, 0])
    t_all = t_train_test[t_ind]
    r_all = r_train_test[t_ind]
    reverse_ind = np.argsort(t_ind)
    n_train = t_train.shape[0]
    train_id = reverse_ind[:n_train]  # index the training locations
    test_id = reverse_ind[n_train:]  # index the test locations
    y_all = np.nan * np.zeros([
        t_all.shape[0], y_train.shape[1]
    ])  # observation vector with nans at test locations
    # y_all[reverse_ind[:n_train], ...] = y_train  # and the data at the train locations
    y_all = index_update(y_all, index[reverse_ind[:n_train]],
                         y_train)  # and the data at the train locations
    if y_test is not None:
        # y_all[reverse_ind[n_train:], ...] = y_test  # and the data at the train locations
        y_all = index_update(y_all, index[reverse_ind[n_train:]],
                             y_test)  # and the data at the train locations
    mask = np.ones_like(y_all, dtype=bool)
    # mask[train_id] = False
    mask = index_update(mask, index[train_id], False)
    dt_all = np.concatenate([np.array([0.0]), np.diff(t_all[:, 0])])
    return (np.array(t_all, dtype=np.float64), np.array(y_all,
                                                        dtype=np.float64),
            np.array(r_all,
                     dtype=np.float64), np.array(r_test, dtype=np.float64),
            np.array(dt_all,
                     dtype=np.float64), np.array(train_id, dtype=np.int64),
            np.array(test_id, dtype=np.int64), np.array(mask, dtype=bool))
Exemple #8
0
def add_batch_dim(values: types.Nest) -> types.NestedArray:
    return tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=0), values)
Exemple #9
0
Hs0 = Hs(stfc.z[0:1])
Hsf = Hs(stfc.z[-1:])

pHs0 = pHs(stfc.z[0:1])
pHsf = pHs(stfc.z[-1:])

Hc = ctfc.H

## DEFINE THE ASSUMED SOLUTION **************************************
z = stfc.z
z0 = z[0]
zf = z[-1]

## DEFINE SWITCHING FUNCTIONS ***************************************
phi1 = lambda a: np.expand_dims(\
                 1./(zf-z0)**3 * (-zf**2*(3.*z0-zf) + 6.*z0*zf*a - 3.*(z0+zf)*a**2 + 2.*a**3),1)
phi2 = lambda a: np.expand_dims(\
                1./(zf-z0)**3 * (-z0**2*(z0-3.*zf) - 6.*z0*zf*a + 3.*(z0+zf)*a**2 - 2.*a**3),1)
phi3 = lambda a: np.expand_dims(\
                1./(zf-z0)**2 * (-z0*zf**2 + zf*(2.*z0+zf)*a - (z0+2.*zf)*a**2 + a**3),1)
phi4 = lambda a: np.expand_dims(\
                1./(zf-z0)**2 * (-z0**2*zf + z0*(z0+2.*zf)*a - (2.*z0+zf)*a**2 + a**3),1)

## DEFINE CONSTRAINED EXPRESSION *************************************
r = lambda z, xi, IC: np.dot(Hs(z),xi['xis']) \
                    + phi1(z)*(IC['R0']             - np.dot(Hs0, xi['xis'])) \
                    + phi2(z)*(                     - np.dot(Hsf, xi['xis'])) \
                    + phi3(z)*(IC['V0']/IC['c']     - np.dot(pHs0,xi['xis'])) \
                    + phi4(z)*(                     - np.dot(pHsf,xi['xis']))

v = egrad(r)
Exemple #10
0
 def __call__(self, x: jnp.ndarray):
     initial_state = jax.tree_map(
         lambda v: v.astype(x.dtype),
         self.wrapped.initial_state(batch_size=x.shape[0]))
     x = jnp.expand_dims(x, axis=0)
     return self.unroller(self.wrapped, x, initial_state)
Exemple #11
0
    def __call__(
            self,
            pose_coeffs,
            betas=np.zeros(1),
    ):
        batch_size = pose_coeffs.shape[0]
        if self.use_pca or self.joint_rot_mode == "axisang":
            # Get axis angle from PCA components and coefficients
            # Remove global rot coeffs
            hand_pose_coeffs = pose_coeffs[:, self.rot:self.rot + self.ncomps]
            if self.use_pca:
                full_hand_pose = hand_pose_coeffs @ self.selected_comps
            else:
                full_hand_pose = hand_pose_coeffs

            # Concatenate back global rot
            full_pose = np.concatenate(
                (pose_coeffs[:, :self.rot], self.hands_mean + full_hand_pose),
                1)
            if self.root_rot_mode == "axisang":
                # compute rotation matrixes from axis-angle while skipping global rotation
                pose_map, rot_map = self._posemap_axisang(full_pose)
                root_rot = rot_map[:, :9].reshape(batch_size, 3, 3)
                rot_map = rot_map[:, 9:]
                pose_map = pose_map[:, 9:]
            else:
                # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6
                pose_map, rot_map = self._posemap_axisang(full_pose[:, 6:])
                if self.robust_rot:
                    root_rot = self._robust_compute_rotation_matrix_from_ortho6d(
                        full_pose[:, :6])
                else:
                    root_rot = self._compute_rotation_matrix_from_ortho6d(
                        full_pose[:, :6])
        elif self.joint_rot_mode == "rotmat":
            full_pose = pose_coeffs  # ! Dummy Assignment
            pose_rots = self._batch_rotprojs(pose_coeffs)
            rot_map = pose_rots[:, 1:].reshape((batch_size, -1))
            pose_map = self._subtract_flat_id(rot_map)
            root_rot = pose_rots[:, 0]
        elif self.joint_rot_mode == "quat":
            # we need th_rot_map, th_pose_map, root_rot
            # though do no assertion
            # th_pose_coeffs should be [B, 4 + 15 * 4] = [B, 64]
            full_pose = pose_coeffs  # ! Dummy Assignment
            batch_size = pose_coeffs.shape[0]
            pose_coeffs = pose_coeffs.reshape(
                (batch_size, 16, 4))  # [B. 16, 4]
            all_rots = quaternion_to_rotation_matrix(
                pose_coeffs)  # [B, 16, 3, 3]
            # flatten things out
            root_rot = all_rots[:, 0, :, :]  # [B, 3, 3]
            rot_map = all_rots[:, 1:, :].reshape(
                (batch_size, -1))  # [B, 15 * 9]
            pose_map = self._subtract_flat_id(rot_map)
        else:
            raise KeyError(
                "joint_rot_mode not found. shoule be one of 'axisang' or 'rotmat' or 'quat'. got {}"
                .format(self.joint_rot_mode))

        # Full axis angle representation with root joint
        if betas is None or betas.size == 1:
            v_shaped = np.matmul(self.shapedirs, self.betas.transpose(
                1, 0)).transpose((2, 0, 1)) + self.v_template
            j = np.matmul(self.J_regressor, v_shaped).tile((batch_size, 1, 1))
        else:
            v_shaped = np.matmul(self.shapedirs, betas.transpose(
                (1, 0))).transpose((2, 0, 1)) + self.v_template
            j = np.matmul(self.J_regressor, v_shaped)
            # th_pose_map should have shape 20x135

        v_posed = v_shaped + np.matmul(
            self.posedirs,
            pose_map.transpose((1, 0))[np.newaxis, ...]).transpose((2, 0, 1))
        # Final T pose with transformation done !

        # Global rigid transformation

        root_j = j[:, 0, :].reshape(batch_size, 3, 1)
        root_trans = self._with_zeros(np.concatenate((root_rot, root_j), 2))

        all_rots = rot_map.reshape(rot_map.shape[0], 15, 3, 3)
        lev1_idxs = [1, 4, 7, 10, 13]
        lev2_idxs = [2, 5, 8, 11, 14]
        lev3_idxs = [3, 6, 9, 12, 15]
        lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]]
        lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]]
        lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]]
        lev1_j = j[:, lev1_idxs]
        lev2_j = j[:, lev2_idxs]
        lev3_j = j[:, lev3_idxs]

        # From base to tips
        # Get lev1 results
        all_transforms = [root_trans[:, np.newaxis, ...]]
        lev1_j_rel = lev1_j - root_j.transpose((0, 2, 1))
        lev1_rel_transform_flt = self._with_zeros(
            np.concatenate((lev1_rots, lev1_j_rel[..., np.newaxis]),
                           3).reshape(-1, 3, 4))
        root_trans_flt = np.tile(root_trans[:, np.newaxis, ...],
                                 (1, 5, 1, 1)).reshape(root_trans.shape[0] * 5,
                                                       4, 4)
        lev1_flt = np.matmul(root_trans_flt, lev1_rel_transform_flt)
        all_transforms.append(lev1_flt.reshape(all_rots.shape[0], 5, 4, 4))

        # Get lev2 results
        lev2_j_rel = lev2_j - lev1_j
        lev2_rel_transform_flt = self._with_zeros(
            np.concatenate((lev2_rots, lev2_j_rel[..., np.newaxis]),
                           3).reshape(-1, 3, 4))
        lev2_flt = np.matmul(lev1_flt, lev2_rel_transform_flt)
        all_transforms.append(lev2_flt.reshape(all_rots.shape[0], 5, 4, 4))

        # Get lev3 results
        lev3_j_rel = lev3_j - lev2_j
        lev3_rel_transform_flt = self._with_zeros(
            np.concatenate((lev3_rots, lev3_j_rel[..., np.newaxis]),
                           3).reshape(-1, 3, 4))
        lev3_flt = np.matmul(lev2_flt, lev3_rel_transform_flt)
        all_transforms.append(lev3_flt.reshape(all_rots.shape[0], 5, 4, 4))

        reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15]
        results = np.concatenate(all_transforms, 1)[:, reorder_idxs]
        results_global = results

        joint_js = np.concatenate((j, np.zeros((j.shape[0], 16, 1))), 2)

        tmp2 = np.matmul(results, joint_js[..., np.newaxis])
        results2 = (results -
                    np.concatenate([np.zeros(
                        (*tmp2.shape[:2], 4, 3)), tmp2], 3)).transpose(
                            (0, 2, 3, 1))

        T = np.matmul(results2, self.weights.transpose((1, 0)))

        rest_shape_h = np.concatenate(
            (v_posed.transpose(
                (0, 2, 1)), np.ones((batch_size, 1, v_posed.shape[1]))),
            1,
        )

        verts = (T * rest_shape_h[:, np.newaxis, ...]).sum(2).transpose(
            (0, 2, 1))
        verts = verts[:, :, :3]
        jtr = results_global[:, :, :3, 3]
        # In addition to MANO reference joints we sample vertices on each finger
        # to serve as finger tips
        if self.side == "right":
            tips = verts[:, [745, 317, 444, 556, 673]]
        else:
            tips = verts[:, [745, 317, 445, 556, 673]]
        jtr = np.concatenate((jtr, tips), 1)

        # Reorder joints to match visualization utilities
        jtr = jtr[:, [
            0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8,
            9, 20
        ], ]

        # deal with center joint
        if self.center_idx is not None:
            center_joint = jtr[:, self.center_idx][:, np.newaxis, ...]
        else:  # ! Dummy Center Joint (B, 1, 3)
            center_joint = np.zeros_like(np.expand_dims(jtr[:, 0], 1))
        jtr = jtr - center_joint
        verts = verts - center_joint

        global_rot = results_global[:, :, :3, :3]  # (B, 16, 3, 3)
        global_t = results_global[:, :, :3, 3:]  # (B, 16, 3, 1)
        global_t = global_t - np.expand_dims(center_joint,
                                             -1)  # (B, [16], 3, 1)
        transf_global = np.concatenate([global_rot, global_t],
                                       axis=3)  # (B, 16, 3, 4)
        transf_global = self._with_zeros(transf_global.reshape((-1, 3, 4)))
        transf_global = transf_global.reshape((batch_size, 16, 4, 4))

        # Scale to milimeters
        # th_verts = th_verts * 1000
        # th_jtr = th_jtr * 1000
        results = [verts, jtr]  # (V, J)

        if self.return_transf:
            results = results + [transf_global]  # (V, J, T)
            if self.return_full_pose:
                results = results + [full_pose]  # (V, J, T, so3)
        elif self.return_full_pose:
            results = results + [full_pose]  # (V, J, so3)

        return tuple(results)
Exemple #12
0
def nonbonded_v3(
    conf,
    params,
    box,
    lamb,
    charge_rescale_mask,
    lj_rescale_mask,
    beta,
    cutoff,
    lambda_plane_idxs,
    lambda_offset_idxs,
    runtime_validate=True,
):
    """Lennard-Jones + Coulomb, with a few important twists:
    * distances are computed in 4D, controlled by lambda, lambda_plane_idxs, lambda_offset_idxs
    * each pairwise LJ and Coulomb term can be multiplied by an adjustable rescale_mask parameter
    * Coulomb terms are multiplied by erfc(beta * distance)

    Parameters
    ----------
    conf : (N, 3) or (N, 4) np.array
        3D or 4D coordinates
        if 3D, will be converted to 4D using (x,y,z) -> (x,y,z,w)
            where w = cutoff * (lambda_plane_idxs + lambda_offset_idxs * lamb)
    params : (N, 3) np.array
        columns [charges, sigmas, epsilons], one row per particle
    box : Optional 3x3 np.array
    lamb : float
    charge_rescale_mask : (N, N) np.array
        the Coulomb contribution of pair (i,j) will be multiplied by charge_rescale_mask[i,j]
    lj_rescale_mask : (N, N) np.array
        the Lennard-Jones contribution of pair (i,j) will be multiplied by lj_rescale_mask[i,j]
    beta : float
        the charge product q_ij will be multiplied by erfc(beta*d_ij)
    cutoff : Optional float
        a pair of particles (i,j) will be considered non-interacting if the distance d_ij
        between their 4D coordinates exceeds cutoff
    lambda_plane_idxs : Optional (N,) np.array
    lambda_offset_idxs : Optional (N,) np.array
    runtime_validate: bool
        check whether beta is compatible with cutoff
        (if True, this function will currently not play nice with Jax JIT)
        TODO: is there a way to conditionally print a runtime warning inside
            of a Jax JIT-compiled function, without triggering a Jax ConcretizationTypeError?

    Returns
    -------
    energy : float

    References
    ----------
    * Rodinger, Howell, Pomès, 2005, J. Chem. Phys. "Absolute free energy calculations by thermodynamic integration in four spatial
        dimensions" https://aip.scitation.org/doi/abs/10.1063/1.1946750
    * Darden, York, Pedersen, 1993, J. Chem. Phys. "Particle mesh Ewald: An N log(N) method for Ewald sums in large
    systems" https://aip.scitation.org/doi/abs/10.1063/1.470117
        * Coulomb interactions are treated using the direct-space contribution from eq 2
    """
    if runtime_validate:
        assert (charge_rescale_mask == charge_rescale_mask.T).all()
        assert (lj_rescale_mask == lj_rescale_mask.T).all()

    N = conf.shape[0]

    if conf.shape[-1] == 3:
        conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs,
                             cutoff)

    # make 4th dimension of box large enough so its roughly aperiodic
    if box is not None:
        if box.shape[-1] == 3:
            box_4d = np.eye(4) * 1000
            box_4d = index_update(box_4d, index[:3, :3], box)
        else:
            box_4d = box
    else:
        box_4d = None

    box = box_4d

    charges = params[:, 0]
    sig = params[:, 1]
    eps = params[:, 2]

    sig_i = np.expand_dims(sig, 0)
    sig_j = np.expand_dims(sig, 1)
    sig_ij = sig_i + sig_j

    eps_i = np.expand_dims(eps, 0)
    eps_j = np.expand_dims(eps, 1)

    eps_ij = eps_i * eps_j

    dij = distance(conf, box)

    keep_mask = np.ones((N, N)) - np.eye(N)
    keep_mask = np.where(eps_ij != 0, keep_mask, 0)

    if cutoff is not None:
        if runtime_validate:
            validate_coulomb_cutoff(cutoff, beta, threshold=1e-2)
        eps_ij = np.where(dij < cutoff, eps_ij, 0)

    # (ytz): this avoids a nan in the gradient in both jax and tensorflow
    sig_ij = np.where(keep_mask, sig_ij, 0)
    eps_ij = np.where(keep_mask, eps_ij, 0)

    inv_dij = 1 / dij
    inv_dij = np.where(np.eye(N), 0, inv_dij)

    sig2 = sig_ij * inv_dij
    sig2 *= sig2
    sig6 = sig2 * sig2 * sig2

    eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6
    eij_lj = np.where(keep_mask, eij_lj, 0)

    qi = np.expand_dims(charges, 0)  # (1, N)
    qj = np.expand_dims(charges, 1)  # (N, 1)
    qij = np.multiply(qi, qj)

    # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term.
    keep_mask = 1 - np.eye(N)
    qij = np.where(keep_mask, qij, 0)
    dij = np.where(keep_mask, dij, 0)

    # funny enough lim_{x->0} erfc(x)/x = 0
    eij_charge = np.where(keep_mask,
                          qij * erfc(beta * dij) * inv_dij,
                          0)  # zero out diagonals
    if cutoff is not None:
        eij_charge = np.where(dij > cutoff, 0, eij_charge)

    eij_total = eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask

    return np.sum(eij_total / 2)
Exemple #13
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
    ):
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        causal_attention_mask = None
        if self.causal:
            query_length, key_length = query.shape[1], key.shape[1]
            causal_attention_mask = self.causal_mask[:, :,
                                                     key_length - query_length:
                                                     key_length, :key_length]

        if attention_mask is not None and causal_attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_mask = combine_masks(attention_mask,
                                           causal_attention_mask,
                                           dtype="i4")
        elif causal_attention_mask is not None:
            attention_mask = causal_attention_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Exemple #14
0
H0 = H(tfc.z[0:1])
Hf = H(tfc.z[-1:])

Hp0 = pH(tfc.z[0:1])
Hpf = pH(tfc.z[-1:])

## DEFINE THE ASSUMED SOLUTION: *****************************************************************************
z = tfc.z
z0 = z[0]
zf = z[-1]

R0 = lambda xi: np.array([xi['X'],xi['Y'],xi['Z']]).flatten()
V0 = lambda xi: np.array([xi['dX'],xi['dY'],xi['dZ']]).flatten()/xi['b']**2

phi1 = lambda a:\
    np.expand_dims(1./(zf-z0)**3 * (-zf**2*(3.*z0-zf) + 6.*z0*zf*a - 3.*(z0+zf)*a**2 + 2.*a**3),1)
phi2 = lambda a:\
    np.expand_dims(1./(zf-z0)**3 * (-z0**2*(z0-3.*zf) - 6.*z0*zf*a + 3.*(z0+zf)*a**2 - 2.*a**3),1)
phi3 = lambda a:\
    np.expand_dims(1./(zf-z0)**2 * (-z0*zf**2 + zf*(2.*z0+zf)*a - (z0 + 2.*zf)*a**2 + a**3),1)
phi4 = lambda a:\
    np.expand_dims(1./(zf-z0)**2 * (-z0**2*zf + z0*(z0+2.*zf)*a - (2.*z0 + zf)*a**2 + a**3),1)

## CONSTRUCT THE CONSTRAINED EXPRESSION *********************************************************************
r = lambda z, xi: np.dot(H(z),xi['xis']) + phi1(z)*(R0(xi) - np.dot(H0, xi['xis'])) \
                                         + phi2(z)*(R0(xi) - np.dot(Hf, xi['xis'])) \
                                         + phi3(z)*(V0(xi) - np.dot(Hp0,xi['xis'])) \
                                         + phi4(z)*(V0(xi) - np.dot(Hpf,xi['xis']))

r1 = lambda z, xi: np.sqrt( (r(z,xi)[:,0]+mu   )**2 + r(z,xi)[:,1]**2 + r(z,xi)[:,2]**2)  # m1 to (x,y,z)
r2 = lambda z, xi: np.sqrt( (r(z,xi)[:,0]+mu-1.)**2 + r(z,xi)[:,1]**2 + r(z,xi)[:,2]**2)  # m2 to (x,y,z)
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        # get query proj
        query_states = self.q_proj(hidden_states)
        # get key, value proj
        if is_cross_attention:
            # cross_attentions
            key_states = self.k_proj(key_value_states)
            value_states = self.v_proj(key_value_states)
        else:
            # self_attention
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        # handle cache prepare causal attention mask
        if self.causal:
            query_length, key_length = query_states.shape[1], key_states.shape[
                1]
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states, value_states, query_states, attention_mask)

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape,
                         jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights,
                                 value_states)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights
Exemple #16
0
def scan_enum(f,
              init,
              xs,
              length,
              reverse,
              rng_key=None,
              substitute_stack=None):
    from numpyro.contrib.funsor import enum, config_enumerate, markov, trace as packed_trace

    # XXX: This implementation only works for history size=1 but can be
    # extended to history size > 1 by running `f` `history_size` times
    # for initialization. However, `sequential_sum_product` does not
    # support history size > 1, so we skip supporting it here.
    # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1.
    if reverse:
        x0 = tree_map(lambda x: x[-1], xs)
        xs_ = tree_map(lambda x: x[:-1], xs)
    else:
        x0 = tree_map(lambda x: x[0], xs)
        xs_ = tree_map(lambda x: x[1:], xs)

    carry_shape_at_t1 = None

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(
            ), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [
                jnp.shape(x) for x in tree_flatten(new_carry)[0]
            ]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)

    with markov():
        wrapped_carry = (0, rng_key, init)
        wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0)
        if length == 1:
            ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0)
            return wrapped_carry, (PytreeTrace({}), ys)
        wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry,
                                                     xs_, length - 1, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name
        leftmost_dim = min(site['infer']['dim_to_name'])
        site['infer']['dim_to_name'][leftmost_dim -
                                     1] = '_time_{}'.format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(
        lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys)
    # we also need to reshape `carry` to match sequential behavior
    if length % 2 == 0:
        t, rng_key, carry = wrapped_carry
        flatten_carry, treedef = tree_flatten(carry)
        flatten_carry = [
            jnp.reshape(x, t1_shape)
            for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)
        ]
        carry = tree_unflatten(treedef, flatten_carry)
        wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
Exemple #17
0
    def step(self, s, a):
        """Apply control, damping, boundary, and collision forces.

    Args:
      s: (p, v, misc), where p and v are [n_entities,2] jnp.float32,
         and misc is child defined
      a: [n_agents, dim_a] jnp.float32

    Returns:
      A state tuple (p, v, misc)
    """
        p, v, misc = s  # [n,2], [n,2], [a_shape]
        f = jnp.zeros_like(p)  # [n,2]
        n = p.shape[0]  # number of entities

        # Calculate control forces
        f_control = jnp.pad(a, ((0, n - a.shape[0]), (0, 0)),
                            mode="constant")  # [n, dim_a]
        f += f_control

        # Calculate damping forces
        f_damping = -1.0 * self.damping * v  # [n,2]
        f = f + f_damping

        # Calculate boundary forces
        bounce = (((p + self.radius >= self.max_p) & (v >= 0.0)) |
                  ((p - self.radius <= self.min_p) & (v <= 0.0)))  # [n,2]
        v_new = (-1.0 * bounce + 1.0 * ~bounce) * v  # [n,2]
        f_boundary = self.mass * (v_new - v) / self.dt  # [n,2]
        f = f + f_boundary

        # Calculate shared quantities for later calculations
        # same: [n,n,1], True if i==j
        same = jnp.expand_dims(jnp.eye(n, dtype=jnp.bool_), axis=-1)
        # p2p: [n,n,2], p2p[i,j,:] is the vector from entity i to entity j
        p2p = p - jnp.expand_dims(p, axis=1)
        # dist: [n,n,1], p2p[i,j,0] is the distance between i and j
        dist = jnp.linalg.norm(p2p, axis=-1, keepdims=True)
        # overlap: [n,n,1], overlap[i,j,0] is the overlap between i and j
        overlap = ((jnp.expand_dims(self.radius, axis=1) +
                    jnp.expand_dims(self.radius, axis=0)) - dist)
        if self.same_position_check:
            # ontop: [n,n,1], ontop[i,j,0] = True if i is at the exact location of j
            ontop = (dist == 0.0)
            # ontop_dir: [n,n,1], (1,0) above diagonal, (-1,0) below diagonal
            ontop_dir = jnp.stack(
                [jnp.triu(jnp.ones((n, n))) * 2 - 1,
                 jnp.zeros((n, n))],
                axis=-1)
            # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
            # direction of j from i
            contact_dir = (~ontop * p2p +
                           (ontop * ontop_dir)) / (~ontop * dist + ontop * 1.0)
        else:
            # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the
            # direction of j from i
            contact_dir = p2p / (dist + same)
        # collideable: [n,n,1], True if i and j are collideable
        collideable = (jnp.expand_dims(self.collideable, axis=1)
                       & jnp.expand_dims(self.collideable, axis=0))
        # overlap: [n,n,1], True if i,j overlap
        overlapping = overlap > 0

        # Calculate collision forces
        # Assume all entities collide with all entities, then mask out
        # non-collisions.
        #
        # For approaching, coliding entities, apply a forces
        # along the direction of collision that results in
        # relative velocities consistent with the coefficient of
        # restitution (c) and preservation of momentum in that
        # direction.
        # momentum: m_a*v_a + m_b*v_b = m_a*v'_a + m_b*v'_b
        # restitution: v'_b - v'_a = -c*(v_b-v_a)
        # solve for v'_a:
        #  v'_a = [m_a*v_a + m_b*v_b + m_b*c*(v_b-v_a)]/(m_a + m_b)
        #
        # v_contact_dir: [n,n] speed of i in dir of j
        v_contact_dir = jnp.sum(jnp.expand_dims(v, axis=-2) * contact_dir,
                                axis=-1)
        # v_approach: [n,n] speed that i,j are approaching each other
        v_approach = jnp.transpose(v_contact_dir) + v_contact_dir
        # momentum: [n,n] joint momentum in direction of contact (i->j)
        momentum = self.mass * v_contact_dir - jnp.transpose(
            self.mass * v_contact_dir)
        # v_result: [n,n] speed of i in dir of j after collision
        v_result = ((momentum + self.restitution * jnp.transpose(self.mass) *
                     (-v_approach)) / (self.mass + jnp.transpose(self.mass)))
        # f_collision: [n,n] force on i in dir of j to realize acceleration
        f_collision = self.mass * (v_result - v_contact_dir) / self.dt
        # f_collision: [n,n,2] force on i to realize acceleration due to
        # collision with j
        f_collision = jnp.expand_dims(f_collision, axis=-1) * contact_dir
        # collision_mask: [n,n,1]
        collision_mask = (collideable & overlapping & ~same &
                          (jnp.expand_dims(v_approach, axis=-1) > 0))
        # f_collision: [n,2], sum of collision forces on i
        f_collision = jnp.sum(f_collision * collision_mask, axis=-2)
        f = f + f_collision

        # Calculate overlapping spring forces
        # This corrects for any overlap due to discrete steps.
        # f_overlap: [n,n,2], force in the negative contact dir due to overlap
        f_overlap = -1.0 * contact_dir * overlap * self.overlap_spring_constant
        # overlapping_mask: [n,n,1], True if i,j are collideable, overlap,
        # and i != j
        overlapping_mask = collideable & overlapping & ~same
        # f_overlap: [n,2], sum of spring forces on i
        f_overlap = jnp.sum(f_overlap * overlapping_mask, axis=-2)
        f = f + f_overlap

        # apply forces
        v = v + (f / self.mass) * self.dt
        p = p + v * self.dt

        # update misc
        misc = self._update_misc((p, v, misc), a)  # pylint: disable=assignment-from-none

        return (p, v, misc)
        def loss_fn(
            model_config: ml_collections.FrozenConfigDict,
            model_params: Dict[Text, Any],
            model_vars: Dict[Text, Any],
            batch: Dict[Text, Any],
            deterministic: bool,
            dropout_rng: Optional[Dict[Text, Array]] = None,
        ) -> Tuple[float, MetricGroups, Dict[str, Any]]:
            """Loss function used by Ultra Fine Entity Typing task. See BaseTask."""

            variable_dict = {'params': model_params}
            variable_dict.update(model_vars)
            loss_helpers, _ = cls.build_model(model_config).apply(
                variable_dict,
                batch,
                deterministic=deterministic,
                rngs=dropout_rng)

            classifier_logits = loss_helpers['classifier_logits'].astype(
                jnp.float32)
            log_prob = jax.nn.log_sigmoid(classifier_logits)
            # log(1 - sigmoid(x)) = log_sigmoid(-x)
            # We use the latter since it is more numerically stable and denote it
            # as `log_comp_prob` (log of probability of the complimentary event).
            log_comp_prob = jax.nn.log_sigmoid(-classifier_logits)

            # batch['classifier_target'] has shape [batch_size, max_labels_per_sample]
            # and contain all labels in a sparse format. The code below converts
            # this to a dense format.
            classifier_labels = jax.nn.one_hot(batch['classifier_target'],
                                               NUM_CLASSES,
                                               dtype=jnp.float32)
            classifier_labels *= jnp.expand_dims(
                batch['classifier_target_mask'], -1)
            # Labels in a dense format with a shape [batch_size, NUM_CLASSES]
            classifier_labels = classifier_labels.sum(axis=1)
            loss_per_label = -log_prob * classifier_labels - log_comp_prob * (
                1.0 - classifier_labels)

            coarse_grained_weight = get_weight_per_group(
                classifier_labels, COARSE_CLASSES_START, COARSE_CLASSES_END)
            fine_grained_weight = get_weight_per_group(classifier_labels,
                                                       FINE_CLASSES_START,
                                                       FINE_CLASSES_END)
            ultra_fine_grained_weight = get_weight_per_group(
                classifier_labels, ULTRA_FINE_CLASSES_START,
                ULTRA_FINE_CLASSES_END)

            coarse_grained_loss = get_loss_per_group(loss_per_label,
                                                     coarse_grained_weight,
                                                     COARSE_CLASSES_START,
                                                     COARSE_CLASSES_END)
            fine_grained_loss = get_loss_per_group(loss_per_label,
                                                   fine_grained_weight,
                                                   FINE_CLASSES_START,
                                                   FINE_CLASSES_END)
            ultra_fine_grained_loss = get_loss_per_group(
                loss_per_label, ultra_fine_grained_weight,
                ULTRA_FINE_CLASSES_START, ULTRA_FINE_CLASSES_END)
            loss_per_sample = (coarse_grained_loss + fine_grained_loss +
                               ultra_fine_grained_loss)
            loss = loss_per_sample.sum()

            metrics = {
                'agg': {
                    'loss': loss,
                    'denominator': loss_per_sample.shape[0],
                },
                'coarse_grained': {
                    'loss': coarse_grained_loss.sum(),
                    'denominator': coarse_grained_weight.sum(),
                },
                'fine_grained': {
                    'loss': fine_grained_loss.sum(),
                    'denominator': fine_grained_weight.sum(),
                },
                'ultra_fine_grained': {
                    'loss': ultra_fine_grained_loss.sum(),
                    'denominator': ultra_fine_grained_weight.sum(),
                },
            }
            metrics.update(
                get_eval_metrics(classifier_labels, classifier_logits))
            return loss, metrics, {}
Exemple #19
0
        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                )
            )
            model_outputs = model(input_token, params=params, **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
            )

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(
                flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
            )
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(
                state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
            )
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
            )

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (
                jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
                & early_stopping
            )
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
            merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
            )

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
            model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )
def _zeros_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
    s = _shape(input)
    if isinstance(s, (np.ndarray, onp.generic)):
        return np.zeros(s, utils.numpy_dtype(dtype or input.dtype))
    return tf.zeros(s, dtype or s.dtype, name)


# --- Begin Public Functions --------------------------------------------------

concat = utils.copy_docstring(
    tf.concat,
    lambda values, axis, name='concat': (  # pylint: disable=g-long-lambda
        np.concatenate([ops.convert_to_tensor(v) for v in values], axis)))

expand_dims = utils.copy_docstring(
    tf.expand_dims, lambda input, axis, name=None: np.expand_dims(input, axis))

fill = utils.copy_docstring(
    tf.fill,
    lambda dims, value, name=None: value * np.ones(dims,
                                                   np.array(value).dtype))

gather = utils.copy_docstring(tf.gather, _gather)

gather_nd = utils.copy_docstring(tf.gather_nd, _gather_nd)

reverse = utils.copy_docstring(tf.reverse, _reverse)

linspace = utils.copy_docstring(
    tf.linspace,
    lambda start, stop, num, name=None: (  # pylint: disable=g-long-lambda
Exemple #21
0
    def apply(self,
              x,
              batch_stats=None,
              use_running_average=False,
              axis=-1,
              momentum=0.99,
              epsilon=1e-5,
              dtype=jnp.float32,
              bias=True,
              scale=True,
              bias_init=initializers.zeros,
              scale_init=initializers.ones,
              axis_name=None,
              axis_index_groups=None,
              virtual_batch_size=None,
              data_format=None):
        """Normalizes the input using batch statistics.

    Forked from the original flax nn.BatchNorm layer, this allows users to have
    multiple EMAs per device, one for each virtual batch size. For example, if
    the per-device batch size is 128 and the user specifies
    `virtual_batch_size=32`, 4 EMAs will be created on each device, each updated
    with 1/4 of the per-device batch on each forward pass.

    WARNING: the multiple per-device EMAs this creates need to be manually
    synchronized within each device before being used for evaluation, or when
    synchronizing batch norm statistic across devices.

    Args:
      x: the input to be normalized.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats
        will be used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of
        the batch statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma).
        When the next layer is linear (also e.g. nn.relu), this can be disabled
        since the scaling will be done by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
        example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
        examples on the first two and last two devices. See `jax.lax.psum` for
        more details.
      virtual_batch_size: the size of the virtual batches to construct on
        each device, which will be used to normalize sub-batches of each
        per-device batch. Will create a running average
        with a leading dim of size `x.shape[batch_axis] // virtual_batch_size`,
        one for each sub-batch. Note that the first dim of each state must be
        synchronized whenever synchronizing batch norm running averages. Must
        evenly divide the per-device batch size (as determined by `x`), and
        cannot be combined with `axis_index_groups`. Passing the default value
        of None will replicate the existing nn.BatchNorm behavior without
        virtual batches.
      data_format: only used when `virtual_batch_size` is set, to determine the
        batch axis.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
        batch_axis = _get_batch_axis(data_format, x, virtual_batch_size,
                                     use_running_average, axis_index_groups)
        if virtual_batch_size is None:
            virtual_batch_size = x.shape[batch_axis]

        if use_running_average:
            # Virtual batch norm is not used during evaluation, and we cannot
            # guarantee the train and eval batch sizes are the same, so we use a
            # single virtual batch of size batch_size, and take the first element in
            # the running average array, assuming they have been properly synced
            # across their first dim.
            virtual_batch_size = x.shape[batch_axis]

        x = jnp.asarray(x, jnp.float32)
        num_sub_batches = x.shape[batch_axis] // virtual_batch_size
        input_shape = x.shape
        axis = axis if isinstance(axis, tuple) else (axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        # Add an additional axis because we are going to reshape `x` to have a
        # leading dim of size `virtual_batch_size`.
        reduction_axis = tuple(i + 1 for i in range(x.ndim) if i not in axis)
        sub_batched_shape = (num_sub_batches, *x.shape[:batch_axis],
                             virtual_batch_size, *x.shape[batch_axis + 1:])
        x = jnp.reshape(x, sub_batched_shape)
        if self.is_stateful() or batch_stats:
            ra_means = self.state('batch_norm_running_mean',
                                  (num_sub_batches, *reduced_feature_shape),
                                  initializers.zeros,
                                  collection=batch_stats)
            ra_vars = self.state('batch_norm_running_var',
                                 (num_sub_batches, *reduced_feature_shape),
                                 initializers.ones,
                                 collection=batch_stats)
        else:
            ra_means = None
            ra_vars = None

        if use_running_average:
            if ra_means is None:
                raise ValueError(
                    'when use_running_averages is True '
                    'either use a stateful context or provide batch_stats')
            # Note that we assume that the values across the first axis have been
            # properly synchronized.
            mean = jnp.expand_dims(ra_means.value[0], 0)
            var = jnp.expand_dims(ra_vars.value[0], 0)
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            mean2 = jnp.mean(lax.square(x),
                             axis=reduction_axis,
                             keepdims=False)
            if axis_name is not None and not self.is_initializing():
                concatenated_mean = jnp.concatenate([mean, mean2])
                mean, mean2 = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=axis_name,
                              axis_index_groups=axis_index_groups), 2)
            var = mean2 - lax.square(mean)

            if ra_means and not self.is_initializing():
                ra_means.value = momentum * ra_means.value + (1 -
                                                              momentum) * mean
                ra_vars.value = momentum * ra_vars.value + (1 - momentum) * var

        y = x - mean.reshape((num_sub_batches, *feature_shape))
        mul = lax.rsqrt(
            var.reshape((num_sub_batches, *feature_shape)) + epsilon)
        if scale:
            mul = mul * self.param('scale', reduced_feature_shape,
                                   scale_init).reshape((1, *feature_shape))
        y = y * mul
        if bias:
            y = y + self.param('bias', reduced_feature_shape,
                               bias_init).reshape((1, *feature_shape))
        y = jnp.reshape(y, input_shape)
        return jnp.asarray(y, dtype)
Exemple #22
0
 def funx(x):
     if isinstance(x, RelaxVariable):
         return jnp.zeros_like(jnp.expand_dims(x.lower[0, ...], 0))
     else:
         return x
    def __call__(
        self,
        hidden_states,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        if not is_cross_attention:
            qkv_out = self.c_attn(hidden_states)
            query, key, value = jnp.split(qkv_out, 3, axis=2)
        else:
            q_out = self.q_attn(hidden_states)
            (query, ) = jnp.split(q_out, 1, axis=2)
            kv_out = self.c_attn(key_value_states)
            key, value = jnp.split(kv_out, 2, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.causal:
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
    def apply(self, inputs, info, config, train=False, cache=None):
        # Inputs
        output_token_vocabulary_size = info.output_vocab_size
        true_indexes = inputs['true_branch_nodes']
        false_indexes = inputs['false_branch_nodes']
        start_indexes = inputs['start_index']  # pylint: disable=unused-variable
        exit_indexes = inputs['exit_index']
        steps_all = inputs['steps']
        vocab_size = info.features[info._builder.key('statements')].vocab_size  # pylint: disable=protected-access
        hidden_size = config.model.hidden_size
        data = inputs['data'].astype('int32')
        batch_size, num_nodes, unused_statement_length = data.shape

        # An upper bound on the number of steps to take.
        max_steps = int(1.5 * info.max_diameter)

        # Init parameters
        def emb_init(key, shape, dtype=jnp.float32):
            return jax.random.uniform(key, shape, dtype,
                                      -config.initialization.maxval,
                                      config.initialization.maxval)

        embed = Embed.shared(num_embeddings=vocab_size,
                             features=hidden_size,
                             emb_init=emb_init,
                             name='embed')
        branch_decide_dense = nn.Dense.shared(
            name='branch_decide_dense',
            features=2,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6))
        cells = create_lstm_cells(config.model.rnn_cell.layers)
        lstm = StackedRNNCell.shared(cells=cells)
        if config.model.interpolant.apply_dense:
            dense_parent_to_true_child = nn.Dense.shared(
                name='dense_parent_to_true_child',
                features=hidden_size,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6))
            dense_parent_to_false_child = nn.Dense.shared(
                name='dense_parent_to_false_child',
                features=hidden_size,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6))
            dense_true_child_to_parent = nn.Dense.shared(
                name='dense_true_child_to_parent',
                features=hidden_size,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6))
            dense_false_child_to_parent = nn.Dense.shared(
                name='dense_false_child_to_parent',
                features=hidden_size,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6))
        if config.model.interpolant.apply_gru:
            gru_cell = nn.recurrent.GRUCell.shared(name='gru_cell')
        output_dense = nn.Dense.shared(
            name='output_dense',
            features=output_token_vocabulary_size,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6))

        # Apply
        def execute_single_node(hidden_state, node_embedding):
            carry, _ = lax.scan(lstm, hidden_state, node_embedding)
            return carry

        execute = jax.vmap(execute_single_node)  # Single example.

        def branch_decide_single_node(hidden_state):
            # leaves(hidden_state).shape: hidden_size
            hidden_state_concat = jnp.concatenate(
                jax.tree_leaves(hidden_state), axis=0)
            return branch_decide_dense(hidden_state_concat)

        branch_decide = jax.vmap(branch_decide_single_node)

        def update_instruction_pointer(instruction_pointer, branch_decisions,
                                       true_indexes, false_indexes):
            # instruction_pointer.shape: num_nodes,
            # branch_decisions: num_nodes, 2,
            # true_indexes: num_nodes,
            # false_indexes: num_nodes
            p_true = branch_decisions[:, 0]
            p_false = branch_decisions[:, 1]
            if not config.model.interpolant.use_b:
                p_true = jnp.ones_like(p_true)
                p_false = jnp.ones_like(p_false)
            if not config.model.interpolant.use_p:
                instruction_pointer = jnp.ones_like(instruction_pointer)
            true_contributions = jax.ops.segment_sum(p_true *
                                                     instruction_pointer,
                                                     true_indexes,
                                                     num_segments=num_nodes)
            false_contributions = jax.ops.segment_sum(p_false *
                                                      instruction_pointer,
                                                      false_indexes,
                                                      num_segments=num_nodes)
            return true_contributions + false_contributions

        def aggregate(hidden_states, instruction_pointer, branch_decisions,
                      true_indexes, false_indexes):
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # branch_decisions: num_nodes, 2,
            # true_indexes: num_nodes,
            # false_indexes: num_nodes
            p_true = branch_decisions[:, 0]
            p_false = branch_decisions[:, 1]
            if not config.model.interpolant.use_b:
                p_true = jnp.ones_like(p_true)
                p_false = jnp.ones_like(p_false)
            if not config.model.interpolant.use_p:
                instruction_pointer = jnp.ones_like(instruction_pointer)
            denominators = update_instruction_pointer(instruction_pointer,
                                                      branch_decisions,
                                                      true_indexes,
                                                      false_indexes)
            denominators += 1e-7
            # denominator.shape: num_nodes,
            if not config.model.interpolant.normalize:
                denominators = jnp.ones_like(denominators)

            def aggregate_component(h):
                # h.shape: num_nodes
                # p_true.shape: num_nodes
                # instruction_pointer.shape: num_nodes
                true_contributions = jax.ops.segment_sum(
                    h * p_true * instruction_pointer,
                    true_indexes,
                    num_segments=num_nodes)
                false_contributions = jax.ops.segment_sum(
                    h * p_false * instruction_pointer,
                    false_indexes,
                    num_segments=num_nodes)
                # *_contributions.shape: num_nodes, hidden_size
                return (true_contributions +
                        false_contributions) / denominators

            aggregate_component = jax.vmap(aggregate_component,
                                           in_axes=1,
                                           out_axes=1)

            return jax.tree_map(aggregate_component, hidden_states)

        def step_single_example(hidden_states, instruction_pointer,
                                node_embeddings, true_indexes, false_indexes,
                                exit_index):
            # Execution (e.g. apply RNN)
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # node_embeddings.shape: num_nodes, statement_length, hidden_size
            if config.model.interpolant.apply_code_rnn:
                hidden_state_contributions = execute(hidden_states,
                                                     node_embeddings)
                # leaves(hidden_state_contributions).shape: num_nodes, hidden_size
            else:
                hidden_state_contributions = hidden_states

            if config.model.interpolant.apply_dense:
                parent_to_true_child = jax.tree_map(
                    dense_parent_to_true_child, hidden_state_contributions)
                parent_to_false_child = jax.tree_map(
                    dense_parent_to_false_child, hidden_state_contributions)
                true_child_to_parent = jax.tree_map(
                    dense_true_child_to_parent, hidden_state_contributions)
                false_child_to_parent = jax.tree_map(
                    dense_false_child_to_parent, hidden_state_contributions)
            else:
                parent_to_true_child = hidden_state_contributions
                parent_to_false_child = hidden_state_contributions
                true_child_to_parent = hidden_state_contributions
                false_child_to_parent = hidden_state_contributions

            # Use the exit node's hidden state as it's hidden state contribution
            # to avoid "executing" the exit node.
            def mask_h(h_contribution, h):
                return h_contribution.at[exit_index, :].set(h[exit_index, :])

            hidden_state_contributions = jax.tree_multimap(
                mask_h, hidden_state_contributions, hidden_states)

            # Branch decisions (e.g. Dense layer)
            branch_decision_logits = branch_decide(hidden_state_contributions)
            branch_decisions = nn.softmax(branch_decision_logits, axis=-1)

            # Update state
            if config.model.interpolant.use_ipa:
                instruction_pointer_new = update_instruction_pointer(
                    instruction_pointer, branch_decisions, true_indexes,
                    false_indexes)
                hidden_states_new = aggregate(hidden_state_contributions,
                                              instruction_pointer,
                                              branch_decisions, true_indexes,
                                              false_indexes)
            else:
                assert config.model.interpolant.use_parent_embeddings
                assert config.model.interpolant.use_child_embeddings
                instruction_pointer_new = instruction_pointer
                normalization = jnp.sqrt(
                    2 + (  # Each node has a true and false child.
                        # jnp.bincount(true_indexes, minlength=num_nodes)
                        jax.ops.segment_sum(jnp.ones_like(true_indexes),
                                            true_indexes,
                                            num_segments=num_nodes)
                        # + jnp.bincount(false_indexes, minlength=num_nodes)
                        + jax.ops.segment_sum(jnp.ones_like(false_indexes),
                                              false_indexes,
                                              num_segments=num_nodes)))

                # normalization.shape: num_nodes,
                def aggregate_parent_and_child_contributions(p1, p2, c3, c4):
                    return (jax.ops.segment_sum(
                        p1, true_indexes, num_segments=num_nodes) +
                            jax.ops.segment_sum(
                                p2, false_indexes, num_segments=num_nodes) +
                            c3[true_indexes] +
                            c4[false_indexes]) / normalization[:, None]

                hidden_states_new = jax.tree_multimap(
                    aggregate_parent_and_child_contributions,
                    parent_to_true_child,
                    parent_to_false_child,
                    true_child_to_parent,  # true_child_to_parent[child] -> parent
                    false_child_to_parent)
            if config.model.interpolant.apply_gru:

                def apply_gru(h2, h1):
                    output, _ = gru_cell(h2, h1)
                    return output

                hidden_states_new = (jax.tree_multimap(apply_gru,
                                                       hidden_states_new,
                                                       hidden_states))

            to_tag = {
                'branch_decisions': branch_decisions,
                'hidden_state_contributions': hidden_state_contributions,
                'hidden_states_before': hidden_states,
                'hidden_states': hidden_states_new,
                'instruction_pointer_before': instruction_pointer,
                'instruction_pointer': instruction_pointer_new,
                'true_indexes': true_indexes,
                'false_indexes': false_indexes,
            }
            return hidden_states_new, instruction_pointer_new, to_tag

        def compute_logits_single_example(hidden_states, instruction_pointer,
                                          exit_index, steps, node_embeddings,
                                          true_indexes, false_indexes):
            """single_example refers to selecting a single exit node hidden state."""

            # leaves(hidden_states).shape: num_nodes, hidden_size

            def step_(carry, _):
                hidden_states, instruction_pointer, index = carry
                hidden_states_new, instruction_pointer_new, to_tag = (
                    step_single_example(hidden_states, instruction_pointer,
                                        node_embeddings, true_indexes,
                                        false_indexes, exit_index))
                carry = jax.tree_multimap(
                    lambda new, old, index=index: jnp.where(
                        index < steps, new, old),
                    (hidden_states_new, instruction_pointer_new, index + 1),
                    (hidden_states, instruction_pointer, index + 1),
                )
                return carry, to_tag

            if config.model.ipagnn.checkpoint and not self.is_initializing():
                step_ = jax.checkpoint(step_)

            carry = (hidden_states, instruction_pointer, jnp.array([0]))
            (hidden_states, instruction_pointer,
             _), to_tag = lax.scan(step_, carry, None, length=max_steps)

            final_state = jax.tree_map(lambda hs: hs[exit_index],
                                       hidden_states)
            # leaves(final_state).shape: hidden_size
            final_state_concat = jnp.concatenate(jax.tree_leaves(final_state),
                                                 axis=0)
            logits = output_dense(final_state_concat)
            to_tag.update({
                'instruction_pointer_final': instruction_pointer,
                'hidden_states_final': hidden_states,
            })
            return logits, to_tag

        compute_logits = jax.vmap(compute_logits_single_example,
                                  in_axes=(0, 0, 0, 0, 0, 0, 0))

        # Init state
        node_embeddings = embed(data)
        # node_embeddings.shape:
        #     batch_size, num_nodes, statement_length, hidden_size
        hidden_states = StackedRNNCell.initialize_carry(
            jax.random.PRNGKey(0), cells, (
                batch_size,
                num_nodes,
            ), hidden_size)
        if config.model.interpolant.init_with_code_embeddings:
            hidden_states = jax.vmap(execute)(hidden_states, node_embeddings)
        # leaves(hidden_states).shape: batch_size, num_nodes, hidden_size
        instruction_pointer = jax.ops.index_add(
            jnp.zeros((
                batch_size,
                num_nodes,
            )),
            jax.ops.index[:,
                          0],  # TODO(dbieber): Use "start_index" instead of 0.
            1)
        # instruction_pointer.shape: batch_size, num_nodes,

        logits, to_tag = compute_logits(hidden_states, instruction_pointer,
                                        exit_indexes, steps_all,
                                        node_embeddings, true_indexes,
                                        false_indexes)
        for key, value in to_tag.items():
            value = Tag(value, name=key)
        logits = jnp.expand_dims(logits, axis=1)
        return logits
def _cov_helper_without_p(data):
    return jnp.expand_dims(jnp.matmul(jnp.conj(jnp.transpose(data)), data),
                           axis=0)
Exemple #26
0
 def apply(rs: JaxArray, ids: JaxArray, **kwargs):
     ξ = np.expand_dims(f(rs, ids, **kwargs).flatten(), 0)
     Jξ = np.expand_dims(Jf(rs, ids, **kwargs).flatten(), 0)
     return ξ, Jξ
Exemple #27
0
def __render_comps(model, has_bulge, has_bar, n_spirals, shape,
                   oversample_n, base_roll):
    """Render the components of a galaxy builder model

    Arguments:
    model -- The model to render. This should be a dictionary like
        {(param, component): value} (i.e. what you would get from
        pd.Series(...).to_dict())
    has_bulge -- Whether the model contains a bulge, needed for Jax compilation
    has_bar -- Whether the model contains a bar
    n_spirals -- The number of spiral arms present in the model
    shape -- The desired output shape to render
    oversample_n -- The factor to which Sersic oversampling will be done
    base_roll -- The roll parameter of the original model. This is needed to
        preserve the location of spiral arms as best as possible
    """
    P, P_super = _make_xy_arrays(shape, oversample_n)

    out = {}

    disk_I = sersic_I(
        model[('disk', 'L')], model[('disk', 'Re')],
        model[('disk', 'q')], 1, 2
    )
    disk_super = sersic(
        *P_super,
        mux=model[('disk', 'mux')],
        muy=model[('disk', 'muy')],
        roll=model[('disk', 'roll')],
        q=model[('disk', 'q')],
        Re=model[('disk', 'Re')],
        I=disk_I,
        n=1.0,
        c=2.0,
    )

    out['disk'] = jnp.squeeze(downsample(disk_super, oversample_n))

    # next add spirals to the disk
    if n_spirals > 0:
        spirals = get_spirals(model, n_spirals, base_roll)
        spiral_distances = jnp.stack([
            vmap_polyline_distance(s, *P)
            for s in spirals
        ], axis=-1)

        Is = jnp.array([
            model[('spiral', 'I.{}'.format(i))]
            for i in range(n_spirals)
        ])
        spreads = jnp.array([
            model[('spiral', 'spread.{}'.format(i))] for i in range(n_spirals)
        ])
        spirals = jnp.sum(
            Is
            * jnp.exp(-spiral_distances**2 / (2*spreads**2))
            * jnp.expand_dims(out['disk'], -1),
            axis=-1
        )
        out['spiral'] = spirals
    else:
        spirals = jnp.zeros(shape)

    # calculate the luminosity of the disk and spirals together (the bulge and
    # bar fractions are calculated relative to this)
    disk_spiral_L = model[('disk', 'L')] + spirals.sum()

    # if we have a bulge, render it
    if has_bulge:
        # bulge_frac assumes we don't have a bar
        bulge_L = (
            model[('bulge', 'frac')] * (disk_spiral_L)
            / (1 - model[('bulge', 'frac')])
        )
        bulge_Re = model[('disk', 'Re')] * model[('bulge', 'scale')]
        bulge_I = sersic_I(
            bulge_L, bulge_Re, model[('bulge', 'q')], model[('bulge', 'n')]
        )
        bulge_super = sersic(
            *P_super,
            mux=model[('centre', 'mux')],
            muy=model[('centre', 'muy')],
            roll=model[('bulge', 'roll')],
            q=model[('bulge', 'q')],
            Re=bulge_Re,
            I=bulge_I,
            n=model[('bulge', 'n')],
            c=2.0
        )
        out['bulge'] = jnp.squeeze(downsample(bulge_super, oversample_n))

    # if we have a bar, render it
    if has_bar:
        # bar_frac assumes we don't have a bulge
        bar_L = (
            model[('bar', 'frac')] * (disk_spiral_L)
            / (1 - model[('bar', 'frac')])
        )
        bar_Re = model[('disk', 'Re')] * model[('bar', 'scale')]
        bar_I = sersic_I(
            bar_L, bar_Re, model[('bar', 'q')], model[('bar', 'n')]
        )
        bar_super = sersic(
            *P_super,
            mux=model[('centre', 'mux')],
            muy=model[('centre', 'muy')],
            roll=model[('bar', 'roll')],
            q=model[('bar', 'q')],
            Re=bar_Re,
            I=bar_I,
            n=model[('bar', 'n')],
            c=model[('bar', 'c')],
        )
        out['bar'] = jnp.squeeze(downsample(bar_super, oversample_n))

    # return the dictionary of rendered components
    return out
Exemple #28
0
    def call(
        self,
        query: jnp.ndarray,
        key: tp.Optional[jnp.ndarray] = None,
        value: tp.Optional[jnp.ndarray] = None,
        mask=None,
        training=None,
    ):
        """
        Arguments:
            inputs:  List of `[query, key, value]` where
                * `query`: np.ndarray of shape `(..., query_elements, query_depth)`
                * `key`: `np.ndarray of shape '(..., key_elements, key_depth)`
                * `value`: np.ndarray of shape `(..., key_elements, value_depth)`, optional, if not given `key` will be used.
            mask: a binary np.ndarray of shape `[batch_size?, num_heads?, query_elements, key_elements]`
                which specifies which query elements can attendo to which key elements,
                `1` indicates attention and `0` indicates no attention.
        Output shape:
            * `(..., query_elements, output_size)` if `output_size` is given, else
            * `(..., query_elements, value_depth)` if `value` is given, else
            * `(..., query_elements, key_depth)`
        """

        # einsum nomenclature
        # ------------------------
        # N = query elements
        # M = key/value elements
        # H = heads
        # I = input features
        # O = output features

        if key is None:
            key = query

        if value is None:
            value = key

        output_size = (self.output_size
                       if self.output_size is not None else value.shape[-1])

        # verify shapes
        if key.shape[-2] != value.shape[-2]:
            raise ValueError(
                "the number of elements in 'key' must be equal to the same as the number of elements in 'value'"
            )

        if mask is not None:
            if len(mask.shape) < 2:
                raise ValueError("'mask' must have atleast 2 dimensions")
            if query.shape[-2] != mask.shape[-2]:
                raise ValueError(
                    "mask's second to last dimension must be equal to the number of elements in 'query'"
                )
            if key.shape[-2] != mask.shape[-1]:
                raise ValueError(
                    "mask's last dimension must be equal to the number of elements in 'key'"
                )

        # get weights
        query_kernel = hooks.get_parameter(
            "query_kernel",
            [self.num_heads, query.shape[-1], self.head_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )
        key_kernel = hooks.get_parameter(
            "key_kernel",
            [self.num_heads, key.shape[-1], self.head_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )
        value_kernel = hooks.get_parameter(
            "value_kernel",
            [self.num_heads, value.shape[-1], self.head_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )
        projection_kernel = hooks.get_parameter(
            "projection_kernel",
            [self.num_heads, self.head_size, output_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )

        # Linear transformations
        query = jnp.einsum("...NI , HIO -> ...NHO", query, query_kernel)
        key = jnp.einsum("...MI , HIO -> ...MHO", key, key_kernel)
        value = jnp.einsum("...MI , HIO -> ...MHO", value, value_kernel)

        # Scale dot-product, doing the division to either query or key
        # instead of their product saves some computation
        query /= jnp.sqrt(self.head_size)

        # Calculate dot product attention
        logits = jnp.einsum("...NHO,...MHO->...HNM", query, key)

        # apply mask
        if mask is not None:
            mask = mask.astype(jnp.float32)

            # possibly expand on the head dimension so broadcasting works
            if len(mask.shape) != len(logits.shape):
                mask = jnp.expand_dims(mask, -3)

            logits += -10e9 * (1.0 - mask)

        attn_coef = jax.nn.softmax(logits)

        # attention dropout
        attn_coef_dropout = Dropout(self.droput_rate)(attn_coef,
                                                      training=training)

        # attention * value
        multihead_output = jnp.einsum("...HNM,...MHI->...NHI",
                                      attn_coef_dropout, value)

        # Run the outputs through another linear projection layer. Recombining heads
        # is automatically done.
        output = jnp.einsum("...NHI,HIO->...NO", multihead_output,
                            projection_kernel)

        if self.use_projection_bias:
            output += hooks.get_parameter(
                "projection_bias",
                [output_size],
                jnp.float32,
                initializer=self.bias_initializer,
            )

        if self.return_attn_coef:
            return output, attn_coef
        else:
            return output
Exemple #29
0
 def wrapper(*args, **kwargs):
     expand = lambda t: jnp.expand_dims(t, axis=axis)
     args = jax.tree_map(expand, args)
     kwargs = jax.tree_map(expand, kwargs)
     outputs = f(*args, **kwargs)
     return jax.tree_map(lambda t: jnp.squeeze(t, axis=axis), outputs)
Exemple #30
0
  def test_mention_memory_layer(self, separate_memory_values):
    """Testing memory attention layer."""

    test_utils.force_multi_devices(self.n_devices)
    devices = jax.local_devices()

    model = memory_attention_layer.MemoryAttentionLayer(
        memory_key_dim=self.memory_key_dim,
        input_dim=self.input_dim,
        memory_update_type=self.memory_update_type,
        memory_update_config=self.memory_update_config,
        k_top_device=self.k_top_device,
        k_top_post_selection=self.k_top_post_selection,
        splits=self.splits,
        dtype=self.dtype)

    static_argnums = (9) if separate_memory_values else (9, 10)
    pinit_with_output = jax.pmap(
        model.init_with_output,
        axis_name='batch',
        static_broadcasted_argnums=static_argnums)

    rng = jax.random.PRNGKey(0)
    split_rng = jax.random.split(rng, self.n_devices)
    encoded_input = jnp.ones(
        shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype)
    encoded_input = jax.device_put_replicated(encoded_input, devices)

    mention_batch_positions = jnp.tile(
        jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1)
    mention_batch_positions = jax.device_put_replicated(mention_batch_positions,
                                                        devices)

    mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz))
    mention_start_positions = jax.device_put_replicated(mention_start_positions,
                                                        devices)

    mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz))
    mention_end_positions = jax.device_put_replicated(mention_end_positions,
                                                      devices)

    n_mentions = mention_start_positions.shape[-1]

    mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz))
    mention_mask = jax.device_put_replicated(mention_mask, devices)

    memory_table = np.ones(
        (self.n_devices * self.table_size, self.memory_key_dim),
        dtype=self.dtype)
    # Make sure id 0 or 1 will be highest scoring
    memory_table[0] = memory_table[0] * 2.0
    memory_table[1] = memory_table[1] * -2.0
    memory_table = jnp.asarray(memory_table, dtype=self.dtype)

    memory_keys = memory_table.reshape(self.n_devices, self.rows,
                                       self.table_size // self.rows,
                                       self.memory_key_dim)

    memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices)
    if separate_memory_values:
      memory_values = memory_table.reshape(self.n_devices, self.table_size,
                                           self.memory_key_dim)
      memory_values = jax.device_put_sharded(list(memory_values), devices)
    else:
      memory_values = None

    memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape(
        self.n_devices, self.table_size)
    memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices)

    # Use entity id as identifier here
    memory_identifiers = memory_entity_ids

    (encoded_output, loss_helpers, _), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,
        text_identifiers=None,
    )

    attention_weights = loss_helpers['memory_attention_weights']
    entity_ids = loss_helpers['top_entity_ids']

    normed_input = encoded_input - 1.0

    # Check input was changed
    self.assertFalse(jnp.allclose(encoded_output, normed_input))

    # Check input was not changed where it should not be
    all_indices = set(
        itertools.product(np.arange(self.bsz), np.arange(self.seq_len)))
    # Note that mention positions is the same across all of the devices
    start_indices = set(
        zip(mention_batch_positions[0].tolist(),
            mention_start_positions[0].tolist()))
    non_start_indices = all_indices.difference(start_indices)
    non_start_indices_1, non_start_indices_2 = zip(*non_start_indices)
    non_start_indices_1 = jnp.asarray(non_start_indices_1)
    non_start_indices_2 = jnp.asarray(non_start_indices_2)

    non_start_outputs = encoded_output[:, non_start_indices_1,
                                       non_start_indices_2]
    non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2]
    self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs))

    # Check shapes as expected
    self.assertSequenceEqual(
        encoded_output.shape,
        (self.n_devices, self.bsz, self.seq_len, self.input_dim))

    self.assertSequenceEqual(
        attention_weights.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    self.assertSequenceEqual(
        entity_ids.shape,
        (self.n_devices, n_mentions, self.k_top_post_selection))

    # Check id 0 or 1 retrieved
    self.assertTrue(
        jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1)))

    # Set some text identifiers to 0 and others to 1 so that some are binding
    text_identifiers = np.zeros((n_mentions), dtype=np.int32)
    text_identifiers[:n_mentions // 2] = 1
    text_identifiers = jax.device_put_replicated(text_identifiers, devices)

    # Initialize and run one forward pass of model
    (_, loss_helpers, logging_helpers), _ = pinit_with_output(
        split_rng,
        encoded_input,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        mention_mask,
        memory_keys_sharded,
        memory_identifiers,
        memory_entity_ids,
        True,  # deterministic
        memory_values,  # memory_values
        text_identifiers=text_identifiers,
    )
    attention_weights_wid = loss_helpers['memory_attention_weights']
    entity_ids_wid = loss_helpers['top_entity_ids']
    n_disallowed = logging_helpers['n_disallowed'][0]

    # Check no effect on ids
    self.assertTrue(jnp.all(entity_ids == entity_ids_wid))

    # Check id 0 or 1 have 0 scores
    text_identifiers = jnp.expand_dims(text_identifiers, -1)
    score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid
    self.assertAlmostEqual(score_masked.sum(), 0.0)

    # Check number disallowed as expected
    self.assertEqual(n_disallowed, n_mentions // 2)