def grad_fn(g): g_logits = jnp.expand_dims(g, axis=-1) * (exp_shifted / sum_exp - targets) return jnp.asarray(g_logits, logits.dtype), jnp.asarray(g, targets.dtype)
def concrete_bound(relax_params): return self._bind( scanner.node_relaxations, relax_params).concrete_bound_chunk( graph, inputs, env, node_ref, jnp.expand_dims(one_obj, 0))
def dot_product_attention( self, query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, precision=None, ): assert key.shape[:-1] == value.shape[:-1] assert query.shape[0:1] == key.shape[0:1] and query.shape[ -1] == key.shape[-1] if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis, ) assert key.ndim == query.ndim assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError( "Attention axis must be between the batch axis and the last-two axes." ) n = key.ndim # Constructing projection tensor. if self.redraw_features: # TODO(kchoro): Get rid of the constant below. query_seed = lax.convert_element_type( jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32) rng = random.PRNGKey(query_seed) self.projection_matrix = self.draw_weights(rng) # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(onp.delete(range(n), axis + (n - 1, ))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1, ) k_extra_perm = axis + batch_dims + (n - 1, ) key_extra = key.transpose(k_extra_perm) key = key.transpose(qk_perm) query = query.transpose(qk_perm) # v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) v_perm = batch_dims + axis + (n - 1, ) value = value.transpose(v_perm) batch_dims_t = tuple(range(len(batch_dims))) attention_dims_t = tuple( range(len(batch_dims), len(batch_dims) + len(axis))) # Constructing tensors Q^{'} and K^{'}. query_prime = self.kernel_feature_creator(query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True) key_prime = self.kernel_feature_creator(key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False) if self.unidirectional: index = attention_dims_t[0] z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + ( key_prime.shape[-1], ) + (value.shape[-1], ) numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll) W = numerator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0)) # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V W = jnp.moveaxis(W, 0, index) if not self.renormalize_attention: # Unidirectional, not-normalized attention. perm_inv = _invert_perm(qk_perm) result = W.transpose(perm_inv) return result else: # Unidirectional, normalized attention. thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones( key_extra.shape[0:len(axis)]) index = attention_dims_t[0] t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + ( key_prime.shape[-1], ) denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll) R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0)) R = jnp.moveaxis(R, 0, index) else: contract_query = tuple( range( len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1)) contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1)) # Constructing Z = (K^{'})^{T}V # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v) Z = lax.dot_general( key_prime, value, ((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)), precision=precision, ) # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V # q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m) # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v) # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v) W = lax.dot_general(query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision) if not self.renormalize_attention: # Bidirectional, not-normalized attention. perm_inv = _invert_perm(qk_perm) result = W.transpose(perm_inv) return result else: # Bidirectional, normalized attention. thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones( key_extra.shape[0:len(axis)]) contract_key = tuple( range(len(batch_dims), len(batch_dims) + len(axis))) contract_thick_all_ones = tuple( range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim)) # Construct T = (K^{'})^{T} 1_L # k (bs, <non-attention dims>, num_heads, <attention dims>, channels) T = lax.dot_general( key_prime, thick_all_ones, ((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)), precision=precision, ) # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L # q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m) # T (bs, <non-attention dims>, num_heads, channels_m) R = lax.dot_general( query_prime, T, (((query_prime.ndim - 1, ), (T.ndim - 1, )), (batch_dims_t, range(0, len(T.shape) - 1))), precision=precision, ) R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer) R = jnp.reciprocal(R) R = jnp.expand_dims(R, len(R.shape)) # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v) # R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel) result = W * R # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) result = result.transpose(perm_inv) return result
def apply(self, inputs, info, config, train=False, cache=None): """Apply the full IPAGNN model to a batch of input programs. Args: inputs: A dictionary with the following fields, each with a leading batch dimension. - true_branch_nodes: For each node in the statement-level control flow graph, the index of the node that would be reached if the true branch were followed. If not a branch node, this is simply the index of the next node and matches the index given by false_indexes. - false_branch_nodes: For each node in the statement-level control flow graph, the index of the node that would be reached if the false branch were followed. If not a branch node, this is simply the index of the next node and matches the index given by true_indexes. - start_index: The node index where the function starts. - exit_index: The node index of the exit-node. Both the true- and false- index of the exit node are the exit node itself. - steps: The maximum number of model steps to take for a particular program. - data: Has shape (4, number of nodes). Each 4-tuple represents a single statement in the program. The meaning of each entry in a 4-tuple is described in Figure 1 of the paper. info: Information about the dataset. config: The experimental config. train: (bool) Whether the model is being trained. cache: Unused. Returns: The logits predicted from each program in the batch's output nodes. """ # Inputs true_indexes = inputs['true_branch_nodes'] false_indexes = inputs['false_branch_nodes'] start_indexes = inputs['start_index'] # pylint: disable=unused-variable exit_indexes = inputs['exit_index'] steps_all = inputs['steps'] vocab_size = info.features[info._builder.key('statements')].vocab_size # pylint: disable=protected-access output_token_vocabulary_size = info.output_vocab_size hidden_size = config.model.hidden_size data = inputs['data'].astype('int32') batch_size, num_nodes, unused_statement_length = data.shape # An upper bound on the number of steps to take. max_steps = int(1.5 * info.max_diameter) # Init parameters def emb_init(key, shape, dtype=jnp.float32): return jax.random.uniform(key, shape, dtype, -config.initialization.maxval, config.initialization.maxval) embed = Embed.shared(num_embeddings=vocab_size, features=hidden_size, emb_init=emb_init, name='embed') branch_decide_dense = nn.Dense.shared( name='branch_decide_dense', features=2, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) cells = create_lstm_cells(config.model.rnn_cell.layers) lstm = StackedRNNCell.shared(cells=cells) output_dense = nn.Dense.shared( name='output_dense', features=output_token_vocabulary_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) # Init state def _create_hidden_states(): rng = jax.random.PRNGKey(0) return StackedRNNCell.initialize_carry(rng, cells, ( batch_size, num_nodes, ), hidden_size) def _create_instruction_pointer(): return jax.ops.index_add( jnp.zeros(( batch_size, num_nodes, )), jax.ops. index[:, 0], # TODO(dbieber): Use "start_index" instead of 0. 1) hidden_states = _create_hidden_states() # leaves(hidden_states).shape: batch_size, num_nodes, hidden_size instruction_pointer = _create_instruction_pointer() # instruction_pointer.shape: batch_size, num_nodes, node_embeddings = embed(data) # node_embeddings.shape: # batch_size, num_nodes, statement_length, hidden_size # Apply def execute_single_node(hidden_state, node_embedding): carry, _ = lax.scan(lstm, hidden_state, node_embedding) return carry execute = jax.vmap(execute_single_node) def branch_decide_single_node(hidden_state): # leaves(hidden_state).shape: hidden_size hidden_state_concat = jnp.concatenate( jax.tree_leaves(hidden_state), axis=0) return branch_decide_dense(hidden_state_concat) branch_decide = jax.vmap(branch_decide_single_node) def update_instruction_pointer(instruction_pointer, branch_decisions, true_indexes, false_indexes): # instruction_pointer.shape: num_nodes, # branch_decisions: num_nodes, 2, # true_indexes: num_nodes, # false_indexes: num_nodes p_true = branch_decisions[:, 0] p_false = branch_decisions[:, 1] true_contributions = jax.ops.segment_sum(p_true * instruction_pointer, true_indexes, num_segments=num_nodes) false_contributions = jax.ops.segment_sum(p_false * instruction_pointer, false_indexes, num_segments=num_nodes) return true_contributions + false_contributions def aggregate(hidden_states, instruction_pointer, branch_decisions, true_indexes, false_indexes): # leaves(hidden_states).shape: num_nodes, hidden_size # instruction_pointer.shape: num_nodes, # branch_decisions: num_nodes, 2, # true_indexes: num_nodes, # false_indexes: num_nodes p_true = branch_decisions[:, 0] p_false = branch_decisions[:, 1] denominators = update_instruction_pointer(instruction_pointer, branch_decisions, true_indexes, false_indexes) denominators += 1e-7 # denominator.shape: num_nodes, def aggregate_component(h): # h.shape: num_nodes # p_true.shape: num_nodes # instruction_pointer.shape: num_nodes true_contributions = jax.ops.segment_sum( h * p_true * instruction_pointer, true_indexes, num_segments=num_nodes) false_contributions = jax.ops.segment_sum( h * p_false * instruction_pointer, false_indexes, num_segments=num_nodes) # *_contributions.shape: num_nodes, hidden_size return (true_contributions + false_contributions) / denominators aggregate_component = jax.vmap(aggregate_component, in_axes=1, out_axes=1) return jax.tree_map(aggregate_component, hidden_states) def step_single_example(hidden_states, instruction_pointer, node_embeddings, true_indexes, false_indexes, exit_index): # Execution (e.g. apply RNN) # leaves(hidden_states).shape: num_nodes, hidden_size # instruction_pointer.shape: num_nodes, # node_embeddings.shape: num_nodes, statement_length, hidden_size hidden_state_contributions = execute(hidden_states, node_embeddings) # leaves(hidden_state_contributions).shape: num_nodes, hidden_size # Use the exit node's hidden state as it's hidden state contribution # to avoid "executing" the exit node. def mask_h(h_contribution, h): return h_contribution.at[exit_index, :].set(h[exit_index, :]) hidden_state_contributions = jax.tree_multimap( mask_h, hidden_state_contributions, hidden_states) # Branch decisions (e.g. Dense layer) branch_decision_logits = branch_decide(hidden_state_contributions) branch_decisions = nn.softmax(branch_decision_logits, axis=-1) # Update state instruction_pointer_new = update_instruction_pointer( instruction_pointer, branch_decisions, true_indexes, false_indexes) hidden_states_new = aggregate(hidden_state_contributions, instruction_pointer, branch_decisions, true_indexes, false_indexes) to_tag = { 'branch_decisions': branch_decisions, 'hidden_state_contributions': hidden_state_contributions, 'hidden_states_before': hidden_states, 'hidden_states': hidden_states_new, 'instruction_pointer_before': instruction_pointer, 'instruction_pointer': instruction_pointer_new, 'true_indexes': true_indexes, 'false_indexes': false_indexes, } return hidden_states_new, instruction_pointer_new, to_tag def compute_logits_single_example(hidden_states, instruction_pointer, exit_index, steps, node_embeddings, true_indexes, false_indexes): """single_example refers to selecting a single exit node hidden state.""" # leaves(hidden_states).shape: num_nodes, hidden_size def step_(carry, _): hidden_states, instruction_pointer, index = carry hidden_states_new, instruction_pointer_new, to_tag = ( step_single_example(hidden_states, instruction_pointer, node_embeddings, true_indexes, false_indexes, exit_index)) carry = jax.tree_multimap( lambda new, old, index=index: jnp.where( index < steps, new, old), (hidden_states_new, instruction_pointer_new, index + 1), (hidden_states, instruction_pointer, index + 1), ) return carry, to_tag if config.model.ipagnn.checkpoint and not self.is_initializing(): step_ = jax.checkpoint(step_) carry = (hidden_states, instruction_pointer, jnp.array([0])) (hidden_states, instruction_pointer, _), to_tag = lax.scan(step_, carry, None, length=max_steps) final_state = jax.tree_map(lambda hs: hs[exit_index], hidden_states) # leaves(final_state).shape: hidden_size final_state_concat = jnp.concatenate(jax.tree_leaves(final_state), axis=0) logits = output_dense(final_state_concat) to_tag.update({ 'instruction_pointer_final': instruction_pointer, 'hidden_states_final': hidden_states, }) return logits, to_tag compute_logits = jax.vmap(compute_logits_single_example, in_axes=(0, 0, 0, 0, 0, 0, 0)) logits, to_tag = compute_logits(hidden_states, instruction_pointer, exit_indexes, steps_all, node_embeddings, true_indexes, false_indexes) for key, value in to_tag.items(): value = Tag(value, name=key) logits = jnp.expand_dims(logits, axis=1) return logits
def _cov_helper_with_p(data, p): return jnp.expand_dims(jnp.matmul(jnp.conj(jnp.transpose(data)), jnp.multiply(p[:, None], data)), axis=0)
def expand_dims(self: TensorType, axis: int) -> TensorType: return type(self)(np.expand_dims(self.raw, axis=axis))
def test_input_admin(t, y, r, t_test, y_test, r_test): """ TODO: tidy this function up Order the inputs, remove duplicates, and index the train and test input locations. :param t: training inputs [N, 1] :param y: observations at the training inputs [N, 1] :param r: training spatial inputs :param t_test: testing inputs [N*, 1] :param y_test: observations at the test inputs [N*, 1] :param r_test: test spatial inputs :return: t_all: the combined and sorted training and test inputs [N + N*, 1] y_all: an array of observations y augmented with nans at test locations [N + N*, R] r_all: spatial inputs with nans at test locations [N + N*, R] dt_all: combined training and test step sizes, Δtₙ = tₙ - tₙ₋₁ [N + N*, 1] dt_train: training step sizes, Δtₙ = tₙ - tₙ₋₁ [N, 1] train_id: an array of indices corresponding to the training inputs [N, 1] test_id: an array of indices corresponding to the test inputs [N*, 1] mask: boolean array to signify training locations [N + N*, 1] """ assert t.shape[0] == y.shape[0] if t.ndim < 2: t = np.expand_dims(t, 1) # make 2-D if y.ndim < 2: y = np.expand_dims(y, 1) # make 2-D if r is None: r = np.nan * t # np.empty((1,) + x.shape[1:]) * np.nan if r.ndim < 2: r = np.expand_dims(r, 1) # make 2-D ind = np.argsort(t[:, 0], axis=0) t_train = t[ind, ...] y_train = y[ind, ...] r_train = r[ind, ...] if t_test is None: t_test = np.empty((1, ) + t_train.shape[1:]) * np.nan r_test = np.empty((1, ) + t_train.shape[1:]) * np.nan else: if t_test.ndim < 2: t_test = np.expand_dims(t_test, 1) # make 2-D test_sort_ind = np.argsort(t_test[:, 0], axis=0) t_test = t_test[test_sort_ind, ...] if y_test is not None: y_test = y_test[test_sort_ind, ...].reshape((-1, ) + y.shape[1:]) if r_test is not None: r_test = r_test[test_sort_ind, ...] else: r_test = np.nan * t_test if not (t_test.shape[1] == t_train.shape[1]): t_test = np.concatenate([ t_test[:, 0][:, None], np.nan * np.empty([t_test.shape[0], t_train.shape[1] - 1]) ], axis=1) # here we use non-JAX numpy to sort out indexing of these static arrays t_train_test = np.concatenate([t_train, t_test]) keep_ind = ~np.isnan(t_train_test[:, 0]) t_train_test = t_train_test[keep_ind, ...] if r_test.shape[1] != r_train.shape[ 1]: # do spatial test points have different dimensionality to training points? r_test_nan = np.nan * np.zeros([r_test.shape[0], r_train.shape[1]]) else: r_test_nan = r_test r_train_test = np.concatenate([r_train, r_test_nan]) r_train_test = r_train_test[keep_ind, ...] t_ind = np.argsort(t_train_test[:, 0]) t_all = t_train_test[t_ind] r_all = r_train_test[t_ind] reverse_ind = np.argsort(t_ind) n_train = t_train.shape[0] train_id = reverse_ind[:n_train] # index the training locations test_id = reverse_ind[n_train:] # index the test locations y_all = np.nan * np.zeros([ t_all.shape[0], y_train.shape[1] ]) # observation vector with nans at test locations # y_all[reverse_ind[:n_train], ...] = y_train # and the data at the train locations y_all = index_update(y_all, index[reverse_ind[:n_train]], y_train) # and the data at the train locations if y_test is not None: # y_all[reverse_ind[n_train:], ...] = y_test # and the data at the train locations y_all = index_update(y_all, index[reverse_ind[n_train:]], y_test) # and the data at the train locations mask = np.ones_like(y_all, dtype=bool) # mask[train_id] = False mask = index_update(mask, index[train_id], False) dt_all = np.concatenate([np.array([0.0]), np.diff(t_all[:, 0])]) return (np.array(t_all, dtype=np.float64), np.array(y_all, dtype=np.float64), np.array(r_all, dtype=np.float64), np.array(r_test, dtype=np.float64), np.array(dt_all, dtype=np.float64), np.array(train_id, dtype=np.int64), np.array(test_id, dtype=np.int64), np.array(mask, dtype=bool))
def add_batch_dim(values: types.Nest) -> types.NestedArray: return tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=0), values)
Hs0 = Hs(stfc.z[0:1]) Hsf = Hs(stfc.z[-1:]) pHs0 = pHs(stfc.z[0:1]) pHsf = pHs(stfc.z[-1:]) Hc = ctfc.H ## DEFINE THE ASSUMED SOLUTION ************************************** z = stfc.z z0 = z[0] zf = z[-1] ## DEFINE SWITCHING FUNCTIONS *************************************** phi1 = lambda a: np.expand_dims(\ 1./(zf-z0)**3 * (-zf**2*(3.*z0-zf) + 6.*z0*zf*a - 3.*(z0+zf)*a**2 + 2.*a**3),1) phi2 = lambda a: np.expand_dims(\ 1./(zf-z0)**3 * (-z0**2*(z0-3.*zf) - 6.*z0*zf*a + 3.*(z0+zf)*a**2 - 2.*a**3),1) phi3 = lambda a: np.expand_dims(\ 1./(zf-z0)**2 * (-z0*zf**2 + zf*(2.*z0+zf)*a - (z0+2.*zf)*a**2 + a**3),1) phi4 = lambda a: np.expand_dims(\ 1./(zf-z0)**2 * (-z0**2*zf + z0*(z0+2.*zf)*a - (2.*z0+zf)*a**2 + a**3),1) ## DEFINE CONSTRAINED EXPRESSION ************************************* r = lambda z, xi, IC: np.dot(Hs(z),xi['xis']) \ + phi1(z)*(IC['R0'] - np.dot(Hs0, xi['xis'])) \ + phi2(z)*( - np.dot(Hsf, xi['xis'])) \ + phi3(z)*(IC['V0']/IC['c'] - np.dot(pHs0,xi['xis'])) \ + phi4(z)*( - np.dot(pHsf,xi['xis'])) v = egrad(r)
def __call__(self, x: jnp.ndarray): initial_state = jax.tree_map( lambda v: v.astype(x.dtype), self.wrapped.initial_state(batch_size=x.shape[0])) x = jnp.expand_dims(x, axis=0) return self.unroller(self.wrapped, x, initial_state)
def __call__( self, pose_coeffs, betas=np.zeros(1), ): batch_size = pose_coeffs.shape[0] if self.use_pca or self.joint_rot_mode == "axisang": # Get axis angle from PCA components and coefficients # Remove global rot coeffs hand_pose_coeffs = pose_coeffs[:, self.rot:self.rot + self.ncomps] if self.use_pca: full_hand_pose = hand_pose_coeffs @ self.selected_comps else: full_hand_pose = hand_pose_coeffs # Concatenate back global rot full_pose = np.concatenate( (pose_coeffs[:, :self.rot], self.hands_mean + full_hand_pose), 1) if self.root_rot_mode == "axisang": # compute rotation matrixes from axis-angle while skipping global rotation pose_map, rot_map = self._posemap_axisang(full_pose) root_rot = rot_map[:, :9].reshape(batch_size, 3, 3) rot_map = rot_map[:, 9:] pose_map = pose_map[:, 9:] else: # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6 pose_map, rot_map = self._posemap_axisang(full_pose[:, 6:]) if self.robust_rot: root_rot = self._robust_compute_rotation_matrix_from_ortho6d( full_pose[:, :6]) else: root_rot = self._compute_rotation_matrix_from_ortho6d( full_pose[:, :6]) elif self.joint_rot_mode == "rotmat": full_pose = pose_coeffs # ! Dummy Assignment pose_rots = self._batch_rotprojs(pose_coeffs) rot_map = pose_rots[:, 1:].reshape((batch_size, -1)) pose_map = self._subtract_flat_id(rot_map) root_rot = pose_rots[:, 0] elif self.joint_rot_mode == "quat": # we need th_rot_map, th_pose_map, root_rot # though do no assertion # th_pose_coeffs should be [B, 4 + 15 * 4] = [B, 64] full_pose = pose_coeffs # ! Dummy Assignment batch_size = pose_coeffs.shape[0] pose_coeffs = pose_coeffs.reshape( (batch_size, 16, 4)) # [B. 16, 4] all_rots = quaternion_to_rotation_matrix( pose_coeffs) # [B, 16, 3, 3] # flatten things out root_rot = all_rots[:, 0, :, :] # [B, 3, 3] rot_map = all_rots[:, 1:, :].reshape( (batch_size, -1)) # [B, 15 * 9] pose_map = self._subtract_flat_id(rot_map) else: raise KeyError( "joint_rot_mode not found. shoule be one of 'axisang' or 'rotmat' or 'quat'. got {}" .format(self.joint_rot_mode)) # Full axis angle representation with root joint if betas is None or betas.size == 1: v_shaped = np.matmul(self.shapedirs, self.betas.transpose( 1, 0)).transpose((2, 0, 1)) + self.v_template j = np.matmul(self.J_regressor, v_shaped).tile((batch_size, 1, 1)) else: v_shaped = np.matmul(self.shapedirs, betas.transpose( (1, 0))).transpose((2, 0, 1)) + self.v_template j = np.matmul(self.J_regressor, v_shaped) # th_pose_map should have shape 20x135 v_posed = v_shaped + np.matmul( self.posedirs, pose_map.transpose((1, 0))[np.newaxis, ...]).transpose((2, 0, 1)) # Final T pose with transformation done ! # Global rigid transformation root_j = j[:, 0, :].reshape(batch_size, 3, 1) root_trans = self._with_zeros(np.concatenate((root_rot, root_j), 2)) all_rots = rot_map.reshape(rot_map.shape[0], 15, 3, 3) lev1_idxs = [1, 4, 7, 10, 13] lev2_idxs = [2, 5, 8, 11, 14] lev3_idxs = [3, 6, 9, 12, 15] lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]] lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]] lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]] lev1_j = j[:, lev1_idxs] lev2_j = j[:, lev2_idxs] lev3_j = j[:, lev3_idxs] # From base to tips # Get lev1 results all_transforms = [root_trans[:, np.newaxis, ...]] lev1_j_rel = lev1_j - root_j.transpose((0, 2, 1)) lev1_rel_transform_flt = self._with_zeros( np.concatenate((lev1_rots, lev1_j_rel[..., np.newaxis]), 3).reshape(-1, 3, 4)) root_trans_flt = np.tile(root_trans[:, np.newaxis, ...], (1, 5, 1, 1)).reshape(root_trans.shape[0] * 5, 4, 4) lev1_flt = np.matmul(root_trans_flt, lev1_rel_transform_flt) all_transforms.append(lev1_flt.reshape(all_rots.shape[0], 5, 4, 4)) # Get lev2 results lev2_j_rel = lev2_j - lev1_j lev2_rel_transform_flt = self._with_zeros( np.concatenate((lev2_rots, lev2_j_rel[..., np.newaxis]), 3).reshape(-1, 3, 4)) lev2_flt = np.matmul(lev1_flt, lev2_rel_transform_flt) all_transforms.append(lev2_flt.reshape(all_rots.shape[0], 5, 4, 4)) # Get lev3 results lev3_j_rel = lev3_j - lev2_j lev3_rel_transform_flt = self._with_zeros( np.concatenate((lev3_rots, lev3_j_rel[..., np.newaxis]), 3).reshape(-1, 3, 4)) lev3_flt = np.matmul(lev2_flt, lev3_rel_transform_flt) all_transforms.append(lev3_flt.reshape(all_rots.shape[0], 5, 4, 4)) reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15] results = np.concatenate(all_transforms, 1)[:, reorder_idxs] results_global = results joint_js = np.concatenate((j, np.zeros((j.shape[0], 16, 1))), 2) tmp2 = np.matmul(results, joint_js[..., np.newaxis]) results2 = (results - np.concatenate([np.zeros( (*tmp2.shape[:2], 4, 3)), tmp2], 3)).transpose( (0, 2, 3, 1)) T = np.matmul(results2, self.weights.transpose((1, 0))) rest_shape_h = np.concatenate( (v_posed.transpose( (0, 2, 1)), np.ones((batch_size, 1, v_posed.shape[1]))), 1, ) verts = (T * rest_shape_h[:, np.newaxis, ...]).sum(2).transpose( (0, 2, 1)) verts = verts[:, :, :3] jtr = results_global[:, :, :3, 3] # In addition to MANO reference joints we sample vertices on each finger # to serve as finger tips if self.side == "right": tips = verts[:, [745, 317, 444, 556, 673]] else: tips = verts[:, [745, 317, 445, 556, 673]] jtr = np.concatenate((jtr, tips), 1) # Reorder joints to match visualization utilities jtr = jtr[:, [ 0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20 ], ] # deal with center joint if self.center_idx is not None: center_joint = jtr[:, self.center_idx][:, np.newaxis, ...] else: # ! Dummy Center Joint (B, 1, 3) center_joint = np.zeros_like(np.expand_dims(jtr[:, 0], 1)) jtr = jtr - center_joint verts = verts - center_joint global_rot = results_global[:, :, :3, :3] # (B, 16, 3, 3) global_t = results_global[:, :, :3, 3:] # (B, 16, 3, 1) global_t = global_t - np.expand_dims(center_joint, -1) # (B, [16], 3, 1) transf_global = np.concatenate([global_rot, global_t], axis=3) # (B, 16, 3, 4) transf_global = self._with_zeros(transf_global.reshape((-1, 3, 4))) transf_global = transf_global.reshape((batch_size, 16, 4, 4)) # Scale to milimeters # th_verts = th_verts * 1000 # th_jtr = th_jtr * 1000 results = [verts, jtr] # (V, J) if self.return_transf: results = results + [transf_global] # (V, J, T) if self.return_full_pose: results = results + [full_pose] # (V, J, T, so3) elif self.return_full_pose: results = results + [full_pose] # (V, J, so3) return tuple(results)
def nonbonded_v3( conf, params, box, lamb, charge_rescale_mask, lj_rescale_mask, beta, cutoff, lambda_plane_idxs, lambda_offset_idxs, runtime_validate=True, ): """Lennard-Jones + Coulomb, with a few important twists: * distances are computed in 4D, controlled by lambda, lambda_plane_idxs, lambda_offset_idxs * each pairwise LJ and Coulomb term can be multiplied by an adjustable rescale_mask parameter * Coulomb terms are multiplied by erfc(beta * distance) Parameters ---------- conf : (N, 3) or (N, 4) np.array 3D or 4D coordinates if 3D, will be converted to 4D using (x,y,z) -> (x,y,z,w) where w = cutoff * (lambda_plane_idxs + lambda_offset_idxs * lamb) params : (N, 3) np.array columns [charges, sigmas, epsilons], one row per particle box : Optional 3x3 np.array lamb : float charge_rescale_mask : (N, N) np.array the Coulomb contribution of pair (i,j) will be multiplied by charge_rescale_mask[i,j] lj_rescale_mask : (N, N) np.array the Lennard-Jones contribution of pair (i,j) will be multiplied by lj_rescale_mask[i,j] beta : float the charge product q_ij will be multiplied by erfc(beta*d_ij) cutoff : Optional float a pair of particles (i,j) will be considered non-interacting if the distance d_ij between their 4D coordinates exceeds cutoff lambda_plane_idxs : Optional (N,) np.array lambda_offset_idxs : Optional (N,) np.array runtime_validate: bool check whether beta is compatible with cutoff (if True, this function will currently not play nice with Jax JIT) TODO: is there a way to conditionally print a runtime warning inside of a Jax JIT-compiled function, without triggering a Jax ConcretizationTypeError? Returns ------- energy : float References ---------- * Rodinger, Howell, Pomès, 2005, J. Chem. Phys. "Absolute free energy calculations by thermodynamic integration in four spatial dimensions" https://aip.scitation.org/doi/abs/10.1063/1.1946750 * Darden, York, Pedersen, 1993, J. Chem. Phys. "Particle mesh Ewald: An N log(N) method for Ewald sums in large systems" https://aip.scitation.org/doi/abs/10.1063/1.470117 * Coulomb interactions are treated using the direct-space contribution from eq 2 """ if runtime_validate: assert (charge_rescale_mask == charge_rescale_mask.T).all() assert (lj_rescale_mask == lj_rescale_mask.T).all() N = conf.shape[0] if conf.shape[-1] == 3: conf = convert_to_4d(conf, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff) # make 4th dimension of box large enough so its roughly aperiodic if box is not None: if box.shape[-1] == 3: box_4d = np.eye(4) * 1000 box_4d = index_update(box_4d, index[:3, :3], box) else: box_4d = box else: box_4d = None box = box_4d charges = params[:, 0] sig = params[:, 1] eps = params[:, 2] sig_i = np.expand_dims(sig, 0) sig_j = np.expand_dims(sig, 1) sig_ij = sig_i + sig_j eps_i = np.expand_dims(eps, 0) eps_j = np.expand_dims(eps, 1) eps_ij = eps_i * eps_j dij = distance(conf, box) keep_mask = np.ones((N, N)) - np.eye(N) keep_mask = np.where(eps_ij != 0, keep_mask, 0) if cutoff is not None: if runtime_validate: validate_coulomb_cutoff(cutoff, beta, threshold=1e-2) eps_ij = np.where(dij < cutoff, eps_ij, 0) # (ytz): this avoids a nan in the gradient in both jax and tensorflow sig_ij = np.where(keep_mask, sig_ij, 0) eps_ij = np.where(keep_mask, eps_ij, 0) inv_dij = 1 / dij inv_dij = np.where(np.eye(N), 0, inv_dij) sig2 = sig_ij * inv_dij sig2 *= sig2 sig6 = sig2 * sig2 * sig2 eij_lj = 4 * eps_ij * (sig6 - 1.0) * sig6 eij_lj = np.where(keep_mask, eij_lj, 0) qi = np.expand_dims(charges, 0) # (1, N) qj = np.expand_dims(charges, 1) # (N, 1) qij = np.multiply(qi, qj) # (ytz): trick used to avoid nans in the diagonal due to the 1/dij term. keep_mask = 1 - np.eye(N) qij = np.where(keep_mask, qij, 0) dij = np.where(keep_mask, dij, 0) # funny enough lim_{x->0} erfc(x)/x = 0 eij_charge = np.where(keep_mask, qij * erfc(beta * dij) * inv_dij, 0) # zero out diagonals if cutoff is not None: eij_charge = np.where(dij > cutoff, 0, eij_charge) eij_total = eij_lj * lj_rescale_mask + eij_charge * charge_rescale_mask return np.sum(eij_total / 2)
def __call__( self, hidden_states, attention_mask=None, deterministic: bool = True, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) causal_attention_mask = None if self.causal: query_length, key_length = query.shape[1], key.shape[1] causal_attention_mask = self.causal_mask[:, :, key_length - query_length: key_length, :key_length] if attention_mask is not None and causal_attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4") elif causal_attention_mask is not None: attention_mask = causal_attention_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) if attention_mask is not None: attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
H0 = H(tfc.z[0:1]) Hf = H(tfc.z[-1:]) Hp0 = pH(tfc.z[0:1]) Hpf = pH(tfc.z[-1:]) ## DEFINE THE ASSUMED SOLUTION: ***************************************************************************** z = tfc.z z0 = z[0] zf = z[-1] R0 = lambda xi: np.array([xi['X'],xi['Y'],xi['Z']]).flatten() V0 = lambda xi: np.array([xi['dX'],xi['dY'],xi['dZ']]).flatten()/xi['b']**2 phi1 = lambda a:\ np.expand_dims(1./(zf-z0)**3 * (-zf**2*(3.*z0-zf) + 6.*z0*zf*a - 3.*(z0+zf)*a**2 + 2.*a**3),1) phi2 = lambda a:\ np.expand_dims(1./(zf-z0)**3 * (-z0**2*(z0-3.*zf) - 6.*z0*zf*a + 3.*(z0+zf)*a**2 - 2.*a**3),1) phi3 = lambda a:\ np.expand_dims(1./(zf-z0)**2 * (-z0*zf**2 + zf*(2.*z0+zf)*a - (z0 + 2.*zf)*a**2 + a**3),1) phi4 = lambda a:\ np.expand_dims(1./(zf-z0)**2 * (-z0**2*zf + z0*(z0+2.*zf)*a - (2.*z0 + zf)*a**2 + a**3),1) ## CONSTRUCT THE CONSTRAINED EXPRESSION ********************************************************************* r = lambda z, xi: np.dot(H(z),xi['xis']) + phi1(z)*(R0(xi) - np.dot(H0, xi['xis'])) \ + phi2(z)*(R0(xi) - np.dot(Hf, xi['xis'])) \ + phi3(z)*(V0(xi) - np.dot(Hp0,xi['xis'])) \ + phi4(z)*(V0(xi) - np.dot(Hpf,xi['xis'])) r1 = lambda z, xi: np.sqrt( (r(z,xi)[:,0]+mu )**2 + r(z,xi)[:,1]**2 + r(z,xi)[:,2]**2) # m1 to (x,y,z) r2 = lambda z, xi: np.sqrt( (r(z,xi)[:,0]+mu-1.)**2 + r(z,xi)[:,1]**2 + r(z,xi)[:,2]**2) # m2 to (x,y,z)
def __call__( self, hidden_states: jnp.ndarray, key_value_states: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] # get query proj query_states = self.q_proj(hidden_states) # get key, value proj if is_cross_attention: # cross_attentions key_states = self.k_proj(key_value_states) value_states = self.v_proj(key_value_states) else: # self_attention key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = self._split_heads(query_states) key_states = self._split_heads(key_states) value_states = self._split_heads(value_states) # handle cache prepare causal attention mask if self.causal: query_length, key_length = query_states.shape[1], key_states.shape[ 1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"][ "cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, : key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key_states, value_states, attention_mask = self._concatenate_to_cache( key_states, value_states, query_states, attention_mask) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) return attn_output, attn_weights
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None): from numpyro.contrib.funsor import enum, config_enumerate, markov, trace as packed_trace # XXX: This implementation only works for history size=1 but can be # extended to history size > 1 by running `f` `history_size` times # for initialization. However, `sequential_sum_product` does not # support history size > 1, so we skip supporting it here. # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1. if reverse: x0 = tree_map(lambda x: x[-1], xs) xs_ = tree_map(lambda x: x[:-1], xs) else: x0 = tree_map(lambda x: x[0], xs) xs_ = tree_map(lambda x: x[1:], xs) carry_shape_at_t1 = None def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes( ), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [ jnp.shape(x) for x in tree_flatten(new_carry)[0] ] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y) with markov(): wrapped_carry = (0, rng_key, init) wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0) if length == 1: ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0) return wrapped_carry, (PytreeTrace({}), ys) wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse) first_var = None for name, site in pytree_trace.trace.items(): # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name leftmost_dim = min(site['infer']['dim_to_name']) site['infer']['dim_to_name'][leftmost_dim - 1] = '_time_{}'.format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap( lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys) # we also need to reshape `carry` to match sequential behavior if length % 2 == 0: t, rng_key, carry = wrapped_carry flatten_carry, treedef = tree_flatten(carry) flatten_carry = [ jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape_at_t1) ] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def step(self, s, a): """Apply control, damping, boundary, and collision forces. Args: s: (p, v, misc), where p and v are [n_entities,2] jnp.float32, and misc is child defined a: [n_agents, dim_a] jnp.float32 Returns: A state tuple (p, v, misc) """ p, v, misc = s # [n,2], [n,2], [a_shape] f = jnp.zeros_like(p) # [n,2] n = p.shape[0] # number of entities # Calculate control forces f_control = jnp.pad(a, ((0, n - a.shape[0]), (0, 0)), mode="constant") # [n, dim_a] f += f_control # Calculate damping forces f_damping = -1.0 * self.damping * v # [n,2] f = f + f_damping # Calculate boundary forces bounce = (((p + self.radius >= self.max_p) & (v >= 0.0)) | ((p - self.radius <= self.min_p) & (v <= 0.0))) # [n,2] v_new = (-1.0 * bounce + 1.0 * ~bounce) * v # [n,2] f_boundary = self.mass * (v_new - v) / self.dt # [n,2] f = f + f_boundary # Calculate shared quantities for later calculations # same: [n,n,1], True if i==j same = jnp.expand_dims(jnp.eye(n, dtype=jnp.bool_), axis=-1) # p2p: [n,n,2], p2p[i,j,:] is the vector from entity i to entity j p2p = p - jnp.expand_dims(p, axis=1) # dist: [n,n,1], p2p[i,j,0] is the distance between i and j dist = jnp.linalg.norm(p2p, axis=-1, keepdims=True) # overlap: [n,n,1], overlap[i,j,0] is the overlap between i and j overlap = ((jnp.expand_dims(self.radius, axis=1) + jnp.expand_dims(self.radius, axis=0)) - dist) if self.same_position_check: # ontop: [n,n,1], ontop[i,j,0] = True if i is at the exact location of j ontop = (dist == 0.0) # ontop_dir: [n,n,1], (1,0) above diagonal, (-1,0) below diagonal ontop_dir = jnp.stack( [jnp.triu(jnp.ones((n, n))) * 2 - 1, jnp.zeros((n, n))], axis=-1) # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the # direction of j from i contact_dir = (~ontop * p2p + (ontop * ontop_dir)) / (~ontop * dist + ontop * 1.0) else: # contact_dir: [n,n,2], contact_dir[i,j,:] is the unit vector in the # direction of j from i contact_dir = p2p / (dist + same) # collideable: [n,n,1], True if i and j are collideable collideable = (jnp.expand_dims(self.collideable, axis=1) & jnp.expand_dims(self.collideable, axis=0)) # overlap: [n,n,1], True if i,j overlap overlapping = overlap > 0 # Calculate collision forces # Assume all entities collide with all entities, then mask out # non-collisions. # # For approaching, coliding entities, apply a forces # along the direction of collision that results in # relative velocities consistent with the coefficient of # restitution (c) and preservation of momentum in that # direction. # momentum: m_a*v_a + m_b*v_b = m_a*v'_a + m_b*v'_b # restitution: v'_b - v'_a = -c*(v_b-v_a) # solve for v'_a: # v'_a = [m_a*v_a + m_b*v_b + m_b*c*(v_b-v_a)]/(m_a + m_b) # # v_contact_dir: [n,n] speed of i in dir of j v_contact_dir = jnp.sum(jnp.expand_dims(v, axis=-2) * contact_dir, axis=-1) # v_approach: [n,n] speed that i,j are approaching each other v_approach = jnp.transpose(v_contact_dir) + v_contact_dir # momentum: [n,n] joint momentum in direction of contact (i->j) momentum = self.mass * v_contact_dir - jnp.transpose( self.mass * v_contact_dir) # v_result: [n,n] speed of i in dir of j after collision v_result = ((momentum + self.restitution * jnp.transpose(self.mass) * (-v_approach)) / (self.mass + jnp.transpose(self.mass))) # f_collision: [n,n] force on i in dir of j to realize acceleration f_collision = self.mass * (v_result - v_contact_dir) / self.dt # f_collision: [n,n,2] force on i to realize acceleration due to # collision with j f_collision = jnp.expand_dims(f_collision, axis=-1) * contact_dir # collision_mask: [n,n,1] collision_mask = (collideable & overlapping & ~same & (jnp.expand_dims(v_approach, axis=-1) > 0)) # f_collision: [n,2], sum of collision forces on i f_collision = jnp.sum(f_collision * collision_mask, axis=-2) f = f + f_collision # Calculate overlapping spring forces # This corrects for any overlap due to discrete steps. # f_overlap: [n,n,2], force in the negative contact dir due to overlap f_overlap = -1.0 * contact_dir * overlap * self.overlap_spring_constant # overlapping_mask: [n,n,1], True if i,j are collideable, overlap, # and i != j overlapping_mask = collideable & overlapping & ~same # f_overlap: [n,2], sum of spring forces on i f_overlap = jnp.sum(f_overlap * overlapping_mask, axis=-2) f = f + f_overlap # apply forces v = v + (f / self.mass) * self.dt p = p + v * self.dt # update misc misc = self._update_misc((p, v, misc), a) # pylint: disable=assignment-from-none return (p, v, misc)
def loss_fn( model_config: ml_collections.FrozenConfigDict, model_params: Dict[Text, Any], model_vars: Dict[Text, Any], batch: Dict[Text, Any], deterministic: bool, dropout_rng: Optional[Dict[Text, Array]] = None, ) -> Tuple[float, MetricGroups, Dict[str, Any]]: """Loss function used by Ultra Fine Entity Typing task. See BaseTask.""" variable_dict = {'params': model_params} variable_dict.update(model_vars) loss_helpers, _ = cls.build_model(model_config).apply( variable_dict, batch, deterministic=deterministic, rngs=dropout_rng) classifier_logits = loss_helpers['classifier_logits'].astype( jnp.float32) log_prob = jax.nn.log_sigmoid(classifier_logits) # log(1 - sigmoid(x)) = log_sigmoid(-x) # We use the latter since it is more numerically stable and denote it # as `log_comp_prob` (log of probability of the complimentary event). log_comp_prob = jax.nn.log_sigmoid(-classifier_logits) # batch['classifier_target'] has shape [batch_size, max_labels_per_sample] # and contain all labels in a sparse format. The code below converts # this to a dense format. classifier_labels = jax.nn.one_hot(batch['classifier_target'], NUM_CLASSES, dtype=jnp.float32) classifier_labels *= jnp.expand_dims( batch['classifier_target_mask'], -1) # Labels in a dense format with a shape [batch_size, NUM_CLASSES] classifier_labels = classifier_labels.sum(axis=1) loss_per_label = -log_prob * classifier_labels - log_comp_prob * ( 1.0 - classifier_labels) coarse_grained_weight = get_weight_per_group( classifier_labels, COARSE_CLASSES_START, COARSE_CLASSES_END) fine_grained_weight = get_weight_per_group(classifier_labels, FINE_CLASSES_START, FINE_CLASSES_END) ultra_fine_grained_weight = get_weight_per_group( classifier_labels, ULTRA_FINE_CLASSES_START, ULTRA_FINE_CLASSES_END) coarse_grained_loss = get_loss_per_group(loss_per_label, coarse_grained_weight, COARSE_CLASSES_START, COARSE_CLASSES_END) fine_grained_loss = get_loss_per_group(loss_per_label, fine_grained_weight, FINE_CLASSES_START, FINE_CLASSES_END) ultra_fine_grained_loss = get_loss_per_group( loss_per_label, ultra_fine_grained_weight, ULTRA_FINE_CLASSES_START, ULTRA_FINE_CLASSES_END) loss_per_sample = (coarse_grained_loss + fine_grained_loss + ultra_fine_grained_loss) loss = loss_per_sample.sum() metrics = { 'agg': { 'loss': loss, 'denominator': loss_per_sample.shape[0], }, 'coarse_grained': { 'loss': coarse_grained_loss.sum(), 'denominator': coarse_grained_weight.sum(), }, 'fine_grained': { 'loss': fine_grained_loss.sum(), 'denominator': fine_grained_weight.sum(), }, 'ultra_fine_grained': { 'loss': ultra_fine_grained_loss.sum(), 'denominator': ultra_fine_grained_weight.sum(), }, } metrics.update( get_eval_metrics(classifier_labels, classifier_logits)) return loss, metrics, {}
def beam_search_body_fn(state, input_ids_length=1): """beam search state update fn.""" # 1. Forward current tokens # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # unflatten beam dimension # Unflatten beam dimension in attention cache arrays input_token = flatten_beam_dim( lax.dynamic_slice( state.running_sequences, (0, 0, state.cur_len - input_ids_length), (batch_size, num_beams, input_ids_length), ) ) model_outputs = model(input_token, params=params, **state.model_kwargs) logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) cache = jax.tree_map( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values ) # adapt logits for FlaxMarianMTModel logits = self._adapt_logits_for_beam_search(logits) # 2. Compute log probs # get log probabilities from logits, # process logits with processors (*e.g.* min_length, ...), and # add new logprobs to existing running logprobs scores. log_probs = jax.nn.log_softmax(logits) log_probs = logits_processor( flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len ) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) vocab_size = log_probs.shape[2] log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) # 3. Retrieve top-K # Each item in batch has num_beams * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. # Gather the top 2*K scores from _all_ beams. # Gather 2*k top beams. # Recover the beam index by floor division. # Recover token id by modulo division and expand Id array for broadcasting. # Update sequences for the 2*K top-k new sequences. beams_to_keep = 2 * num_beams topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_running_sequences = gather_beams( state.running_sequences, topk_beam_indices, batch_size, beams_to_keep ) topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) # 4. Check which sequences have ended # Update current sequences: # Did any of these sequences reach an end marker? # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large # negative value. did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) # 5. Get running sequences scores for next # Determine the top k beam indices (from top 2*k beams) from log probs # and gather top k beams (from top 2*k beams). next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) next_running_sequences, next_running_scores = gather_beams( [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams ) # 6. Process topk logits # Further process log probs: # - add length penalty # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) beams_in_batch_are_full = ( jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) & early_stopping ) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += add_penalty * np.array(-1.0e7) # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams ) # 8. Update model kwargs. # Determine the top k beam indices from the original set of all beams. # With these, gather the top k beam-associated caches. next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) return BeamSearchState( cur_len=state.cur_len + 1, running_scores=next_running_scores, running_sequences=next_running_sequences, scores=next_scores, sequences=next_sequences, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, )
def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin s = _shape(input) if isinstance(s, (np.ndarray, onp.generic)): return np.zeros(s, utils.numpy_dtype(dtype or input.dtype)) return tf.zeros(s, dtype or s.dtype, name) # --- Begin Public Functions -------------------------------------------------- concat = utils.copy_docstring( tf.concat, lambda values, axis, name='concat': ( # pylint: disable=g-long-lambda np.concatenate([ops.convert_to_tensor(v) for v in values], axis))) expand_dims = utils.copy_docstring( tf.expand_dims, lambda input, axis, name=None: np.expand_dims(input, axis)) fill = utils.copy_docstring( tf.fill, lambda dims, value, name=None: value * np.ones(dims, np.array(value).dtype)) gather = utils.copy_docstring(tf.gather, _gather) gather_nd = utils.copy_docstring(tf.gather_nd, _gather_nd) reverse = utils.copy_docstring(tf.reverse, _reverse) linspace = utils.copy_docstring( tf.linspace, lambda start, stop, num, name=None: ( # pylint: disable=g-long-lambda
def apply(self, x, batch_stats=None, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros, scale_init=initializers.ones, axis_name=None, axis_index_groups=None, virtual_batch_size=None, data_format=None): """Normalizes the input using batch statistics. Forked from the original flax nn.BatchNorm layer, this allows users to have multiple EMAs per device, one for each virtual batch size. For example, if the per-device batch size is 128 and the user specifies `virtual_batch_size=32`, 4 EMAs will be created on each device, each updated with 1/4 of the per-device batch on each forward pass. WARNING: the multiple per-device EMAs this creates need to be manually synchronized within each device before being used for evaluation, or when synchronizing batch norm statistic across devices. Args: x: the input to be normalized. batch_stats: a `flax.nn.Collection` used to store an exponential moving average of the batch statistics (default: None). use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: if True, bias (beta) is added. scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, `[[0, 1], [2, 3]]` would independently batch-normalize over the examples on the first two and last two devices. See `jax.lax.psum` for more details. virtual_batch_size: the size of the virtual batches to construct on each device, which will be used to normalize sub-batches of each per-device batch. Will create a running average with a leading dim of size `x.shape[batch_axis] // virtual_batch_size`, one for each sub-batch. Note that the first dim of each state must be synchronized whenever synchronizing batch norm running averages. Must evenly divide the per-device batch size (as determined by `x`), and cannot be combined with `axis_index_groups`. Passing the default value of None will replicate the existing nn.BatchNorm behavior without virtual batches. data_format: only used when `virtual_batch_size` is set, to determine the batch axis. Returns: Normalized inputs (the same shape as inputs). """ batch_axis = _get_batch_axis(data_format, x, virtual_batch_size, use_running_average, axis_index_groups) if virtual_batch_size is None: virtual_batch_size = x.shape[batch_axis] if use_running_average: # Virtual batch norm is not used during evaluation, and we cannot # guarantee the train and eval batch sizes are the same, so we use a # single virtual batch of size batch_size, and take the first element in # the running average array, assuming they have been properly synced # across their first dim. virtual_batch_size = x.shape[batch_axis] x = jnp.asarray(x, jnp.float32) num_sub_batches = x.shape[batch_axis] // virtual_batch_size input_shape = x.shape axis = axis if isinstance(axis, tuple) else (axis, ) axis = _absolute_dims(x.ndim, axis) feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) # Add an additional axis because we are going to reshape `x` to have a # leading dim of size `virtual_batch_size`. reduction_axis = tuple(i + 1 for i in range(x.ndim) if i not in axis) sub_batched_shape = (num_sub_batches, *x.shape[:batch_axis], virtual_batch_size, *x.shape[batch_axis + 1:]) x = jnp.reshape(x, sub_batched_shape) if self.is_stateful() or batch_stats: ra_means = self.state('batch_norm_running_mean', (num_sub_batches, *reduced_feature_shape), initializers.zeros, collection=batch_stats) ra_vars = self.state('batch_norm_running_var', (num_sub_batches, *reduced_feature_shape), initializers.ones, collection=batch_stats) else: ra_means = None ra_vars = None if use_running_average: if ra_means is None: raise ValueError( 'when use_running_averages is True ' 'either use a stateful context or provide batch_stats') # Note that we assume that the values across the first axis have been # properly synchronized. mean = jnp.expand_dims(ra_means.value[0], 0) var = jnp.expand_dims(ra_vars.value[0], 0) else: mean = jnp.mean(x, axis=reduction_axis, keepdims=False) mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if axis_name is not None and not self.is_initializing(): concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean(concatenated_mean, axis_name=axis_name, axis_index_groups=axis_index_groups), 2) var = mean2 - lax.square(mean) if ra_means and not self.is_initializing(): ra_means.value = momentum * ra_means.value + (1 - momentum) * mean ra_vars.value = momentum * ra_vars.value + (1 - momentum) * var y = x - mean.reshape((num_sub_batches, *feature_shape)) mul = lax.rsqrt( var.reshape((num_sub_batches, *feature_shape)) + epsilon) if scale: mul = mul * self.param('scale', reduced_feature_shape, scale_init).reshape((1, *feature_shape)) y = y * mul if bias: y = y + self.param('bias', reduced_feature_shape, bias_init).reshape((1, *feature_shape)) y = jnp.reshape(y, input_shape) return jnp.asarray(y, dtype)
def funx(x): if isinstance(x, RelaxVariable): return jnp.zeros_like(jnp.expand_dims(x.lower[0, ...], 0)) else: return x
def __call__( self, hidden_states, key_value_states: Optional[jnp.ndarray] = None, attention_mask=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] if not is_cross_attention: qkv_out = self.c_attn(hidden_states) query, key, value = jnp.split(qkv_out, 3, axis=2) else: q_out = self.q_attn(hidden_states) (query, ) = jnp.split(q_out, 1, axis=2) kv_out = self.c_attn(key_value_states) key, value = jnp.split(kv_out, 2, axis=2) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) query_length, key_length = query.shape[1], key.shape[1] if self.causal: if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"][ "cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)) else: causal_mask = self.causal_mask[:, :, :query_length, : key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) + causal_mask.shape[1:]) # combine masks if needed if attention_mask is not None and self.causal: attention_mask = jnp.broadcast_to( jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) elif self.causal: attention_mask = causal_mask elif attention_mask is not None: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) dropout_rng = None if not deterministic and self.config.attn_pdrop > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): key, value, attention_mask = self._concatenate_to_cache( key, value, query, attention_mask) # transform boolean mask into float mask if attention_mask is not None: attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype), ) else: attention_bias = None # usual dot product attention attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) outputs = (attn_output, attn_weights) if output_attentions else (attn_output, ) return outputs
def apply(self, inputs, info, config, train=False, cache=None): # Inputs output_token_vocabulary_size = info.output_vocab_size true_indexes = inputs['true_branch_nodes'] false_indexes = inputs['false_branch_nodes'] start_indexes = inputs['start_index'] # pylint: disable=unused-variable exit_indexes = inputs['exit_index'] steps_all = inputs['steps'] vocab_size = info.features[info._builder.key('statements')].vocab_size # pylint: disable=protected-access hidden_size = config.model.hidden_size data = inputs['data'].astype('int32') batch_size, num_nodes, unused_statement_length = data.shape # An upper bound on the number of steps to take. max_steps = int(1.5 * info.max_diameter) # Init parameters def emb_init(key, shape, dtype=jnp.float32): return jax.random.uniform(key, shape, dtype, -config.initialization.maxval, config.initialization.maxval) embed = Embed.shared(num_embeddings=vocab_size, features=hidden_size, emb_init=emb_init, name='embed') branch_decide_dense = nn.Dense.shared( name='branch_decide_dense', features=2, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) cells = create_lstm_cells(config.model.rnn_cell.layers) lstm = StackedRNNCell.shared(cells=cells) if config.model.interpolant.apply_dense: dense_parent_to_true_child = nn.Dense.shared( name='dense_parent_to_true_child', features=hidden_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) dense_parent_to_false_child = nn.Dense.shared( name='dense_parent_to_false_child', features=hidden_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) dense_true_child_to_parent = nn.Dense.shared( name='dense_true_child_to_parent', features=hidden_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) dense_false_child_to_parent = nn.Dense.shared( name='dense_false_child_to_parent', features=hidden_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) if config.model.interpolant.apply_gru: gru_cell = nn.recurrent.GRUCell.shared(name='gru_cell') output_dense = nn.Dense.shared( name='output_dense', features=output_token_vocabulary_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) # Apply def execute_single_node(hidden_state, node_embedding): carry, _ = lax.scan(lstm, hidden_state, node_embedding) return carry execute = jax.vmap(execute_single_node) # Single example. def branch_decide_single_node(hidden_state): # leaves(hidden_state).shape: hidden_size hidden_state_concat = jnp.concatenate( jax.tree_leaves(hidden_state), axis=0) return branch_decide_dense(hidden_state_concat) branch_decide = jax.vmap(branch_decide_single_node) def update_instruction_pointer(instruction_pointer, branch_decisions, true_indexes, false_indexes): # instruction_pointer.shape: num_nodes, # branch_decisions: num_nodes, 2, # true_indexes: num_nodes, # false_indexes: num_nodes p_true = branch_decisions[:, 0] p_false = branch_decisions[:, 1] if not config.model.interpolant.use_b: p_true = jnp.ones_like(p_true) p_false = jnp.ones_like(p_false) if not config.model.interpolant.use_p: instruction_pointer = jnp.ones_like(instruction_pointer) true_contributions = jax.ops.segment_sum(p_true * instruction_pointer, true_indexes, num_segments=num_nodes) false_contributions = jax.ops.segment_sum(p_false * instruction_pointer, false_indexes, num_segments=num_nodes) return true_contributions + false_contributions def aggregate(hidden_states, instruction_pointer, branch_decisions, true_indexes, false_indexes): # leaves(hidden_states).shape: num_nodes, hidden_size # instruction_pointer.shape: num_nodes, # branch_decisions: num_nodes, 2, # true_indexes: num_nodes, # false_indexes: num_nodes p_true = branch_decisions[:, 0] p_false = branch_decisions[:, 1] if not config.model.interpolant.use_b: p_true = jnp.ones_like(p_true) p_false = jnp.ones_like(p_false) if not config.model.interpolant.use_p: instruction_pointer = jnp.ones_like(instruction_pointer) denominators = update_instruction_pointer(instruction_pointer, branch_decisions, true_indexes, false_indexes) denominators += 1e-7 # denominator.shape: num_nodes, if not config.model.interpolant.normalize: denominators = jnp.ones_like(denominators) def aggregate_component(h): # h.shape: num_nodes # p_true.shape: num_nodes # instruction_pointer.shape: num_nodes true_contributions = jax.ops.segment_sum( h * p_true * instruction_pointer, true_indexes, num_segments=num_nodes) false_contributions = jax.ops.segment_sum( h * p_false * instruction_pointer, false_indexes, num_segments=num_nodes) # *_contributions.shape: num_nodes, hidden_size return (true_contributions + false_contributions) / denominators aggregate_component = jax.vmap(aggregate_component, in_axes=1, out_axes=1) return jax.tree_map(aggregate_component, hidden_states) def step_single_example(hidden_states, instruction_pointer, node_embeddings, true_indexes, false_indexes, exit_index): # Execution (e.g. apply RNN) # leaves(hidden_states).shape: num_nodes, hidden_size # instruction_pointer.shape: num_nodes, # node_embeddings.shape: num_nodes, statement_length, hidden_size if config.model.interpolant.apply_code_rnn: hidden_state_contributions = execute(hidden_states, node_embeddings) # leaves(hidden_state_contributions).shape: num_nodes, hidden_size else: hidden_state_contributions = hidden_states if config.model.interpolant.apply_dense: parent_to_true_child = jax.tree_map( dense_parent_to_true_child, hidden_state_contributions) parent_to_false_child = jax.tree_map( dense_parent_to_false_child, hidden_state_contributions) true_child_to_parent = jax.tree_map( dense_true_child_to_parent, hidden_state_contributions) false_child_to_parent = jax.tree_map( dense_false_child_to_parent, hidden_state_contributions) else: parent_to_true_child = hidden_state_contributions parent_to_false_child = hidden_state_contributions true_child_to_parent = hidden_state_contributions false_child_to_parent = hidden_state_contributions # Use the exit node's hidden state as it's hidden state contribution # to avoid "executing" the exit node. def mask_h(h_contribution, h): return h_contribution.at[exit_index, :].set(h[exit_index, :]) hidden_state_contributions = jax.tree_multimap( mask_h, hidden_state_contributions, hidden_states) # Branch decisions (e.g. Dense layer) branch_decision_logits = branch_decide(hidden_state_contributions) branch_decisions = nn.softmax(branch_decision_logits, axis=-1) # Update state if config.model.interpolant.use_ipa: instruction_pointer_new = update_instruction_pointer( instruction_pointer, branch_decisions, true_indexes, false_indexes) hidden_states_new = aggregate(hidden_state_contributions, instruction_pointer, branch_decisions, true_indexes, false_indexes) else: assert config.model.interpolant.use_parent_embeddings assert config.model.interpolant.use_child_embeddings instruction_pointer_new = instruction_pointer normalization = jnp.sqrt( 2 + ( # Each node has a true and false child. # jnp.bincount(true_indexes, minlength=num_nodes) jax.ops.segment_sum(jnp.ones_like(true_indexes), true_indexes, num_segments=num_nodes) # + jnp.bincount(false_indexes, minlength=num_nodes) + jax.ops.segment_sum(jnp.ones_like(false_indexes), false_indexes, num_segments=num_nodes))) # normalization.shape: num_nodes, def aggregate_parent_and_child_contributions(p1, p2, c3, c4): return (jax.ops.segment_sum( p1, true_indexes, num_segments=num_nodes) + jax.ops.segment_sum( p2, false_indexes, num_segments=num_nodes) + c3[true_indexes] + c4[false_indexes]) / normalization[:, None] hidden_states_new = jax.tree_multimap( aggregate_parent_and_child_contributions, parent_to_true_child, parent_to_false_child, true_child_to_parent, # true_child_to_parent[child] -> parent false_child_to_parent) if config.model.interpolant.apply_gru: def apply_gru(h2, h1): output, _ = gru_cell(h2, h1) return output hidden_states_new = (jax.tree_multimap(apply_gru, hidden_states_new, hidden_states)) to_tag = { 'branch_decisions': branch_decisions, 'hidden_state_contributions': hidden_state_contributions, 'hidden_states_before': hidden_states, 'hidden_states': hidden_states_new, 'instruction_pointer_before': instruction_pointer, 'instruction_pointer': instruction_pointer_new, 'true_indexes': true_indexes, 'false_indexes': false_indexes, } return hidden_states_new, instruction_pointer_new, to_tag def compute_logits_single_example(hidden_states, instruction_pointer, exit_index, steps, node_embeddings, true_indexes, false_indexes): """single_example refers to selecting a single exit node hidden state.""" # leaves(hidden_states).shape: num_nodes, hidden_size def step_(carry, _): hidden_states, instruction_pointer, index = carry hidden_states_new, instruction_pointer_new, to_tag = ( step_single_example(hidden_states, instruction_pointer, node_embeddings, true_indexes, false_indexes, exit_index)) carry = jax.tree_multimap( lambda new, old, index=index: jnp.where( index < steps, new, old), (hidden_states_new, instruction_pointer_new, index + 1), (hidden_states, instruction_pointer, index + 1), ) return carry, to_tag if config.model.ipagnn.checkpoint and not self.is_initializing(): step_ = jax.checkpoint(step_) carry = (hidden_states, instruction_pointer, jnp.array([0])) (hidden_states, instruction_pointer, _), to_tag = lax.scan(step_, carry, None, length=max_steps) final_state = jax.tree_map(lambda hs: hs[exit_index], hidden_states) # leaves(final_state).shape: hidden_size final_state_concat = jnp.concatenate(jax.tree_leaves(final_state), axis=0) logits = output_dense(final_state_concat) to_tag.update({ 'instruction_pointer_final': instruction_pointer, 'hidden_states_final': hidden_states, }) return logits, to_tag compute_logits = jax.vmap(compute_logits_single_example, in_axes=(0, 0, 0, 0, 0, 0, 0)) # Init state node_embeddings = embed(data) # node_embeddings.shape: # batch_size, num_nodes, statement_length, hidden_size hidden_states = StackedRNNCell.initialize_carry( jax.random.PRNGKey(0), cells, ( batch_size, num_nodes, ), hidden_size) if config.model.interpolant.init_with_code_embeddings: hidden_states = jax.vmap(execute)(hidden_states, node_embeddings) # leaves(hidden_states).shape: batch_size, num_nodes, hidden_size instruction_pointer = jax.ops.index_add( jnp.zeros(( batch_size, num_nodes, )), jax.ops.index[:, 0], # TODO(dbieber): Use "start_index" instead of 0. 1) # instruction_pointer.shape: batch_size, num_nodes, logits, to_tag = compute_logits(hidden_states, instruction_pointer, exit_indexes, steps_all, node_embeddings, true_indexes, false_indexes) for key, value in to_tag.items(): value = Tag(value, name=key) logits = jnp.expand_dims(logits, axis=1) return logits
def _cov_helper_without_p(data): return jnp.expand_dims(jnp.matmul(jnp.conj(jnp.transpose(data)), data), axis=0)
def apply(rs: JaxArray, ids: JaxArray, **kwargs): ξ = np.expand_dims(f(rs, ids, **kwargs).flatten(), 0) Jξ = np.expand_dims(Jf(rs, ids, **kwargs).flatten(), 0) return ξ, Jξ
def __render_comps(model, has_bulge, has_bar, n_spirals, shape, oversample_n, base_roll): """Render the components of a galaxy builder model Arguments: model -- The model to render. This should be a dictionary like {(param, component): value} (i.e. what you would get from pd.Series(...).to_dict()) has_bulge -- Whether the model contains a bulge, needed for Jax compilation has_bar -- Whether the model contains a bar n_spirals -- The number of spiral arms present in the model shape -- The desired output shape to render oversample_n -- The factor to which Sersic oversampling will be done base_roll -- The roll parameter of the original model. This is needed to preserve the location of spiral arms as best as possible """ P, P_super = _make_xy_arrays(shape, oversample_n) out = {} disk_I = sersic_I( model[('disk', 'L')], model[('disk', 'Re')], model[('disk', 'q')], 1, 2 ) disk_super = sersic( *P_super, mux=model[('disk', 'mux')], muy=model[('disk', 'muy')], roll=model[('disk', 'roll')], q=model[('disk', 'q')], Re=model[('disk', 'Re')], I=disk_I, n=1.0, c=2.0, ) out['disk'] = jnp.squeeze(downsample(disk_super, oversample_n)) # next add spirals to the disk if n_spirals > 0: spirals = get_spirals(model, n_spirals, base_roll) spiral_distances = jnp.stack([ vmap_polyline_distance(s, *P) for s in spirals ], axis=-1) Is = jnp.array([ model[('spiral', 'I.{}'.format(i))] for i in range(n_spirals) ]) spreads = jnp.array([ model[('spiral', 'spread.{}'.format(i))] for i in range(n_spirals) ]) spirals = jnp.sum( Is * jnp.exp(-spiral_distances**2 / (2*spreads**2)) * jnp.expand_dims(out['disk'], -1), axis=-1 ) out['spiral'] = spirals else: spirals = jnp.zeros(shape) # calculate the luminosity of the disk and spirals together (the bulge and # bar fractions are calculated relative to this) disk_spiral_L = model[('disk', 'L')] + spirals.sum() # if we have a bulge, render it if has_bulge: # bulge_frac assumes we don't have a bar bulge_L = ( model[('bulge', 'frac')] * (disk_spiral_L) / (1 - model[('bulge', 'frac')]) ) bulge_Re = model[('disk', 'Re')] * model[('bulge', 'scale')] bulge_I = sersic_I( bulge_L, bulge_Re, model[('bulge', 'q')], model[('bulge', 'n')] ) bulge_super = sersic( *P_super, mux=model[('centre', 'mux')], muy=model[('centre', 'muy')], roll=model[('bulge', 'roll')], q=model[('bulge', 'q')], Re=bulge_Re, I=bulge_I, n=model[('bulge', 'n')], c=2.0 ) out['bulge'] = jnp.squeeze(downsample(bulge_super, oversample_n)) # if we have a bar, render it if has_bar: # bar_frac assumes we don't have a bulge bar_L = ( model[('bar', 'frac')] * (disk_spiral_L) / (1 - model[('bar', 'frac')]) ) bar_Re = model[('disk', 'Re')] * model[('bar', 'scale')] bar_I = sersic_I( bar_L, bar_Re, model[('bar', 'q')], model[('bar', 'n')] ) bar_super = sersic( *P_super, mux=model[('centre', 'mux')], muy=model[('centre', 'muy')], roll=model[('bar', 'roll')], q=model[('bar', 'q')], Re=bar_Re, I=bar_I, n=model[('bar', 'n')], c=model[('bar', 'c')], ) out['bar'] = jnp.squeeze(downsample(bar_super, oversample_n)) # return the dictionary of rendered components return out
def call( self, query: jnp.ndarray, key: tp.Optional[jnp.ndarray] = None, value: tp.Optional[jnp.ndarray] = None, mask=None, training=None, ): """ Arguments: inputs: List of `[query, key, value]` where * `query`: np.ndarray of shape `(..., query_elements, query_depth)` * `key`: `np.ndarray of shape '(..., key_elements, key_depth)` * `value`: np.ndarray of shape `(..., key_elements, value_depth)`, optional, if not given `key` will be used. mask: a binary np.ndarray of shape `[batch_size?, num_heads?, query_elements, key_elements]` which specifies which query elements can attendo to which key elements, `1` indicates attention and `0` indicates no attention. Output shape: * `(..., query_elements, output_size)` if `output_size` is given, else * `(..., query_elements, value_depth)` if `value` is given, else * `(..., query_elements, key_depth)` """ # einsum nomenclature # ------------------------ # N = query elements # M = key/value elements # H = heads # I = input features # O = output features if key is None: key = query if value is None: value = key output_size = (self.output_size if self.output_size is not None else value.shape[-1]) # verify shapes if key.shape[-2] != value.shape[-2]: raise ValueError( "the number of elements in 'key' must be equal to the same as the number of elements in 'value'" ) if mask is not None: if len(mask.shape) < 2: raise ValueError("'mask' must have atleast 2 dimensions") if query.shape[-2] != mask.shape[-2]: raise ValueError( "mask's second to last dimension must be equal to the number of elements in 'query'" ) if key.shape[-2] != mask.shape[-1]: raise ValueError( "mask's last dimension must be equal to the number of elements in 'key'" ) # get weights query_kernel = hooks.get_parameter( "query_kernel", [self.num_heads, query.shape[-1], self.head_size], jnp.float32, initializer=self.kernel_initializer, ) key_kernel = hooks.get_parameter( "key_kernel", [self.num_heads, key.shape[-1], self.head_size], jnp.float32, initializer=self.kernel_initializer, ) value_kernel = hooks.get_parameter( "value_kernel", [self.num_heads, value.shape[-1], self.head_size], jnp.float32, initializer=self.kernel_initializer, ) projection_kernel = hooks.get_parameter( "projection_kernel", [self.num_heads, self.head_size, output_size], jnp.float32, initializer=self.kernel_initializer, ) # Linear transformations query = jnp.einsum("...NI , HIO -> ...NHO", query, query_kernel) key = jnp.einsum("...MI , HIO -> ...MHO", key, key_kernel) value = jnp.einsum("...MI , HIO -> ...MHO", value, value_kernel) # Scale dot-product, doing the division to either query or key # instead of their product saves some computation query /= jnp.sqrt(self.head_size) # Calculate dot product attention logits = jnp.einsum("...NHO,...MHO->...HNM", query, key) # apply mask if mask is not None: mask = mask.astype(jnp.float32) # possibly expand on the head dimension so broadcasting works if len(mask.shape) != len(logits.shape): mask = jnp.expand_dims(mask, -3) logits += -10e9 * (1.0 - mask) attn_coef = jax.nn.softmax(logits) # attention dropout attn_coef_dropout = Dropout(self.droput_rate)(attn_coef, training=training) # attention * value multihead_output = jnp.einsum("...HNM,...MHI->...NHI", attn_coef_dropout, value) # Run the outputs through another linear projection layer. Recombining heads # is automatically done. output = jnp.einsum("...NHI,HIO->...NO", multihead_output, projection_kernel) if self.use_projection_bias: output += hooks.get_parameter( "projection_bias", [output_size], jnp.float32, initializer=self.bias_initializer, ) if self.return_attn_coef: return output, attn_coef else: return output
def wrapper(*args, **kwargs): expand = lambda t: jnp.expand_dims(t, axis=axis) args = jax.tree_map(expand, args) kwargs = jax.tree_map(expand, kwargs) outputs = f(*args, **kwargs) return jax.tree_map(lambda t: jnp.squeeze(t, axis=axis), outputs)
def test_mention_memory_layer(self, separate_memory_values): """Testing memory attention layer.""" test_utils.force_multi_devices(self.n_devices) devices = jax.local_devices() model = memory_attention_layer.MemoryAttentionLayer( memory_key_dim=self.memory_key_dim, input_dim=self.input_dim, memory_update_type=self.memory_update_type, memory_update_config=self.memory_update_config, k_top_device=self.k_top_device, k_top_post_selection=self.k_top_post_selection, splits=self.splits, dtype=self.dtype) static_argnums = (9) if separate_memory_values else (9, 10) pinit_with_output = jax.pmap( model.init_with_output, axis_name='batch', static_broadcasted_argnums=static_argnums) rng = jax.random.PRNGKey(0) split_rng = jax.random.split(rng, self.n_devices) encoded_input = jnp.ones( shape=(self.bsz, self.seq_len, self.input_dim), dtype=self.dtype) encoded_input = jax.device_put_replicated(encoded_input, devices) mention_batch_positions = jnp.tile( jnp.arange(self.bsz).reshape(-1, 1), (1, 3)).reshape(-1) mention_batch_positions = jax.device_put_replicated(mention_batch_positions, devices) mention_start_positions = jnp.tile(jnp.asarray([0, 5, 10]), (self.bsz)) mention_start_positions = jax.device_put_replicated(mention_start_positions, devices) mention_end_positions = jnp.tile(jnp.asarray([2, 7, 12]), (self.bsz)) mention_end_positions = jax.device_put_replicated(mention_end_positions, devices) n_mentions = mention_start_positions.shape[-1] mention_mask = jnp.tile(jnp.asarray([1, 1, 1]), (self.bsz)) mention_mask = jax.device_put_replicated(mention_mask, devices) memory_table = np.ones( (self.n_devices * self.table_size, self.memory_key_dim), dtype=self.dtype) # Make sure id 0 or 1 will be highest scoring memory_table[0] = memory_table[0] * 2.0 memory_table[1] = memory_table[1] * -2.0 memory_table = jnp.asarray(memory_table, dtype=self.dtype) memory_keys = memory_table.reshape(self.n_devices, self.rows, self.table_size // self.rows, self.memory_key_dim) memory_keys_sharded = jax.device_put_sharded(list(memory_keys), devices) if separate_memory_values: memory_values = memory_table.reshape(self.n_devices, self.table_size, self.memory_key_dim) memory_values = jax.device_put_sharded(list(memory_values), devices) else: memory_values = None memory_entity_ids = np.arange(self.n_devices * self.table_size).reshape( self.n_devices, self.table_size) memory_entity_ids = jax.device_put_sharded(list(memory_entity_ids), devices) # Use entity id as identifier here memory_identifiers = memory_entity_ids (encoded_output, loss_helpers, _), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, text_identifiers=None, ) attention_weights = loss_helpers['memory_attention_weights'] entity_ids = loss_helpers['top_entity_ids'] normed_input = encoded_input - 1.0 # Check input was changed self.assertFalse(jnp.allclose(encoded_output, normed_input)) # Check input was not changed where it should not be all_indices = set( itertools.product(np.arange(self.bsz), np.arange(self.seq_len))) # Note that mention positions is the same across all of the devices start_indices = set( zip(mention_batch_positions[0].tolist(), mention_start_positions[0].tolist())) non_start_indices = all_indices.difference(start_indices) non_start_indices_1, non_start_indices_2 = zip(*non_start_indices) non_start_indices_1 = jnp.asarray(non_start_indices_1) non_start_indices_2 = jnp.asarray(non_start_indices_2) non_start_outputs = encoded_output[:, non_start_indices_1, non_start_indices_2] non_start_inputs = normed_input[:, non_start_indices_1, non_start_indices_2] self.assertTrue(jnp.allclose(non_start_outputs, non_start_inputs)) # Check shapes as expected self.assertSequenceEqual( encoded_output.shape, (self.n_devices, self.bsz, self.seq_len, self.input_dim)) self.assertSequenceEqual( attention_weights.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) self.assertSequenceEqual( entity_ids.shape, (self.n_devices, n_mentions, self.k_top_post_selection)) # Check id 0 or 1 retrieved self.assertTrue( jnp.all((entity_ids[..., 0] == 0) + (entity_ids[..., 0] == 1))) # Set some text identifiers to 0 and others to 1 so that some are binding text_identifiers = np.zeros((n_mentions), dtype=np.int32) text_identifiers[:n_mentions // 2] = 1 text_identifiers = jax.device_put_replicated(text_identifiers, devices) # Initialize and run one forward pass of model (_, loss_helpers, logging_helpers), _ = pinit_with_output( split_rng, encoded_input, mention_batch_positions, mention_start_positions, mention_end_positions, mention_mask, memory_keys_sharded, memory_identifiers, memory_entity_ids, True, # deterministic memory_values, # memory_values text_identifiers=text_identifiers, ) attention_weights_wid = loss_helpers['memory_attention_weights'] entity_ids_wid = loss_helpers['top_entity_ids'] n_disallowed = logging_helpers['n_disallowed'][0] # Check no effect on ids self.assertTrue(jnp.all(entity_ids == entity_ids_wid)) # Check id 0 or 1 have 0 scores text_identifiers = jnp.expand_dims(text_identifiers, -1) score_masked = (text_identifiers == entity_ids_wid) * attention_weights_wid self.assertAlmostEqual(score_masked.sum(), 0.0) # Check number disallowed as expected self.assertEqual(n_disallowed, n_mentions // 2)