def _Sources(self): p = self.params if p.cur_iter_in_seed: self._cur_iter += 1 if p.integer_source_max: inputs = tf.random_uniform(p.source_shape, maxval=p.integer_source_max, dtype=tf.int32, seed=p.random_seed + 1000 * self._cur_iter) elif p.float_source_max: inputs = tf.random_uniform(p.source_shape, maxval=p.float_source_max, seed=p.random_seed + 1000 * self._cur_iter) else: inputs = tf.random_normal(p.source_shape, seed=p.random_seed + 1000 * self._cur_iter) paddings = tf.cast( tf.cumsum( tf.random_uniform(p.source_shape[:2], seed=p.random_seed + 1001 * self._cur_iter), axis=1) > 0.5 * p.source_shape[1], tf.float32) paddings = self._check_paddings(paddings) return inputs, paddings
def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask, dec_state, t): del tgt_pos, tgt_segment_id [buf] = dec_state if tgt_id.shape == (self.batch_size, self.beam_size): buf = inplace_ops.alias_inplace_update(buf, t, tgt_id) else: div = int(tgt_id.shape[1] // self.beam_size) for i, x_i in enumerate(tf.split(tgt_id, div, 1)): buf = inplace_ops.alias_inplace_update(buf, t + i, x_i) buf1 = tf.transpose(buf, [1, 0, 2]) buf1 = tf.reshape(buf1, [self.batch_size, self.max_steps * self.beam_size]) # select next_tgt_id as a function of previous target tokens if self.rule == '+1': next_tgt_id = (tgt_id + 1) next_tgt_id %= self.vocab_size elif self.rule == 'sum': # sum over all previous tokens in tgt_mask next_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(tgt_mask, tf.int32)) next_tgt_id %= self.vocab_size elif self.rule == 'fib': # select last token according to tgt_mask m = tgt_mask m *= tf.cast( tf.equal(tf.cumsum(m, -1), tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype) last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32)) next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size # with a lower probably add extra +1 to the correct next_tgt_id n = self.vocab_size logits = 5 * tf.one_hot(next_tgt_id % n, n) logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n) logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n) logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n) logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n) # increase eos_score if current tgt_id contains 9 eos_id = 0 tgt_id_contains_9 = tf.logical_or(tf.equal(tgt_id % 10, 9), tf.equal((tgt_id // 10) % 10, 9)) logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot( eos_id, self.vocab_size), tf.cast(tgt_id_contains_9, tf.float32)) # tie-breaking -- lower token id wins a little bit tie = np.arange(0., 1., 1. / n) tie /= tie.sum() logits -= tie logits = tf.nn.log_softmax(logits) dec_state = [buf] return logits, dec_state
def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype)
def _Targets(self, target_shape): p = self.params if p.cur_iter_in_seed: self._cur_iter += 1 random_seed = p.random_seed * 2000 * self._cur_iter if p.fixed_target_ids is None: tids = tf.cast( tf.random_uniform(target_shape, seed=random_seed) * p.tokenizer.vocab_size, tf.int32) else: tids = p.fixed_target_ids assert tids.shape_as_list() == target_shape if p.fixed_target_labels is None: tlabels = tf.cast( tf.random_uniform(target_shape, seed=random_seed + 1) * p.tokenizer.vocab_size, tf.int32) tpaddings = tf.cast( tf.cumsum(tf.random_uniform( target_shape[:2], seed=p.random_seed + 1001 * self._cur_iter), axis=1) > 0.4 * target_shape[1], tf.float32) tpaddings = self._check_paddings(tpaddings) else: tlabels = p.fixed_target_labels assert tlabels.shape_as_list() == target_shape tpaddings = tf.constant(0.0, shape=target_shape) tweights = 1.0 - tpaddings d = { 'ids': tids, 'labels': tlabels, 'weights': tweights, 'paddings': tpaddings } if not p.for_mt: d['transcripts'] = tf.constant(p.target_transcript, shape=[target_shape[0]]) if p.align_label_with_frame: source_len = p.source_shape[1] d['alignments'] = tf.cast( tf.random_uniform(target_shape, seed=p.random_seed) * source_len, tf.int32) return d
def _ComputePaddings(ids, eos_id): is_eos = tf.to_int32(tf.equal(ids, eos_id)) # eos_in_prefix[i, j] = any(ids[i, k] == eos_id for k in range(j)) eos_in_prefix = tf.cumsum(is_eos, axis=-1, exclusive=True) return tf.where( tf.equal(eos_in_prefix, 0), tf.zeros_like(ids), tf.ones_like(ids))
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def Top2GatingOnLogits(inputs, paddings, logits, num_devices, experts_dim, expert_capacity_dim, fprop_dtype, use_xla_sharding=True, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True, capacity_factor=None): """Computes Top-2 gating for Mixture-of-Experts. There are two expected usages of this function: 1. used with xla_sharding. In this case, 'inputs' corresponds to a sharded tensor across multiple tpu cores. The operations within this function are automatically sharded/replicated across tpu cores. 2. used within ML-Pathways. In this case, 'inputs' is always local to one tpu core. All computations below are carried out on one tpu core only. This function tries to dispatch examples across tpu cores in such a way that each expert is assigned no more than 'expert_capacity_dim' number of examples. Below ` indicates common way of splitting along mesh dimension. Dimensions cheat sheet: G: group_dim S: group_size_dim E: number of experts C: capacity per expert M: model_dim (same as input_dim, same as output_dim) B: original batch_dim L: original sequence_length_dim Note that for local_dispatch original batch BLM is reshaped into GSM, each group `g = 0...G-1` is being dispatched independently. Args: inputs: G`SM Tensor. paddings: G`S Tensor. logits: G`SE Tensor. num_devices: number of MoE devices for local dispatch experts_dim: number of experts. expert_capacity_dim: number of examples per minibatch(group) per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output. fprop_dtype: activations datatype to use. use_xla_sharding: bool, True if this function is used for the xla_sharding case. second_expert_policy: 'all', 'sampling' or 'random'. - 'all': we greedily pick the 2nd expert. - 'sampling': we sample the 2nd expert from the softmax. - 'random': we optionally 'random'-ize dispatch to second-best expert proportional to (weight / second_expert_threshold). second_expert_threshold: threshold for probability normalization for second_expert_policy == 'random'. legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly. capacity_factor: if set, increases expert_capacity_dim to at least (group_size * capacity_factor) / experts_dim where `group_size` is the size of G dimension of `inputs`. If the value of expert_capacity_dim is already big enough no change is made. TODO(lepikhin): get rid of the legacy_mtf_behavior flag. Returns: A tuple (aux_loss, combine_tensor, dispatch_tensor). - aux_loss: auxiliary loss, for equalizing the expert assignment ratios. - combine_tensor: G`SEC Tensor for combining expert outputs. - dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to experts. """ del inputs # inputs is currently not used. raw_gates = tf.nn.softmax(logits) # along E dim if capacity_factor is not None: # Determine expert capacity automatically depedning on the input size. group_size_dim = int(logits.shape[1]) auto_expert_capacity = int((group_size_dim * capacity_factor) / experts_dim) if expert_capacity_dim < auto_expert_capacity: expert_capacity_dim = auto_expert_capacity # Round up to a multiple of 4 to avoid possible padding. while expert_capacity_dim % 4: expert_capacity_dim += 1 tf.logging.info( 'Setting expert_capacity_dim=%r (capacity_factor=%r ' 'group_size_dim=%r experts_dim=%r name_scope=%r)', expert_capacity_dim, capacity_factor, group_size_dim, experts_dim, tf.get_default_graph().get_name_scope()) tpu_summary.scalar('expert_capacity', expert_capacity_dim) # top first and second gate value and expert index for each input # # GSK Tensors, K=2 def _MaybeSplit(x): if use_xla_sharding: return Split(x, 0, num_devices) else: return x def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name): over_capacity = tf.reduce_sum( tf.cast( tf.greater_equal(mask * position_in_expert, capacity), mask.dtype)) over_capacity_ratio = over_capacity / tf.reduce_sum(mask) py_utils.AddTpuSummaryTensor(name, over_capacity_ratio) tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean') # As pointed out by zhifengc@ this method needs to be refactored. lepikhin@ # and krikun@ will: # - expand moe_spmd_test to compare Adafactor updates, slots on TPU # including 2x2 with sharding # # - add more tests for policy="random" # # - add single step test for full size WMT model on CPU # # and then break this function into modules. # # GS index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32) index_1 = _MaybeSplit(index_1) tpu_summary.tensor('index_1', index_1) # GSE mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype) mask_1 = _MaybeSplit(mask_1) density_1_proxy = raw_gates importance = tf.ones_like(mask_1[:, :, 0]) if paddings is not None: importance = 1.0 - paddings mask_1 *= tf.expand_dims(importance, -1) density_1_proxy *= tf.expand_dims(importance, -1) gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1) gates_without_top_1 = raw_gates * (1.0 - mask_1) if second_expert_policy == 'sampling': # We directly sample the 2nd expert index from the softmax over of the 2nd # expert by getting rid of the 1st expert already selected above. To do so, # we set a very negative value to the logit corresponding to the 1st expert. # Then we sample from the softmax (categorical) distribution using the # Gumbel max trick. noise = _MaybeSplit(tf.random.uniform(logits.shape, dtype=logits.dtype)) # Generates standard Gumbel(0, 1) noise, GSE Tensors noise = -tf.math.log(-tf.math.log(noise)) very_negative_logits = _MaybeSplit( (tf.ones_like(logits) * logits.dtype.max * tf.constant(-0.7, dtype=logits.dtype))) # Gets rid of the first expert by setting its logit to be very negative updated_logits = _MaybeSplit( tf.where(mask_1 > 0.0, very_negative_logits, logits)) # Adds the Gumbel noise to the updated logits noised_logits = _MaybeSplit(updated_logits + noise) # Picks the index of the largest noised logit as the 2nd expert. This is # equivalent to sampling from the softmax over the 2nd experts. index_2 = tf.math.argmax(noised_logits, axis=-1, output_type=tf.int32) else: index_2 = tf.math.argmax(gates_without_top_1, axis=-1, output_type=tf.int32) index_2 = _MaybeSplit(index_2) mask_2 = tf.one_hot(index_2, experts_dim, dtype=fprop_dtype) mask_2 = _MaybeSplit(mask_2) if paddings is not None: mask_2 *= tf.expand_dims(importance, -1) gate_2 = tf.einsum('GSE,GSE->GS', gates_without_top_1, mask_2) if legacy_mtf_behavior: # cl/298510175 moved this branch for gate_{1,2} denom calculation here. # # For policy=random, it's better to nomalize gate_{1,2} before taking # capacity into account and before potentially dropping second expert. # # According to mean_xent (http://short/_NzbZ5rINr5): # MoE_512_102xen_PolicyAll_298510175 # MoE_512_102xen_PolicyRandom_298510175 # # vs pre-cl/298510175 # MoE_512_102xen_PolicyRandom # MoE_512_102xen_PolicyAll # # it substantially improves policy=random with threshold=0.5 which # historically was better than policy="all" # # Also confirmed this by decoding # nmt_train/m4/data/es_en/test.txt # nmt_train/m4/data/ru_en/test.txt # nmt_train/m4/data/zh_en/test.txt # and improving BLEU # # moe_decode.MoE_512_102xen_PolicyRandom_298510175-160000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102 # 0.421443 # 0.327102 # 0.315693 # vs # moe_decode.feb18_non_fig_snapshot_2626_MoE_512_102xen_PolicyRandom-190000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102 # 0.399232 # 0.310606 # 0.288229 # # Additional comparison, see mean_xent http://short/_YHccOhQtdu with # legacy_mtf_behavior=False models # 3 - MoE_512_102xen_PolicyAll_LegacyFalse # 6 - MoE_512_102xen_PolicyRandom_LegacyFalse # shows that policy="random" gets worse with legacy_mtf_behavior=False, and # is similar to pre-cl/298510175 # 4 - MoE_512_102xen_PolicyRandom # # gate_1 can become 0 due to Expert being out of capacity. # # gate_2 can become 0 due to # second_expert_policy == 'random' # or "out of capacity" scenario. # # Here we renormalize regardless of cases above. denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # We reshape the mask as [X*S, E], and compute cumulative sums of # assignment indicators for each expert index e \in 0..E-1 independently. # First occurrence of assignment indicator is excluded, see exclusive=True # flag below. position_in_expert_1 = tf.cumsum(mask_1, exclusive=True, axis=1) # GS Tensor capacity = tf.cast(expert_capacity_dim, dtype=position_in_expert_1.dtype) # GE Tensor (reducing S out of GSE tensor mask_1) # density_1[:, e] represents assignment ratio (num assigned / total) to # expert e as top_1 expert without taking capacity into account. if legacy_mtf_behavior: density_denom = 1.0 else: density_denom = tf.reduce_mean( importance, axis=(1))[:, tf.newaxis] + 1e-6 density_1 = tf.reduce_mean(mask_1, axis=(1)) / density_denom # density_1_proxy[:, e] represents mean of raw_gates for expert e, including # those of examples not assigned to e with top_k. density_1_proxy = tf.reduce_mean(density_1_proxy, axis=1) / density_denom # The MoE paper (https://arxiv.org/pdf/1701.06538.pdf) uses an aux loss of # reduce_mean(density_1_proxy * density_1_proxy). Here we replace one of # the density_1_proxy with the discrete density_1 following # mesh_tensorflow/transformer/moe.py?rcl=283569345. aux_loss = tf.reduce_mean(density_1_proxy * density_1) # element-wise aux_loss *= experts_dim * experts_dim # const coefficient # Add the over capacity ratio for expert 1 _CreateOverCapacityRatioSummary(mask_1, position_in_expert_1, capacity, 'over_capacity_1_ratio') mask_1 *= tf.cast(tf.less(position_in_expert_1, capacity), dtype=mask_1.dtype) position_in_expert_1 = tf.einsum('GSE,GSE->GS', position_in_expert_1, mask_1) # How many examples in this sequence go to this expert mask_1_count = tf.einsum('GSE->GE', mask_1) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = tf.einsum('GSE->GS', mask_1) if second_expert_policy == 'all' or second_expert_policy == 'sampling': pass elif second_expert_policy == 'random': # gate_2 is between 0 and 1, reminder: # # raw_gates = tf.nn.softmax(logits) # index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32) # mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype) # gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1) # # E.g. if gate_2 exceeds second_expert_threshold, then we definitely # dispatch to second-best expert. Otherwise we dispatch with probability # proportional to (gate_2 / threshold). # sampled_2 = tf.less( _MaybeSplit(tf.random.uniform(gate_2.shape, dtype=gate_2.dtype)), (gate_2 / max(second_expert_threshold, 1e-9))) gate_2 *= tf.cast(sampled_2, gate_2.dtype) mask_2 *= tf.cast(tf.expand_dims(sampled_2, -1), mask_2.dtype) else: raise ValueError(second_expert_policy) position_in_expert_2 = tf.cumsum( mask_2, exclusive=True, axis=1) + tf.expand_dims(mask_1_count, 1) # Add the over capacity ratio for expert 2 _CreateOverCapacityRatioSummary(mask_2, position_in_expert_2, capacity, 'over_capacity_2_ratio') mask_2 *= tf.cast(tf.less(position_in_expert_2, capacity), mask_2.dtype) position_in_expert_2 = tf.einsum('GSE,GSE->GS', position_in_expert_2, mask_2) mask_2_flat = tf.reduce_sum(mask_2, axis=-1) # Equivalent non-einsum implementation: # # position_in_expert_2 *= mask_2 # position_in_expert_2 = tf.reduce_sum( # position_in_expert_2, axis=-1, name='position_in_expert_2') gate_1 *= mask_1_flat gate_2 *= mask_2_flat if not legacy_mtf_behavior: denom = gate_1 + gate_2 # To avoid divide by 0. denom = tf.where(denom > 0, denom, tf.ones_like(denom)) gate_1 /= denom gate_2 /= denom # GSC Tensor b = tf.one_hot( tf.cast(position_in_expert_1, dtype=tf.int32), expert_capacity_dim, dtype=fprop_dtype, name='one_hot_b_0') # GSE Tensor a = tf.expand_dims(gate_1 * mask_1_flat, -1) * tf.one_hot( index_1, experts_dim, dtype=fprop_dtype) # GSEC Tensor first_part_of_combine_tensor = tf.einsum( 'GSE,GSC->GSEC', a, b, name='first_part_of_combine_tensor') # GSC Tensor b = tf.one_hot( tf.cast(position_in_expert_2, dtype=tf.int32), expert_capacity_dim, dtype=fprop_dtype, name='one_hot_b_1') # GSE Tensor a = tf.expand_dims(gate_2 * mask_2_flat, -1) * tf.one_hot( index_2, experts_dim, dtype=fprop_dtype) second_part_of_combine_tensor = tf.einsum( 'GSE,GSC->GSEC', a, b, name='second_part_of_combine_tensor') # GSEC Tensor combine_tensor = ( first_part_of_combine_tensor + second_part_of_combine_tensor) combine_tensor = _MaybeSplit(combine_tensor) # GSEC Tensor dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), fprop_dtype) dispatch_tensor = _MaybeSplit(dispatch_tensor) # TODO(yonghui): compute and return per-group aux_loss. return aux_loss, combine_tensor, dispatch_tensor