def expand(observed_value): return jnp.tile(observed_value, (num_prediction_samples, 1))
def apply(self, inputs, vocab_size, inputs_positions=None, inputs_segmentation=None, shared_embedding=None, use_bfloat16=False, emb_dim=512, num_heads=8, dtype=jnp.float32, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=512, train=True, dropout_rate=0.1, attention_dropout_rate=0.1, learn_pos_emb=False, classifier=False, classifier_pool='CLS', num_classes=10): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. shared_embedding: a shared embedding layer to use. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads dtype: the dtype of the computation (default: float32) num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights learn_pos_emb: boolean, if learn the positional embedding or use the sinusoidal positional embedding. classifier: boolean, for classification mode (output N-class logits) classifier_pool: str, supports "MEAN", "MAX" pooling. num_classes: int, number of classification classes. Returns: output of a transformer encoder or logits if classifier_mode is true. """ assert inputs.ndim == 2 # (batch, len) # Padding Masks src_padding_mask = (inputs > 0)[..., None] # Input Embedding if shared_embedding is None: input_embed = nn.Embed.partial( num_embeddings=vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: input_embed = shared_embedding x = inputs.astype('int32') x = input_embed(x) if classifier and classifier_pool == 'CLS': cls = self.param('cls', (1, 1, emb_dim), nn.initializers.zeros) cls = jnp.tile(cls, [x.shape[0], 1, 1]) x = jnp.concatenate([cls, x], axis=1) max_len += 1 src_padding_mask = jnp.concatenate( [src_padding_mask[:, :1], src_padding_mask], axis=1) pe_init = nn.initializers.normal( stddev=0.02) if learn_pos_emb else None x = common_layers.AddPositionEmbs(x, inputs_positions=inputs_positions, posemb_init=pe_init, max_len=max_len, name='posembed_input') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) if use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input Encoder for lyr in range(num_layers): x = LinearTransformerBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, name=f'encoderblock_{lyr}') encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm') if classifier: encoded = common_layers.classifier_head( encoded, num_classes, mlp_dim, pooling_mode=classifier_pool) return encoded
import pickle import copy with open('dataautodiff2.pkl', 'rb') as f: #meas in jax device array format meas, platooninfo = pickle.load(f) platoonobjfn_grad = grad(platoonobjfn_obj_jax) sim_jax = copy.deepcopy(meas) pguess_jax = jnp.array([10 * 3.3, .086 / 3.3, 1.545, 2, .175, 5.01]) args_jax = (True, 6) curplatoon_jax = [[], 581, 611] n = int(len(curplatoon_jax) - 1) leadinfo_jax, folinfo_jax, rinfo_jax = makeleadfolinfo_r3( curplatoon_jax, platooninfo, meas) p2 = jnp.tile(pguess_jax, n) #testobj = platoonobjfn_obj_jax(p2,OVM,OVMadjsys,OVMadj,meas,sim_jax,platooninfo,curplatoon_jax,leadinfo_jax,folinfo_jax,rinfo_jax,*args_jax) testgrad = platoonobjfn_grad(p2, OVM, OVMadjsys, OVMadj, meas, sim_jax, platooninfo, curplatoon_jax, leadinfo_jax, folinfo_jax, rinfo_jax, *args_jax) """ \\ TO DO \\ Get the gradient of platoonobjfn_obj using jax. Compare this is the gradient from finite differences, and the gradient from the adjoint method. Time for running finite differences is fdertime Time for running adjoint method is dertime Accuracy of adjoint method is acc, Relative accuracy of adjoint method is acc2 For jax, compute the run time to get the gradient. Get the accuracy of jax, and the relative accuracy of jax.
def exec_op(self, op, input_values, deterministic, training, **_): """Executes an op according to the normal concrete semantics.""" input_kwargs: Dict[str, Any] = op.input_kwargs op_kwargs: Dict[str, Any] = op.op_kwargs op_type = op.type if "name" not in op_kwargs: raise ValueError("Op kwargs must contain a name.") op_name = op_kwargs["name"] if op_type == OpType.NONE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [lax.stop_gradient(input_value)] elif op_type == OpType.IDENTITY: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [input_value] # nn.linear elif op_type == OpType.DENSE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.Dense(**op_kwargs)(input_value)] elif op_type == OpType.DENSE_GENERAL: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert 2 <= len(op_kwargs) <= 7 output_values = [nn.DenseGeneral(**op_kwargs)(input_value)] elif op_type == OpType.CONV: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs ks = op_kwargs["kernel_size"] if isinstance(ks, int): op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2) output_values = [nn.Conv(**op_kwargs)(input_value)] # others elif op_type == OpType.MUL: assert len(input_values) == 2 assert not input_kwargs assert len(op_kwargs) == 1 # name output_values = [input_values[0] * input_values[1]] elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]: assert len(op_kwargs) == 1 # name input_value = input_values[0] if "layer_drop_rate" in input_kwargs: assert len(input_kwargs) == 1 survival_rate = 1 - input_kwargs["layer_drop_rate"] if survival_rate == 1.0 or deterministic: pass else: # Reuse dropout's rng stream. rng = self.make_rng("dropout") mask_shape = [input_value.shape[0] ] + [1] * (input_value.ndim - 1) mask = random.bernoulli(rng, p=survival_rate, shape=mask_shape) mask = jnp.tile(mask, [1] + list(input_value.shape[1:])) input_value = lax.select(mask, input_value / survival_rate, jnp.zeros_like(input_value)) else: assert not input_kwargs assert op_type == OpType.ADD if op_type == OpType.ADD: assert len(input_values) == 2 output_values = [input_value + input_values[1]] else: assert len(input_values) == 1 output_values = [input_value] elif op_type == OpType.SCALAR_MUL: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name if "const" in input_kwargs: c = input_kwargs["const"] else: c = 1 / jnp.sqrt(input_values[0].shape[-1]) output_values = [input_values[0] * c] elif op_type == OpType.SCALAR_ADD: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name assert "const" in input_kwargs c = input_kwargs["const"] output_values = [input_values[0] + c] elif op_type == OpType.DOT_GENERAL: assert len(input_values) == 2 assert 0 < len(input_kwargs) <= 3 assert len(op_kwargs) == 1 # name output_values = [ lax.dot_general(input_values[0], input_values[1], **input_kwargs) ] elif op_type == OpType.EINSUM: assert len(input_values) == 2 assert len(input_kwargs) == 1 assert "sum" in input_kwargs output_values = [ jnp.einsum(input_kwargs["sum"], input_values[0], input_values[1]) ] # nn.attention elif op_type == OpType.SELF_ATTENTION: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [ nn.SelfAttention(**op_kwargs, deterministic=deterministic)(input_value) ] # nn.activation elif op_type in [ OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID ]: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs fn = { OpType.RELU: nn.relu, OpType.GELU: nn.gelu, OpType.SWISH: nn.swish, OpType.SIGMOID: nn.sigmoid }[op_type] output_values = [fn(input_value)] elif op_type == OpType.SOFTMAX: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [nn.softmax(input_value, **input_kwargs)] # nn.normalization elif op_type == OpType.BATCH_NORM: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 add_kwargs = {} if "use_running_average" not in input_kwargs: add_kwargs = {"use_running_average": not training} else: add_kwargs = {} output_values = [ nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs, **add_kwargs) ] elif op_type == OpType.LAYER_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.LayerNorm(**op_kwargs)(input_value)] elif op_type == OpType.GROUP_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.GroupNorm(**op_kwargs)(input_value)] # reshape operators elif op_type == OpType.RESHAPE: assert len(input_values) == 1 input_value = input_values[0] assert 0 < len(input_kwargs) < 3 new_shape = input_kwargs.pop("new_shape") if new_shape[0] == "B": new_shape = (input_value.shape[0], ) + new_shape[1:] output_values = [ jnp.reshape(input_value, new_shape, **input_kwargs) ] elif op_type == OpType.FLATTEN: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs new_shape = (input_value.shape[0], -1) output_values = [jnp.reshape(input_value, new_shape)] elif op_type == OpType.TRANSPOSE: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) == 1 assert len(op_kwargs) == 1 # name output_values = [jnp.transpose(input_value, **input_kwargs)] # nn.stochastic elif op_type == OpType.DROPOUT: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [ nn.Dropout(**op_kwargs)(input_value, deterministic=deterministic, **input_kwargs) ] # nn.pooling elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL: op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs ws = input_kwargs["window_shape"] if isinstance(ws, int): ws = [ws] * (input_value.ndim - 2) new_ws = [] for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]): if window_dim_shape == 0: new_ws.append(dim_shape) else: new_ws.append(window_dim_shape) input_kwargs["window_shape"] = tuple(new_ws) if "strides" in input_kwargs: s = input_kwargs["strides"] if isinstance(s, int): input_kwargs["strides"] = (s, ) * (input_value.ndim - 2) output_values = [op_fn(input_value, **input_kwargs)] elif op_type == OpType.MEAN: assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs output_values = [jnp.mean(input_value, **input_kwargs)] # new param elif op_type == OpType.PARAM: assert not input_values assert 0 < len(input_kwargs) <= 2 init_fn = input_kwargs.pop("init_fn") init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs) output_values = [self.param(op_name, init_fn_with_kwargs)] else: raise ValueError(f"op_type {op_type} not supported...") return output_values
def cartesian_prod(x, y): return jnp.stack([jnp.tile(x, len(y)), jnp.repeat(y, len(x))]).T
def naive_upsample_2d(x, factor=2): _N, H, W, C = x.shape x = jnp.reshape(x, [-1, H, 1, W, 1, C]) x = jnp.tile(x, [1, 1, factor, 1, factor, 1]) return jnp.reshape(x, [-1, H * factor, W * factor, C])
def __call__( self, encoding: Array, mention_batch_positions: Array, mention_start_positions: Array, mention_end_positions: Array, mention_mask: Array, memory_keys: Array, memory_values: Array, memory_mask: Array, memory_entity_ids: Array, deterministic: bool, ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]: """Perform attention update over memory table. Args: encoding: [batch_size, n_tokens, hidden_size] input representation. mention_batch_positions: [n_mentions] mention sample position in batch. mention_start_positions: [n_mentions] mention start position in input. mention_end_positions: [n_mentions] mention end position in input. mention_mask: [n_mentions] attention mask to prevent updates from padding. memory_keys: [memory_size, memory_key_dim] mention memory keys. memory_values: [memory_size, memory_value_dim] mention memory values. memory_mask: [memory_size] mask for valid mentions in memory. memory_entity_ids: [memory_size] mention memory entity ids. deterministic: don't apply dropout if true. Returns: Updated input, loss and logging helper dicts. """ loss_helpers, logging_helpers = {}, {} # We generate mention representations to use as queries for similarity # search by concatenating start and end tokens for each mention and # projecting the concatenation with a dense layer. mention_start_encodings = jut.matmul_2d_index_select( encoding, (mention_batch_positions, mention_start_positions)) mention_end_encodings = jut.matmul_2d_index_select( encoding, (mention_batch_positions, mention_end_positions)) queries = self.query_projector( jnp.concatenate((mention_start_encodings, mention_end_encodings), axis=-1)) n_queries = queries.shape[0] # For attention over entire memory table, we do not want to duplicate the # entire memory table for each query. Instead, we perform an # attention-weighted sum to produce a single value. We then feed this value # to the update layer as a set of retrieved values of size 1, with score 1. if self.k_top is None: loss_helpers['top_entity_ids'] = jnp.tile(memory_entity_ids, (n_queries, 1)) scores = jnp.einsum('qd,md->qm', queries, memory_keys) scores = scores - (1 - memory_mask) * _LARGE_NUMBER true_attention_weights = nn.softmax(scores, axis=-1) loss_helpers['memory_attention_weights'] = true_attention_weights top_values = jnp.einsum('qm,md->qd', true_attention_weights, memory_values) # Expand value as though it were a set of retrieved values for each query. # Shape (n_queries, 1, memory_value_dim) top_values = jnp.expand_dims(top_values, axis=1) # Generate pseudo-score (n_queries, 1). attention_weights = jnp.ones_like(top_values, shape=(n_queries, 1)) else: # Reshape memory keys for use in approximate top-k similarity layer. memory_keys = memory_keys.reshape(self.rows, -1, self.memory_key_dim) # We generate a version of the queries with stop gradient to use as input # to the topk similarity layer. We actually do want gradient to flow to # the queries, but backward differentiation over the topk layer yields # inefficient HLO ops. Instead we use queries with gradient to recompute # attention scores later. queries_sg = jax.lax.stop_gradient(queries) # Perform top-k similarity search over queries, yielding # top_values: (queries, k_top, memory_dim) # top_ids: (queries, k_top) top_keys, _, top_ids = self.topk_similarity(queries_sg, memory_keys) top_ids = top_ids.reshape(n_queries, self.k_top) top_values = memory_values[top_ids] loss_helpers['top_entity_ids'] = memory_entity_ids[top_ids] # We re-compute top scores using the queries with gradient (wg) to make # sure the query projector and the rest of the model receives gradient. top_scores_wg = jnp.einsum('qd,qkd->qk', queries, top_keys) top_mask = memory_mask[top_ids] top_scores_wg = top_scores_wg - (1 - top_mask) * _LARGE_NUMBER # We perform dot product attention using retrieved memory vectors as key, # dense projection of retrieved vectors as value and value and mention # representations as query. attention_weights = nn.softmax(top_scores_wg, axis=-1) loss_helpers['memory_attention_weights'] = attention_weights encoding = self.update_layer( encoded_input=encoding, retrieval_values=top_values, retrieval_scores=attention_weights, mention_batch_positions=mention_batch_positions, mention_start_positions=mention_start_positions, mention_end_positions=mention_end_positions, mention_mask=mention_mask, deterministic=deterministic, ) return encoding, loss_helpers, logging_helpers
def generate_triplets(key, inputs, n_inliers, n_outliers, n_random, weight_temp=0.5, distance='euclidean', verbose=False): """Generate triplets. Args: key: Random key. inputs: Input points. n_inliers: Number of inliers. n_outliers: Number of outliers. n_random: Number of random triplets per point. weight_temp: Temperature of the log transformation on the weights. distance: Distance type. verbose: Whether to print progress. Returns: triplets and weights """ n_points = inputs.shape[0] n_extra = min(n_inliers + 50, n_points) index = pynndescent.NNDescent(inputs, metric=distance) index.prepare() neighbors = index.query(inputs, n_extra)[0] neighbors = np.concatenate( (np.arange(n_points).reshape([-1, 1]), neighbors), 1) if verbose: logging.info('found nearest neighbors') distance_fn = get_distance_fn(distance) # conpute scaled neighbors and the scale parameter knn_distances, neighbors, sig = find_scaled_neighbors( inputs, neighbors, distance_fn) neighbors = neighbors[:, :n_inliers + 1] knn_distances = knn_distances[:, :n_inliers + 1] key, use_key = random.split(key) triplets = sample_knn_triplets(use_key, neighbors, n_inliers, n_outliers) weights = find_triplet_weights(inputs, triplets, neighbors[:, 1:n_inliers + 1], distance_fn, sig, distances=knn_distances[:, 1:n_inliers + 1]) flip = weights < 0 anchors, pairs = triplets[:, 0].reshape([-1, 1]), triplets[:, 1:] pairs = jnp.where(jnp.tile(flip.reshape([-1, 1]), [1, 2]), jnp.fliplr(pairs), pairs) triplets = jnp.concatenate((anchors, pairs), 1) if n_random > 0: key, use_key = random.split(key) rand_triplets, rand_weights = sample_random_triplets( use_key, inputs, n_random, distance_fn, sig) triplets = jnp.concatenate((triplets, rand_triplets), 0) weights = jnp.concatenate((weights, 0.1 * rand_weights)) weights -= jnp.min(weights) weights = tempered_log(1. + weights, weight_temp) return triplets, weights
def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None, inputs_kv=None): """Applies EncoderDecoder1DBlock module. Args: inputs: input data for decoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ cfg = self.config # Decoder block. assert inputs.ndim == 3 input_shape = list(inputs.shape) input_shape[-1] *= self.num_repeat if cfg.use_layernorm: x = nn.LayerNorm(dtype=cfg.dtype)(inputs) else: x = inputs if self.is_self_att: x = MultiDimSelfAttention(num_heads=cfg.num_heads, num_repeat=self.num_repeat, out_features=self.out_features, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)(x, decoder_mask) else: if cfg.use_layernorm: x_kv = nn.LayerNorm(dtype=cfg.dtype)(inputs_kv) else: x_kv = inputs_kv x = MultiDimMultiHeadDotProductAttention( num_heads=cfg.num_heads, num_repeat=self.num_repeat, out_features=self.out_features, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, decode=cfg.decode)(x, x_kv, mask=decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) inputs = jnp.expand_dims(inputs, -2) inputs = jnp.tile(inputs, self.to_tile_shape) x = x + inputs # MLP block. if cfg.use_layernorm: z = nn.LayerNorm(dtype=cfg.dtype)(x) else: z = x z = MlpBlock(config=cfg)(z) return jnp.reshape(x + z, input_shape)
def tmp_potential(geom, basis, charges): """ Build one electron integral arrays (overlap, kinetic, and potential integrals) """ coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis) nbf = get_nbf(basis) nprim = coeffs.shape[0] max_am = np.max(ams) A_vals = np.zeros(2 * max_am + 1) # Save various AM distributions for indexing # Obtain all possible primitive duet index combinations primitive_duets = cartesian_product(np.arange(nprim), np.arange(nprim)) with loops.Scope() as s: s.V = np.zeros((nbf, nbf)) s.a = 0 # center A angular momentum iterator s.b = 0 # center B angular momentum iterator for prim_duet in s.range(primitive_duets.shape[0]): p1, p2 = primitive_duets[prim_duet] coef = coeffs[p1] * coeffs[p2] aa, bb = exps[p1], exps[p2] atom1, atom2 = atoms[p1], atoms[p2] am1, am2 = ams[p1], ams[p2] A, B = geom[atom1], geom[atom2] ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2] gamma = aa + bb prefactor = np.exp(-aa * bb * np.dot(A - B, A - B) / gamma) P = (aa * A + bb * B) / gamma # Maximum angular momentum: hard coded # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi. # We need 2+max_am since kinetic requires incrementing angluar momentum by +2 PA_pow = np.power( np.broadcast_to(P - A, (max_am + 3, 3)).T, np.arange(max_am + 3)) PB_pow = np.power( np.broadcast_to(P - B, (max_am + 3, 3)).T, np.arange(max_am + 3)) # For potential integrals, we need the difference between # the gaussian product center P and ALL atoms in the molecule, # and then take all possible powers up to 2*max_am. # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed. # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3 P_minus_geom = np.broadcast_to(P, geom.shape) - geom Pgeom_pow = np.power( np.transpose( np.broadcast_to( P_minus_geom, (2 * max_am + 1, geom.shape[0], geom.shape[1])), (1, 2, 0)), np.arange(2 * max_am + 1)) # All possible np.dot(P-atom,P-atom) rcp2 = np.einsum('ij,ij->i', P_minus_geom, P_minus_geom) # All needed (and unneeded, for am < max_am) boys function evaluations boys_arg = np.broadcast_to(rcp2 * gamma, (2 * max_am + 1, geom.shape[0])) boys_nu = np.tile(np.arange(2 * max_am + 1), (geom.shape[0], 1)).T boys_eval = boys(boys_nu, boys_arg) s.a = 0 for _ in s.while_range(lambda: s.a < dims[p1]): s.b = 0 for _ in s.while_range(lambda: s.b < dims[p2]): # Gather angular momentum and index la, ma, na = angular_momentum_combinations[s.a + ld1] lb, mb, nb = angular_momentum_combinations[s.b + ld2] # To only create unique indices, need to have separate indices arrays for i and j. i = indices[p1] + s.a j = indices[p2] + s.b # Compute one electron integrals and add to appropriate index potential_int = potential(la, ma, na, lb, mb, nb, aa, bb, PA_pow, PB_pow, Pgeom_pow, boys_eval, prefactor, charges, A_vals) * coef s.V = jax.ops.index_add(s.V, jax.ops.index[i, j], potential_int) s.b += 1 s.a += 1 return s.V
def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ assert inputs_q.ndim == 3 and inputs_kv.ndim == 3 if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = partial(DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query', features=(self.num_repeat, self.num_heads, head_dim))(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) key = jnp.expand_dims(key, -3) value = jnp.expand_dims(value, -3) key = jnp.tile(key, self.to_tile_shape) value = jnp.tile(value, self.to_tile_shape) query = jnp.swapaxes(query, -3, -4) key = jnp.swapaxes(key, -3, -4) value = jnp.swapaxes(value, -3, -4) ''' query shape: (batch_size, num_repeat, query_seq_len, num_head, emb_dim) kv shape: (batch_size, num_repeat, kv_seq_len, num_head, emb_dim) ''' # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias attention_bias = lax.select( mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = self.attention_fn(query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) out = jnp.swapaxes(out, -2, -3) ''' swap out from (batch_size, num_repeat, seq_len, emb_dim) to (batch_size, seq_len, num_repeat, emb_dim) ''' return out
def predict(self, params, logits, context, target=None): context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context, axis=1), axis=1), axis=1) context_bias = params.get('context_bias', 0.0) context_index = (params['context_maps'] * context).sum(axis=-1) > context_bias context_map_values = jnp.asarray( [[[[1 << n for n in range(self.context_map_size)]]]]) context_index = jnp.where(context_index, context_map_values, 0) context_index = context_index.sum(axis=-1, keepdims=True) batch_size = logits.shape[0] class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)] for c in range(self.num_classes)]]) class_neuron_index = jnp.tile(class_neuron_index, reps=(batch_size, 1, 1, 1)) context_index = jnp.concatenate([class_neuron_index, context_index], axis=-1) dims = lax.GatherDimensionNumbers(offset_dims=(3, ), collapsed_slice_dims=(0, 1, 2), start_index_map=(0, 1, 2)) weights = lax.gather(operand=params['weights'], start_indices=context_index, dimension_numbers=dims, slice_sizes=(1, 1, 1, self.input_size + int(self.bias))) if self.bias: bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1)) logits = jnp.concatenate([logits, bias], axis=-1) logits = jnp.expand_dims(logits, axis=-1) output_logits = jnp.matmul(weights, logits) output_logits = jnp.clip(output_logits, a_min=jsp.special.logit(self.pred_clipping), a_max=jsp.special.logit(1.0 - self.pred_clipping)) if target is None: return jnp.squeeze(output_logits, axis=-1) else: logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2) output_preds = jnn.sigmoid(output_logits) target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1) params['lr_step'], learning_rate = self.learning_rate.value( params['lr_step']) delta = learning_rate * (target - output_preds) * logits dims = lax.ScatterDimensionNumbers( update_window_dims=(3, ), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) if self.weight_clipping is None: params['weights'] = lax.scatter_add( operand=params['weights'], scatter_indices=context_index, updates=delta, dimension_numbers=dims) else: weights = jnp.clip(weights + delta, a_min=-self.weight_clipping, a_max=self.weight_clipping) params['weights'] = lax.scatter(operand=params['weights'], scatter_indices=context_index, updates=weights, dimension_numbers=dims) return params, jnp.squeeze(output_logits, axis=-1)
def onnx_tile(a, b): return jnp.tile(a, b)
def __call__( self, pose_coeffs, betas=np.zeros(1), ): batch_size = pose_coeffs.shape[0] # Get axis angle from PCA components and coefficients # Remove global rot coeffs hand_pose_coeffs = pose_coeffs[:, self.rot:self.rot + self.ncomps] full_hand_pose = hand_pose_coeffs # Concatenate back global rot full_pose = np.concatenate( (pose_coeffs[:, :self.rot], self.hands_mean + full_hand_pose), 1) # compute rotation matrixes from axis-angle while skipping global rotation pose_map, rot_map = self._posemap_axisang(full_pose) root_rot = rot_map[:, :9].reshape(batch_size, 3, 3) rot_map = rot_map[:, 9:] pose_map = pose_map[:, 9:] # Full axis angle representation with root joint v_shaped = np.matmul(self.shapedirs, betas.transpose( (1, 0))).transpose((2, 0, 1)) + self.v_template j = np.matmul(self.J_regressor, v_shaped) # th_pose_map should have shape 20x135 v_posed = v_shaped + np.matmul( self.posedirs, pose_map.transpose((1, 0))[np.newaxis, ...]).transpose((2, 0, 1)) # Final T pose with transformation done ! # Global rigid transformation root_j = j[:, 0, :].reshape(batch_size, 3, 1) root_trans = self._with_zeros(np.concatenate((root_rot, root_j), 2)) all_rots = rot_map.reshape(rot_map.shape[0], 15, 3, 3) lev1_idxs = [1, 4, 7, 10, 13] lev2_idxs = [2, 5, 8, 11, 14] lev3_idxs = [3, 6, 9, 12, 15] lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]] lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]] lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]] lev1_j = j[:, lev1_idxs] lev2_j = j[:, lev2_idxs] lev3_j = j[:, lev3_idxs] # From base to tips # Get lev1 results all_transforms = [root_trans[:, np.newaxis, ...]] lev1_j_rel = lev1_j - root_j.transpose((0, 2, 1)) lev1_rel_transform_flt = self._with_zeros( np.concatenate((lev1_rots, lev1_j_rel[..., np.newaxis]), 3).reshape(-1, 3, 4)) root_trans_flt = np.tile(root_trans[:, np.newaxis, ...], (1, 5, 1, 1)).reshape(root_trans.shape[0] * 5, 4, 4) lev1_flt = np.matmul(root_trans_flt, lev1_rel_transform_flt) all_transforms.append(lev1_flt.reshape(all_rots.shape[0], 5, 4, 4)) # Get lev2 results lev2_j_rel = lev2_j - lev1_j lev2_rel_transform_flt = self._with_zeros( np.concatenate((lev2_rots, lev2_j_rel[..., np.newaxis]), 3).reshape(-1, 3, 4)) lev2_flt = np.matmul(lev1_flt, lev2_rel_transform_flt) all_transforms.append(lev2_flt.reshape(all_rots.shape[0], 5, 4, 4)) # Get lev3 results lev3_j_rel = lev3_j - lev2_j lev3_rel_transform_flt = self._with_zeros( np.concatenate((lev3_rots, lev3_j_rel[..., np.newaxis]), 3).reshape(-1, 3, 4)) lev3_flt = np.matmul(lev2_flt, lev3_rel_transform_flt) all_transforms.append(lev3_flt.reshape(all_rots.shape[0], 5, 4, 4)) reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15] results = np.concatenate(all_transforms, 1)[:, reorder_idxs] results_global = results joint_js = np.concatenate((j, np.zeros((j.shape[0], 16, 1))), 2) tmp2 = np.matmul(results, joint_js[..., np.newaxis]) results2 = (results - np.concatenate((np.zeros( (*tmp2.shape[:2], 4, 3)), tmp2), 3)).transpose((0, 2, 3, 1)) T = np.matmul(results2, self.weights.transpose((1, 0))) rest_shape_h = np.concatenate((v_posed.transpose( (0, 2, 1)), np.ones((batch_size, 1, v_posed.shape[1]))), 1) verts = (T * rest_shape_h[:, np.newaxis, ...]).sum(2).transpose( (0, 2, 1)) verts = verts[:, :, :3] jtr = results_global[:, :, :3, 3] # In addition to MANO reference joints we sample vertices on each finger # to serve as finger tips tips = verts[:, [745, 317, 444, 556, 673]] jtr = np.concatenate((jtr, tips), 1) # Reorder joints to match visualization utilities jtr = jtr[:, [ 0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20 ]] center_joint = jtr[:, self.center_idx][:, np.newaxis, ...] jtr = jtr - center_joint verts = verts - center_joint return verts, jtr, full_pose
def __call__(self, x, num_quantiles, rng): if self.net_conf == 'minatar': x = x.squeeze(3) x = x.astype(jnp.float32) x = nn.Conv(features=16, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = x.reshape((-1)) elif self.net_conf == 'atari': # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x.astype(jnp.float32) / 255. x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=self.initzer)(x) x = jax.nn.relu(x) x = x.reshape((-1)) # flatten elif self.net_conf == 'classic': #classic environments x = x.astype(jnp.float32) x = x.reshape((-1)) if self.env is not None and self.env in env_inf: x = x - env_inf[self.env]['MIN_VALS'] x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS'] x = 2.0 * x - 1.0 if self.noisy: def net(x, features, rng): return NoisyNetwork(features, rng=rng, bias_in=True)(x) else: def net(x, features, rng): return nn.Dense(features, kernel_init=self.initzer)(x) for _ in range(self.hidden_layer): x = net(x, features=self.neurons, rng=rng) x = jax.nn.relu(x) state_vector_length = x.shape[-1] state_net_tiled = jnp.tile(x, [num_quantiles, 1]) quantiles_shape = [num_quantiles, 1] quantiles = jax.random.uniform(rng, shape=quantiles_shape) quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim]) quantile_net = (jnp.arange(1, self.quantile_embedding_dim + 1, 1).astype(jnp.float32) * onp.pi * quantile_net) quantile_net = jnp.cos(quantile_net) quantile_net = nn.Dense(features=state_vector_length, kernel_init=self.initzer)(quantile_net) quantile_net = jax.nn.relu(quantile_net) x = state_net_tiled * quantile_net adv = net(x, features=self.num_actions, rng=rng) val = net(x, features=1, rng=rng) dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True))) non_dueling_q = net(x, features=self.num_actions, rng=rng) quantile_values = jnp.where(self.dueling, dueling_q, non_dueling_q) return atari_lib.ImplicitQuantileNetworkType(quantile_values, quantiles)
def _beam_search( self, input_ids: None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, ): """ This beam search function is heavily inspired by Flax's official example: https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 """ def flatten_beam_dim(tensor): """Flattens the first two dimensions of a non-scalar array.""" # ignore scalars (e.g. cache index) if tensor.ndim == 0: return tensor return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) def unflatten_beam_dim(tensor, batch_size, num_beams): """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" # ignore scalars (e.g. cache index) if tensor.ndim == 0: return tensor return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) def gather_beams(nested, beam_indices, batch_size, new_num_beams): """ Gathers the beam slices indexed by beam_indices into new beam array. """ batch_indices = jnp.reshape( jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams) ) def gather_fn(tensor): # ignore scalars (e.g. cache index) if tensor.ndim == 0: return tensor else: return tensor[batch_indices, beam_indices] return jax.tree_map(gather_fn, nested) # init values max_length = max_length if max_length is not None else self.config.max_length pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping batch_size, num_beams, cur_len = input_ids.shape eos_token_id = jnp.array(eos_token_id) pad_token_id = jnp.array(pad_token_id) cur_len = jnp.array(cur_len) # per batch,beam-item holding current token in loop. sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) # per batch,beam-item state bit indicating if sentence has finished. is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) # per batch,beam-item score, logprobs running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop # and pass it the `encoder_outputs`, which are part of the `model_kwargs`. model = self.decode if self.config.is_encoder_decoder else self # flatten beam dim if "encoder_outputs" in model_kwargs: model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( model_kwargs["encoder_outputs"]["last_hidden_state"] ) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) # initialize model specific kwargs model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) # initialize state state = BeamSearchState( cur_len=cur_len, running_sequences=running_sequences, running_scores=running_scores, sequences=sequences, scores=scores, is_sent_finished=is_sent_finished, model_kwargs=model_kwargs, ) def beam_search_cond_fn(state): """beam search state termination condition fn.""" # 1. is less than max length? not_max_length_yet = state.cur_len < max_length # 2. can the new beams still improve? best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty) worst_finished_score = jnp.where( state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) ) improvement_still_possible = jnp.all(worst_finished_score < best_running_score) # 3. is there still a beam that has not finished? still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) return not_max_length_yet & still_open_beam & improvement_still_possible def beam_search_body_fn(state): """beam search state update fn.""" # 1. Forward current tokens # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # unflatten beam dimension # Unflatten beam dimension in attention cache arrays input_token = flatten_beam_dim( lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1)) ) model_outputs = model(input_token, params=params, **state.model_kwargs) logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams) cache = jax.tree_map( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values ) # adapt logits for FlaxMarianMTModel logits = self._adapt_logits_for_beam_search(logits) # 2. Compute log probs # get log probabilities from logits, # process logits with processors (*e.g.* min_length, ...), and # add new logprobs to existing running logprobs scores. log_probs = jax.nn.log_softmax(logits) log_probs = logits_processor( flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len ) log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) vocab_size = log_probs.shape[2] log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) # 3. Retrieve top-K # Each item in batch has num_beams * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. # Gather the top 2*K scores from _all_ beams. # Gather 2*k top beams. # Recover the beam index by floor division. # Recover token id by modulo division and expand Id array for broadcasting. # Update sequences for the 2*K top-k new sequences. beams_to_keep = 2 * num_beams topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) topk_beam_indices = topk_indices // vocab_size topk_running_sequences = gather_beams( state.running_sequences, topk_beam_indices, batch_size, beams_to_keep ) topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) # 4. Check which sequences have ended # Update current sequences: # Did any of these sequences reach an end marker? # To prevent these just finished sequences from being added to the current sequences # set of active beam search sequences, set their log probs to a very large # negative value. did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) # 5. Get running sequences scores for next # Determine the top k beam indices (from top 2*k beams) from log probs # and gather top k beams (from top 2*k beams). next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1) next_running_sequences, next_running_scores = gather_beams( [topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams ) # 6. Process topk logits # Further process log probs: # - add length penalty # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty) beams_in_batch_are_full = ( jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) & early_stopping ) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += add_penalty * np.array(-1.0e7) # 7. Get scores, sequences, is sentence finished for next. # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams ) # 8. Update model kwargs. # Determine the top k beam indices from the original set of all beams. # With these, gather the top k beam-associated caches. next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) return BeamSearchState( cur_len=state.cur_len + 1, running_scores=next_running_scores, running_sequences=next_running_sequences, scores=next_scores, sequences=next_sequences, is_sent_finished=next_is_sent_finished, model_kwargs=next_model_kwargs, ) # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU state = beam_search_body_fn(state) if not trace: state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) else: state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return running sequences for that batch item. none_finished = jnp.any(state.is_sent_finished, axis=1) sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) # take best beam for each batch sequences = sequences[:, -1] scores = scores[:, -1] return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
def init_fn(position: Array) -> PHopState: position_buffer = jnp.tile(position, (window_size, 1, 1)) assert position_buffer.shape == ((window_size,) + position.shape) return PHopState(position_buffer, jnp.zeros((position.shape[0],))) # pytype: disable=wrong-arg-count