def replace_jax_transformer_ffwd_vars_to_tf( jax_initial_vars: NestedMap) -> NestedMap: """Replaces JAX TransformerFeedForward vars to TF compatible vars. Args: jax_initial_vars: JAX TransformerFeedforward layer vars. Returns: tf_initial_vars which is TF compatible. """ tf_initial_vars = jax_initial_vars.copy() tf_initial_vars = py_utils.NestedMap() tf_initial_vars.fflayer = py_utils.NestedMap( fc=[ py_utils.NestedMap(w=jax_initial_vars.ffn_layer1.linear.w, b=jax_initial_vars.ffn_layer1.bias.b), py_utils.NestedMap(w=jax_initial_vars.ffn_layer2.linear.w, b=jax_initial_vars.ffn_layer2.bias.b), ], dropout=[py_utils.NestedMap(), py_utils.NestedMap()], ) tf_initial_vars.layer_norm = py_utils.NestedMap( bias=jax_initial_vars.layer_norm.bias, scale=jax_initial_vars.layer_norm.scale) tf_initial_vars.residual_dropout = py_utils.NestedMap() tf_initial_vars.residual_droppath = py_utils.NestedMap() return tf_initial_vars
def test_extract_prefixed_keys_from_dataclass(self): @struct.dataclass class GlobalShardedParameterStats: statistics: np.ndarray # Statistics preconditioners: np.ndarray # Preconditioners exponents: np.ndarray # exponents index_start: int = struct.field(pytree_node=False) sizes: Any = struct.field(pytree_node=False) stats0 = GlobalShardedParameterStats( statistics=np.array([0], dtype=np.float32), preconditioners=np.array([1, 1], dtype=np.float32), exponents=np.array([2, 2, 2], dtype=np.float32), index_start=0, sizes=0, ) # Even though the `preconditioners` is first here, the order is decided # by the order in `GlobalShardedParameterStats` class. stats1 = GlobalShardedParameterStats( preconditioners=np.array([5, 5], dtype=np.float32), statistics=np.array([4], dtype=np.float32), exponents=np.array([6, 6, 6], dtype=np.float32), index_start=1, sizes=1, ) nested_data = py_utils.NestedMap(stats0=stats0, stats1=stats1) nested_names = py_utils.extract_prefixed_keys_from_nested_map( nested_data) flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names) self.assertListEqual([ 'stats0/statistics', 'stats0/preconditioners', 'stats0/exponents', 'stats1/statistics', 'stats1/preconditioners', 'stats1/exponents' ], flattened_nested_names)
def test_feedforward_layer_no_bias(self, activation): p = linears.FeedForward.Params().Set( name='jax_ffn', input_dims=3, output_dims=20, has_bias=False, activation=activation) ffn = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = ffn.instantiate_variables(prng_key) npy_input = np.random.normal(1.0, 0.5, [10, 10, p.input_dims]).astype('float32') inputs = jnp.asarray(npy_input) outputs = test_utils.apply(ffn, initial_vars, ffn.fprop, inputs) logging.info('initial_vars in ffn = %s', initial_vars) # Test whether tf projection layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = py_utils.NestedMap() tf_initial_vars.w = initial_vars.linear.w tf_initial_vars = to_tf_nmap(tf_initial_vars) tf_p = lingvo_layers.ProjectionLayer.Params().Set( name='tf_ffn', input_dim=p.input_dims, output_dim=p.output_dims, batch_norm=False, has_bias=False, activation=activation) tf_ffn = tf_p.Instantiate() tf_output = tf_ffn.FProp(tf_initial_vars, tf.constant(inputs, dtype=tf.float32)) np_outputs = to_np(outputs) tf_np_outputs = to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-6)
def replace_jax_single_shard_full_softmax_vars_to_tf( jax_initial_vars: NestedMap) -> NestedMap: """Replaces JAX Single Shard Full Softmax vars to TF compatible vars. Args: jax_initial_vars: JAX ConvBNAct layer vars. Returns: tf_initial_vars which is TF compatible with ConvBNAct. """ tf_initial_vars = jax_initial_vars.copy() tf_initial_vars.linear = py_utils.NestedMap( w=jax_initial_vars.logits_ffn.linear.w) tf_initial_vars.bias = py_utils.NestedMap( b=jax_initial_vars.logits_ffn.bias.b) del tf_initial_vars.logits_ffn return tf_initial_vars
def replace_jax_attention_vars_to_tf( jax_initial_vars: NestedMap, cross_attention: Optional[bool] = False) -> NestedMap: """Replaces JAX attention vars to TF compatible vars. Args: jax_initial_vars: JAX attention layer vars. cross_attention: Whether cross attention is involved. Returns: tf_initial_vars which is TF compatible. """ tf_initial_vars = jax_initial_vars.copy() tf_initial_vars.fflayer = jax_initial_vars.ff_layer tf_initial_vars.fflayer.fflayer = py_utils.NestedMap() is_moe = 'gate' in jax_initial_vars.ff_layer if is_moe: tf_initial_vars.fflayer.fflayer.layer_norm = py_utils.NestedMap() tf_initial_vars.fflayer.fflayer.layer_norm.scale = jax_initial_vars.ff_layer.layer_norm.scale tf_initial_vars.fflayer.fflayer.layer_norm.bias = jax_initial_vars.ff_layer.layer_norm.bias tf_initial_vars.fflayer.fflayer.gate = jax_initial_vars.ff_layer.gate tf_initial_vars.fflayer.fflayer.wi_0 = jax_initial_vars.ff_layer.wi_0 tf_initial_vars.fflayer.fflayer.wo_0 = jax_initial_vars.ff_layer.wo_0 else: tf_initial_vars.fflayer.fflayer.dropout = [1.0, 1.0] tf_initial_vars.fflayer.fflayer.fc = [NestedMap(), NestedMap()] tf_initial_vars.fflayer.fflayer.fc[0].w = ( jax_initial_vars.ff_layer.ffn_layer1.linear.w) tf_initial_vars.fflayer.fflayer.fc[0].b = ( jax_initial_vars.ff_layer.ffn_layer1.bias.b) tf_initial_vars.fflayer.fflayer.fc[1].w = ( jax_initial_vars.ff_layer.ffn_layer2.linear.w) tf_initial_vars.fflayer.fflayer.fc[1].b = ( jax_initial_vars.ff_layer.ffn_layer2.bias.b) tf_initial_vars.self_atten = NestedMap() tf_initial_vars.self_atten.layer_norm = jax_initial_vars.layer_norm tf_initial_vars.self_atten.atten = jax_initial_vars.self_attention tf_initial_vars.self_atten.residual_dropout = 1.0 if cross_attention: tf_initial_vars.cross_atten = NestedMap() tf_initial_vars.cross_atten.layer_norm = jax_initial_vars.layer_norm tf_initial_vars.cross_atten.atten = jax_initial_vars.cross_attention tf_initial_vars.cross_atten.residual_dropout = 1.0 return tf_initial_vars
def test_LSTMSimple(self, jax_cell_class, cifg, output_nonlinearity): np.random.seed(_NUMPY_RANDOM_SEED) inputs = py_utils.NestedMap(act=[np.random.uniform(size=(3, 2))], padding=jnp.zeros([3, 1])) state0 = py_utils.NestedMap(c=np.random.uniform(size=(3, 2)), m=np.random.uniform(size=(3, 2))) tf_inputs = py_utils.NestedMap( act=[tf.constant(inputs.act[0], tf.float32)], padding=tf.zeros([3, 1])) tf_state0 = py_utils.NestedMap(c=tf.constant(state0.c, tf.float32), m=tf.constant(state0.m, tf.float32)) params = rnn_cell.LSTMCellSimple.Params().Set( name='lstm', params_init=py_utils.WeightInit.Uniform(1.24, _INIT_RANDOM_SEED), bias_init=py_utils.WeightInit.Uniform(1.24, _INIT_RANDOM_SEED), num_input_nodes=2, num_output_nodes=2, couple_input_forget_gates=cifg, enable_lstm_bias=True, output_nonlinearity=output_nonlinearity) lstm = rnn_cell.LSTMCellSimple(params) res, _ = lstm.FPropDefaultTheta(tf_state0, tf_inputs) m_expected = res.m.numpy() c_expected = res.c.numpy() p = jax_cell_class.Params().Set( num_input_nodes=2, num_output_nodes=2, name='lstm', output_nonlinearity=output_nonlinearity, ) model = p.Instantiate() theta = model.instantiate_variables(jax.random.PRNGKey(5678)) theta.wm = lstm.vars['wm'].numpy() theta.b = lstm.vars['b'].numpy() output, _ = test_utils.apply(model, model.vars_to_flax_vars(theta), model.fprop, state0, inputs) self.assertAllClose(m_expected, output.m) self.assertAllClose(c_expected, output.c)
def _parse_record(self, record) -> NestedMap: """Reads and parses a single record.""" p = self.params name_to_features = { 'input_ids': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'masked_lm_positions': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_ids': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_weights': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32), } example = tf.io.parse_single_example(record, name_to_features) mask_length = tf.cast( tf.reduce_sum(example['masked_lm_weights']), dtype=tf.int32) masked_lm_positions = tf.slice(example['masked_lm_positions'], [0], [mask_length]) masked_lm_ids = tf.cast( tf.slice(example['masked_lm_ids'], [0], [mask_length]), dtype=tf.int32) ret = py_utils.NestedMap() ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32) # Get back non-masked, original ids. ret.labels = tf.tensor_scatter_nd_update( tensor=ret.masked_ids, indices=tf.reshape(masked_lm_positions, [-1, 1]), updates=masked_lm_ids) ret.masked_pos = tf.tensor_scatter_nd_update( tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32), indices=tf.reshape(masked_lm_positions, [-1, 1]), updates=tf.ones_like(masked_lm_ids, dtype=tf.float32)) ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32) first_eos_idx = tf.where(tf.math.equal(ret.labels, p.eos_token_id))[0][0] def remove_first_eos(x): # We remove the element at position `first_eos_idx`, and pad with 0 # to keep length unchanged. zero = tf.constant(0, shape=(1,), dtype=x.dtype) return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero], axis=0) ret = ret.Transform(remove_first_eos) ret.paddings = 1.0 - ret.segment_ids pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32) ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32) if p.remask: new_masked_ids, new_masked_pos = self.mlm.FProp(None, ret.labels, ret.paddings) ret.masked_ids = new_masked_ids ret.masked_pos = new_masked_pos return ret
def _scan_fn(layer_in, layer_vars): jax_context = base_layer.cur_jax_context() flax_variables = self.sub.vars_to_flax_vars(layer_vars) # properly setup scope. jax_context.bind(self.sub, flax_variables, [base_layer.SCOPE_AUX_LOSS]) layer_out, extra = fprop_fn(self.sub, layer_in.carry, *args, **kwargs) tf.nest.assert_same_structure(layer_in.carry, layer_out) return NestedMap(carry=layer_out), py_utils.NestedMap(extra=extra)
def _all_paddings_batch(self) -> NestedMap: p = self.params shape = [p.batch_size, p.max_sequence_length] ret = py_utils.NestedMap() ret.labels = tf.zeros(shape, dtype=tf.int32) ret.masked_ids = ret.labels ret.segment_pos = ret.labels ret.masked_pos = tf.zeros(shape, dtype=tf.float32) ret.segment_ids = ret.masked_pos ret.paddings = 1.0 - ret.segment_ids return ret
def extend_step(self, extend_fn, cached_states: NestedMap, step_inputs: NestedJTensor, *args: Any, **kwargs: Any) -> Any: """Extends decoder states by one step. extend_fn should have the following signature. extended_states, step_out = extend_fn(self.sub, states, step_input, *args, **kwargs) extended_states should have the same structure as states step_out should have the same structure as step_input Args: extend_fn: fn to extend cached_states for one step. It should be of the expected signature as described above. cached_states: The combined states for all sub-layers. step_inputs: Input to the bottom decoder layer. *args: Additional positional input. **kwargs: Additional keyword input. Returns: new_states, top_decoder_out, where new_states is the updated decoder states, and top_decoder_out is the output from the top decoder layer. """ # Wrap inputs in a NestedMap to conform to recurrent.scan interface. step_inputs_mp = NestedMap(carry=step_inputs) def _scan_fn(layer_in, vars_and_states): layer_vars = vars_and_states.layer_vars layer_states = vars_and_states.layer_states # Properly setup context. jax_context = base_layer.cur_jax_context() flax_variables = self.sub.vars_to_flax_vars(layer_vars) jax_context.bind(self.sub, flax_variables, [base_layer.SCOPE_AUX_LOSS]) extended_states, layer_out = extend_fn(self.sub, layer_states, layer_in.carry, *args, **kwargs) tf.nest.assert_same_structure(layer_in.carry, layer_out) tf.nest.assert_same_structure(extended_states, layer_states) return NestedMap(carry=layer_out), extended_states vars_and_states = py_utils.NestedMap(layer_vars=self.sub.local_theta(), layer_states=cached_states) final_out, new_states, summaries = recurrent.scan(step_inputs_mp, vars_and_states, _scan_fn, root_layer=self) # forward summaries to the out-context. self._forward_summary(summaries) return new_states, final_out.carry
def compute_loss(self, predictions: NestedMap, input_batch: NestedMap) -> Tuple[Metrics, Dict[str, Any]]: """Computes the loss and other metrics for the given predictions. Args: predictions: The output of `compute_predictions`. input_batch: A `.NestedMap` object containing input tensors to this tower. Returns: - A dict or NestedMap containing str keys and (metric, weight) pairs as values, where one of the entries is expected to corresponds to the loss. - A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. """ labels = input_batch.labels num_tokens = jnp.sum(1.0 - input_batch.paddings.astype(jnp.float32)) num_seqs = jnp.sum( jnp.amax(input_batch.segment_ids.astype(jnp.float32), axis=1)) weights = predictions.augmented_pos.astype(jnp.float32) predicted_labels = predictions.per_example_argmax.astype(labels.dtype) num_preds = predictions.total_weight.astype(jnp.float32) mean_acc = jnp.sum( (labels == predicted_labels) * weights) / jnp.maximum( num_preds, 1) metric_weight = jnp.array(num_preds, predictions.avg_xent.dtype) metrics = py_utils.NestedMap( total_loss=(predictions.total_loss, metric_weight), avg_xent=(predictions.avg_xent, metric_weight), aux_loss=(predictions.aux_loss, metric_weight), log_pplx=(predictions.avg_xent, metric_weight), fraction_of_correct_preds=(mean_acc, jnp.array(num_preds, mean_acc.dtype)), num_predictions=(num_preds, jnp.array(1.0, num_preds.dtype)), num_tokens=(num_tokens, jnp.array(1.0, num_tokens.dtype)), num_seqs=(num_seqs, jnp.array(1.0, num_seqs.dtype)), ) per_example_output = py_utils.NestedMap() return metrics, per_example_output
def _to_nested_map(self, text) -> py_utils.NestedMap: ids, labels, paddings = self.tokenizer.StringsToIds( text, max_length=self.params.max_sequence_length) # Unfortunately some tokenizers don't return the correct paddings. # We recompute it by looking at when the labels sequence terminates. indices = tf.where(tf.math.equal(labels, self.tokenizer.eos_id)) lengths = tf.math.segment_min(indices[:, 1], indices[:, 0]) + 1 new_paddings = tf.cast( 1.0 - tf.sequence_mask( lengths, maxlen=self.params.max_sequence_length, dtype=paddings.dtype), dtype=paddings.dtype) return py_utils.NestedMap(ids=ids, labels=labels, paddings=new_paddings)
def replace_jax_light_conv_vars_to_tf( jax_initial_vars: NestedMap) -> NestedMap: """Replace the JAX LightConv vars to TF compatible vars. Args: jax_initial_vars: JAX LightConv vars. Returns: tf_initial_vars which is TF compatible with LightConv. """ tf_initial_vars = py_utils.NestedMap() tf_initial_vars.ln = py_utils.NestedMap() tf_initial_vars.ln.bias = jax_initial_vars.ln.bias tf_initial_vars.ln.scale = jax_initial_vars.ln.scale tf_initial_vars.norm = py_utils.NestedMap() tf_initial_vars.norm.beta = jax_initial_vars.conv_norm.beta tf_initial_vars.norm.gamma = jax_initial_vars.conv_norm.gamma tf_initial_vars.norm.moving_mean = jax_initial_vars.conv_norm.moving_mean tf_initial_vars.norm.moving_variance = jax_initial_vars.conv_norm.moving_variance tf_initial_vars.dropout = [py_utils.NestedMap(), py_utils.NestedMap()] tf_initial_vars.depthwise_conv1d = py_utils.NestedMap() tf_initial_vars.depthwise_conv1d.w = np.expand_dims( jax_initial_vars.depthwise_conv1d.w, axis=-1) tf_initial_vars.linear_end = py_utils.NestedMap() tf_initial_vars.linear_end.w = jax_initial_vars.linear_end.linear.w tf_initial_vars.linear_end.b = jax_initial_vars.linear_end.bias.b tf_initial_vars.linear_start = py_utils.NestedMap() tf_initial_vars.linear_start.w = np.concatenate([ jax_initial_vars.linear_start_gated.linear.w, jax_initial_vars.linear_start_act.linear.w ], axis=-1) tf_initial_vars.linear_start.b = np.concatenate([ jax_initial_vars.linear_start_gated.bias.b, jax_initial_vars.linear_start_act.bias.b ], axis=-1) tf_initial_vars = to_tf_nmap(tf_initial_vars) return tf_initial_vars
def _InputBatch(self): p = self.params targets = tf.ones([p.batch_size, p.seq_len], dtype=tf.int32) input_batch = py_utils.NestedMap() input_batch.ids = targets # equivalent to tf.roll(targets, 1, axis=1) input_batch.paddings = tf.zeros_like(targets) input_batch.weights = tf.ones_like(targets) input_batch.labels = targets # segment_id = 0 meant padded tokens # e.g., if we have two segments packed into one sentence with paddings # segment_ids = 1, 1, 1, 1, 2, 2, 2, 2, 0, 0 # segment_pos = 0, 1, 2, 3, 0, 1, 2, 3, 0, 0 input_batch.segment_ids = targets input_batch.segment_pos = tf.tile( tf.range(0, p.seq_len)[tf.newaxis, :], [p.batch_size, 1]) return input_batch
def test_transformer_bert(self, trainable_position_emb): seq_len = 512 if trainable_position_emb: position_emb_tpl = embedding_softmax.TrainablePositionalEmbedding.Params( ) position_emb_tpl.max_seq_length = seq_len else: position_emb_tpl = embedding_softmax.PositionalEmbedding.Params() p = transformer_models.TransformerLm.Params().Set( name='bert_lm', model_dims=32, vocab_size=52, position_emb_tpl=position_emb_tpl) stacked_transformer_tpl = p.stacked_transformer_tpl stacked_transformer_tpl.model_dims = 32 stacked_transformer_tpl.hidden_dims = 4 * 32 stacked_transformer_tpl.num_heads = 4 stacked_transformer_tpl.num_layers = 1 p.softmax_tpl.scale_sqrt_depth = True batch_size = 8 bert_lm = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = bert_lm.instantiate_variables(prng_key) input_ids = jax.random.randint(jax.random.PRNGKey(1234), [batch_size, seq_len], 0, 51) input_paddings = jnp.zeros([batch_size, seq_len]) input_weights = jnp.ones([batch_size, seq_len]) input_segment_ids = jnp.ones([batch_size, seq_len]) input_segment_pos = jnp.tile( jnp.arange(0, seq_len)[jnp.newaxis, :], [batch_size, 1]) labels = py_utils.NestedMap() labels.class_ids = input_ids labels.class_weights = input_weights outputs = test_utils.apply(bert_lm, initial_vars, bert_lm.fprop, input_ids, input_paddings, labels=labels, segment_ids=input_segment_ids, segment_pos=input_segment_pos) logging.info('outputs: %s', outputs)
def replace_jax_conformer_layer_vars_to_tf( jax_initial_vars: NestedMap) -> NestedMap: """Replace the JAX conformer layer vars to TF compatible vars. Args: jax_initial_vars: JAX conformer layer vars. Returns: tf_initial_vars which is TF compatible with ConformerLayer. """ tf_initial_vars = py_utils.NestedMap() tf_initial_vars.lconv = replace_jax_light_conv_vars_to_tf( jax_initial_vars.lconv) tf_initial_vars.final_ln = py_utils.NestedMap() tf_initial_vars.final_ln.bias = jax_initial_vars.final_ln.bias tf_initial_vars.final_ln.scale = jax_initial_vars.final_ln.scale tf_initial_vars.fflayer_start = py_utils.NestedMap() tf_initial_vars.fflayer_start.residual_dropout = jax_initial_vars.fflayer_start.residual_dropout tf_initial_vars.fflayer_start.layer_norm = jax_initial_vars.fflayer_start.layer_norm tf_initial_vars.fflayer_start.fflayer = py_utils.NestedMap() tf_initial_vars.fflayer_start.fflayer.dropout = [ jax_initial_vars.fflayer_start.relu_dropout, {} ] tf_initial_vars.fflayer_start.fflayer.fc = [ py_utils.NestedMap(), py_utils.NestedMap() ] tf_initial_vars.fflayer_start.fflayer.fc[ 0].w = jax_initial_vars.fflayer_start.ffn_layer1.linear.w tf_initial_vars.fflayer_start.fflayer.fc[ 0].b = jax_initial_vars.fflayer_start.ffn_layer1.bias.b tf_initial_vars.fflayer_start.fflayer.fc[ 1].w = jax_initial_vars.fflayer_start.ffn_layer2.linear.w tf_initial_vars.fflayer_start.fflayer.fc[ 1].b = jax_initial_vars.fflayer_start.ffn_layer2.bias.b tf_initial_vars.fflayer_end = py_utils.NestedMap() tf_initial_vars.fflayer_end.layer_norm = jax_initial_vars.fflayer_end.layer_norm tf_initial_vars.fflayer_end.residual_dropout = jax_initial_vars.fflayer_end.residual_dropout tf_initial_vars.fflayer_end.fflayer = py_utils.NestedMap() tf_initial_vars.fflayer_end.fflayer.dropout = [ jax_initial_vars.fflayer_end.relu_dropout, {} ] tf_initial_vars.fflayer_end.fflayer.fc = [ py_utils.NestedMap(), py_utils.NestedMap() ] tf_initial_vars.fflayer_end.fflayer.fc[ 0].w = jax_initial_vars.fflayer_end.ffn_layer1.linear.w tf_initial_vars.fflayer_end.fflayer.fc[ 0].b = jax_initial_vars.fflayer_end.ffn_layer1.bias.b tf_initial_vars.fflayer_end.fflayer.fc[ 1].w = jax_initial_vars.fflayer_end.ffn_layer2.linear.w tf_initial_vars.fflayer_end.fflayer.fc[ 1].b = jax_initial_vars.fflayer_end.ffn_layer2.bias.b tf_initial_vars.trans_atten = py_utils.NestedMap() tf_initial_vars.trans_atten.layer_norm = jax_initial_vars.trans_atten.norm tf_initial_vars.trans_atten.residual_dropout = jax_initial_vars.trans_atten.residual_dropout tf_initial_vars.trans_atten.atten = jax_initial_vars.trans_atten.self_atten tf_initial_vars = to_tf_nmap(tf_initial_vars) return tf_initial_vars
def test_stacked_transformer_layer(self, mask_self_attention, packed_input, cross_attention): p = transformers.StackedTransformer.Params().Set( name='jax_stacked_transformer_layer', model_dims=16, hidden_dims=64, num_heads=8, mask_self_attention=mask_self_attention, num_layers=4, packed_input=packed_input, cross_attention=cross_attention) seq_len = np.random.randint(10, 32) batch_size = 10 stacked_transformer_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = stacked_transformer_layer.instantiate_variables( prng_key) # test conversion between vars and flax vars. pax_vars = stacked_transformer_layer.vars flax_vars = stacked_transformer_layer.flax_vars tf.nest.assert_same_structure( pax_vars, stacked_transformer_layer.flax_vars_to_vars(flax_vars)) tf.nest.assert_same_structure( flax_vars, stacked_transformer_layer.vars_to_flax_vars(pax_vars)) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, p.model_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) segment_mask = None tf_segment_mask = None if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) if mask_self_attention: tf_segment_mask = batch_major_attention.CausalSegmentMask( segment_ids, tf.float32) else: tf_segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids) cross_inputs = None cross_paddings = None cross_segment_mask = None tf_cross_inputs = None tf_cross_paddings = None tf_cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 64) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, p.model_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) tf_cross_paddings = tf.constant(npy_cross_paddings, dtype=tf.float32) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) tf_cross_segment_mask = batch_major_attention.SegmentMask( segment_ids, source_segment_ids) outputs = test_utils.apply(stacked_transformer_layer, initial_vars, stacked_transformer_layer.fprop, inputs, paddings, context_p=None, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) logging.info('initial_vars in transformer layer = %s', initial_vars) # Test whether tf Transformer layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = py_utils.NestedMap() tf_initial_vars.x_layers = [] for jax_initial_vars in initial_vars.x_layers: tf_layer_vars = test_utils.replace_jax_attention_vars_to_tf( jax_initial_vars, cross_attention) tf_initial_vars.x_layers.append(tf_layer_vars) tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars) logging.info('tf_initial_vars in transformer layer = %s', initial_vars) tf_p = batch_major_attention.StackedTransformerLayers.Params().Set( name='tf_transformer_layer', mdl_dim=p.model_dims, hidden_dim=p.hidden_dims, num_atten_heads=p.num_heads, mask_self_atten=mask_self_attention, num_layers=p.num_layers, packed_input=packed_input, has_aux_atten=cross_attention) tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.batch_norm = ( False) tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.has_bias = True tf_stacked_transformer_layer = tf_p.Instantiate() tf_output, _ = tf_stacked_transformer_layer.FProp( tf_initial_vars, test_utils.to_tf_nmap(npy_inputs), paddings=test_utils.to_tf_nmap(npy_paddings), segment_mask=test_utils.to_tf_nmap(tf_segment_mask), aux_vec=test_utils.to_tf_nmap(tf_cross_inputs), aux_paddings=test_utils.to_tf_nmap(tf_cross_paddings), aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask)) np_outputs = test_utils.to_np(outputs) tf_np_outputs = test_utils.to_np(tf_output) self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
def test_single_sharded_softmax_layer(self, soft_cap_logits, use_class_ids, use_class_probabilities, label_smoothing_prob): if use_class_ids: class_ids = np.random.randint(0, 50, [8, 10, 1]) else: class_ids = None if use_class_probabilities: class_probabilities = np.random.normal(1.5, 2.0, [8, 10, 50]) else: class_probabilities = None p = embedding_softmax.SingleShardFullSoftmax.Params().Set( name='jax_softmax', num_classes=50, input_dims=40, soft_cap_logits=soft_cap_logits, label_smoothing_prob=label_smoothing_prob) softmax_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=1234) initial_vars = softmax_layer.instantiate_variables(prng_key) npy_input = np.random.normal(1.5, 2.0, [8, 10, p.input_dims]) inputs = jnp.asarray(npy_input) class_weights = np.random.normal(1.5, 2.0, [8, 10, 1]) if class_probabilities is not None: class_probabilities /= np.sum(class_probabilities, axis=-1, keepdims=True) logits = test_utils.apply(softmax_layer, initial_vars, softmax_layer.get_logits, inputs) outputs = test_utils.apply(softmax_layer, initial_vars, softmax_layer.fprop, inputs, class_weights, class_ids=class_ids, class_probabilities=class_probabilities) # Test whether tf Softmax layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = initial_vars tf_initial_vars.linear = py_utils.NestedMap() tf_initial_vars.linear.w = initial_vars.logits_ffn.linear.w tf_initial_vars.bias = py_utils.NestedMap() tf_initial_vars.bias.b = initial_vars.logits_ffn.bias.b tf_p = lingvo_layers.SingleShardFullSoftmax.Params().Set( name='tf_softmax', num_classes=p.num_classes, input_dim=p.input_dims, logits_soft_max=soft_cap_logits) tf_softmax_layer = tf_p.Instantiate() tf_logits = tf_softmax_layer.Logits( tf_initial_vars, tf.constant(inputs, dtype=tf.float32)) if use_class_ids and label_smoothing_prob > 0: class_probabilities = np.zeros([8, 10, 50]) index = np.indices([8, 10]) class_probabilities[index[0], index[1], np.squeeze(class_ids, 2)] = 1 class_probabilities = ( class_probabilities * (1 - label_smoothing_prob) + (1 - class_probabilities) * label_smoothing_prob / (p.num_classes - 1)) class_ids = None tf_output = tf_softmax_layer.FProp( tf_initial_vars, tf.constant(inputs, dtype=tf.float32), class_weights, class_ids=class_ids, class_probabilities=class_probabilities) # Check all entries in the NestedMap and ensure it matches TF np_get_logits = to_np(logits) tf_np_get_logits = to_np(tf_logits) self.assertAllClose(np_get_logits, tf_np_get_logits, atol=1e-6) # Note: The argmax-related values are very sensitive to numerical errors. for k in outputs.keys(): self.assertAllClose(to_np(outputs[k]), to_np(tf_output[k]), atol=1e-6)
def test_single_sharded_shared_embedding_softmax_layer( self, soft_cap_logits, lookup_style, scale_sqrt_depth): class_ids = np.random.randint(1, 50, [8, 10, 1]) p = embedding_softmax.SingleShardSharedEmbeddingSoftmax.Params().Set( name='jax_softmax', num_classes=50, input_dims=40, soft_cap_logits=soft_cap_logits, lookup_style=lookup_style, scale_sqrt_depth=scale_sqrt_depth) softmax_layer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = softmax_layer.instantiate_variables(prng_key) npy_input = np.random.normal(1.5, 2.0, [8, 10, p.input_dims]) inputs = jnp.asarray(npy_input) class_weights = np.random.normal(1.5, 2.0, [8, 10, 1]) outputs = test_utils.apply(softmax_layer, initial_vars, softmax_layer.fprop, inputs, class_weights, class_ids=class_ids) ids = np.squeeze(class_ids, axis=-1) emb_lookup_outputs = test_utils.apply(softmax_layer, initial_vars, softmax_layer.emb_lookup, ids=jnp.asarray(ids)) # Test whether tf Softmax layer returns same output # Modify initial_vars to use TF compatible params tf_initial_vars = initial_vars tf_initial_vars.linear = py_utils.NestedMap() tf_initial_vars.linear.w = initial_vars.logits_ffn.linear.w tf_initial_vars.bias = py_utils.NestedMap() tf_initial_vars.bias.b = initial_vars.logits_ffn.bias.b tf_p = lingvo_layers.SingleShardSharedEmbeddingSoftmax.Params().Set( name='tf_softmax', num_classes=p.num_classes, input_dim=p.input_dims, vocab_size=p.num_classes, embedding_dim=p.input_dims, logits_soft_max=soft_cap_logits, scale_sqrt_depth=scale_sqrt_depth) tf_softmax_layer = tf_p.Instantiate() tf_output = tf_softmax_layer.FProp(tf_initial_vars, tf.constant(inputs, dtype=tf.float32), class_weights, class_ids=class_ids) tf_emb_lookup_output = tf_softmax_layer.EmbLookup(tf_initial_vars, ids=tf.constant(ids)) # Check all entries in the NestedMap and ensure it matches TF np_logits = to_np(outputs.logits) tf_np_logits = to_np(tf_output.logits) self.assertAllClose(np_logits, tf_np_logits, atol=1e-6) for k in outputs.keys(): self.assertAllClose(to_np(outputs[k]), to_np(tf_output[k]), atol=1e-6) np_emb_lookup_output = to_np(emb_lookup_outputs) tf_np_emb_lookup_output = to_np(tf_emb_lookup_output) self.assertAllClose(tf_np_emb_lookup_output, np_emb_lookup_output, atol=1e-6)
def _process(source_id, record): del source_id num = tf.strings.to_number(record, tf.int32) if not tf_py_utils.use_tpu(): num = num * num return py_utils.NestedMap(num=num), 1
def test_repeated_stacked_xformer_layer(self, mask_self_attention, packed_input, cross_attention): model_dims = 16 p1 = transformers.StackedTransformer.Params().Set( name='jax_stacked_transformer_layer', model_dims=model_dims, hidden_dims=64, num_heads=8, mask_self_attention=mask_self_attention, num_layers=4, packed_input=packed_input, cross_attention=cross_attention) p1_one_layer = p1.Copy() p1_one_layer.num_layers = 1 p2 = transformers.StackedTransformerRepeated.Params().Set( name='jax_stacked_transformer_layer_repeated', block=p1_one_layer, x_times=p1.num_layers) seq_len = np.random.randint(10, 32) batch_size = 10 stacked_transformer_layer = p1.Instantiate() repeated_transformer_layer = p2.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = stacked_transformer_layer.instantiate_variables( prng_key) repeated_transformer_layer.instantiate_variable_configs() def _stack_vars(*args): args = [x[jnp.newaxis, :] for x in args] return jnp.vstack(args) stacked_vars = tf.nest.map_structure(_stack_vars, *initial_vars.x_layers) repeated_vars = py_utils.NestedMap(repeat=py_utils.NestedMap( sub=py_utils.NestedMap(x_layers=[stacked_vars]))) tf.nest.assert_same_structure( repeated_vars, repeated_transformer_layer.instantiate_variables(prng_key)) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, model_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) npy_paddings = np.random.randint( 0, 1, [batch_size, seq_len]).astype('float32') paddings = jnp.asarray(npy_paddings) segment_mask = None if packed_input: segment_ids = np.random.random_integers(0, 2, [batch_size, seq_len]) segment_mask = attentions.segment_mask(segment_ids, dtype=np.float32) cross_inputs = None cross_paddings = None cross_segment_mask = None if cross_attention: cross_seq_len = np.random.randint(10, 64) npy_cross_inputs = np.random.normal( 1.0, 0.5, [batch_size, cross_seq_len, model_dims]).astype('float32') cross_inputs = jnp.asarray(npy_cross_inputs) npy_cross_paddings = np.random.randint( 0, 1, [batch_size, cross_seq_len]).astype('float32') cross_paddings = jnp.asarray(npy_cross_paddings) if packed_input: source_segment_ids = np.random.random_integers( 0, 2, [batch_size, cross_seq_len]) cross_segment_mask = attentions.segment_mask( segment_ids, source_segment_ids, dtype=np.float32) outputs = test_utils.apply(stacked_transformer_layer, initial_vars, stacked_transformer_layer.fprop, inputs, paddings, context_p=None, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) outputs_repeated = test_utils.apply( repeated_transformer_layer, repeated_vars, repeated_transformer_layer.fprop, inputs, paddings, context_p=None, segment_mask=segment_mask, cross_inputs=cross_inputs, cross_paddings=cross_paddings, cross_segment_mask=cross_segment_mask) self.assertAllClose(outputs, outputs_repeated, atol=1e-5)
def _to_nested_map(self, x) -> py_utils.NestedMap: t = tf.ones(shape=[4], dtype=tf.int32) * tf.cast(x, dtype=tf.int32) return py_utils.NestedMap(data=t)
def test_glam_unitransformer(self): batch = 2 length = 3 d_model = 6 num_heads = 2 vocab_size = 16 ff_dim = 8 c_dim = 3 e_dim = 2 num_layers = 4 # Build jax layer jax_p = transformer_models.TransformerLm.GLaMUniTransformerParams( name='model', vocab_size=vocab_size, num_transformer_layers=num_layers, moe=True, model_dim=d_model, ff_dim=ff_dim, moe_hidden_dim=ff_dim, attention_num_heads=num_heads, attention_key_value_dim=d_model // num_heads, attention_extra_logit=0.0, use_tgt_labels_size_as_loss_denominator=True, moe_load_balance_loss_weight=0.01, z_loss_weight=1e-4, c_dim=c_dim, e_dim=e_dim) assert jax_p.packed_input jax_layer = jax_p.Instantiate() prng_key = jax.random.PRNGKey(seed=42) jax_vars = jax_layer.instantiate_variables(prng_key) builder_p = gshard_builder.DenseBuilder.Params().Set( num_groups=1, second_expert_policy='all', relative_attention_type='bias', model_dim=d_model, attention_key_value_dim=d_model // num_heads, attention_num_heads=num_heads, attention_combine_dims=True, c_dim=c_dim, capacity_factor=None, attention_extra_logit=0.0, e_dim=e_dim, moe_hidden_dim=ff_dim, ff_dim=ff_dim) tf_layer = gshard_builder.UniTransformer.Params().Set( name='model', num_transformer_layers=num_layers, builder=builder_p, vocab_size=vocab_size, sequence_length=length, label_smoothing=0, aux_loss_coef=0.01, z_loss=1e-4, use_tgt_labels_size_as_loss_denominator=True, positional_embedding=False, gated_gelu=True, moe=True).Instantiate() # Build Jax Inputs np.random.seed(42) npy_ids = np.random.randint(0, vocab_size - 1, [batch, length]) jax_ids = jnp.asarray(npy_ids) npy_paddings = np.array([[0, 0, 1], [0, 0, 1]], dtype=np.float32) jax_paddings = jnp.asarray(npy_paddings) npy_segment_ids = np.array([[1, 2, 0], [1, 1, 0]], dtype=np.int32) npy_segment_pos = np.array([[0, 0, 0], [0, 1, 0]], dtype=np.int32) npy_labels = np.roll(npy_ids, -1, axis=1) jax_labels = jnp.asarray(npy_labels) jax_seg_ids = jnp.asarray(npy_segment_ids) jax_seg_pos = jnp.asarray(npy_segment_pos) jax_label_weighs = jnp.asarray([[1, 1, 0], [1, 1, 0]]) # Build TF Inputs tf_tgt_inputs = py_utils.NestedMap( ids=tf.convert_to_tensor(npy_ids, dtype=tf.int32), labels=tf.convert_to_tensor(npy_labels, dtype=tf.int32), segment_ids=tf.convert_to_tensor(npy_segment_ids, dtype=tf.int32), segment_pos=tf.convert_to_tensor(npy_segment_pos, dtype=tf.int32)) tf_inputs = py_utils.NestedMap(tgt=tf_tgt_inputs) # Compute jax outputs jax_outputs = test_utils.apply(jax_layer, jax_vars, jax_layer.fprop, jax_ids, jax_paddings, context_p=None, labels=py_utils.NestedMap( class_ids=jax_labels, class_weights=jax_label_weighs, ), segment_ids=jax_seg_ids, segment_pos=jax_seg_pos) # Copy jax vars to tf ones. tf_theta = tf_layer.theta.DeepCopy() # GShardBuilder softmax weight use self.vars rather than theta. tf_layer.vars.dec_emb.w.embedding.assign(jax_vars.softmax.embedding.w) tf_theta.dec_emb.w.embedding = jax_vars.softmax.embedding.w tf_theta.dec.final_layer_norm.w.scale = jax_vars.final_ln.scale jax_layer_0_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0), jax_vars.transformer.repeat.sub.x_layers[0]) tf_theta.dec.layer_000.ln.w.scale = jax_layer_0_var.layer_norm.scale jax_atten_var = jax_layer_0_var.self_attention tf_atten_var = tf_theta.dec.layer_000.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_moe_var = jax_layer_0_var.ff_layer tf_theta.dec.layer_001.ln.w.scale = jax_moe_var.layer_norm.scale tf_theta.dec.layer_001.moe.ffw.top_2_gating.w = jax_moe_var.gate tf_theta.dec.layer_001.moe.moe.wi = jax_moe_var.wi_0 tf_theta.dec.layer_001.moe.moe.wo = jax_moe_var.wo_0 jax_layer_1_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0), jax_vars.transformer.repeat.sub.x_layers[1]) tf_theta.dec.layer_002.ln.w.scale = jax_layer_1_var.layer_norm.scale jax_atten_var = jax_layer_1_var.self_attention tf_atten_var = tf_theta.dec.layer_002.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_ffn_var = jax_layer_1_var.ff_layer tf_ffn_var = tf_theta.dec.layer_003.dense_relu_dense tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w tf_theta.dec.layer_003.ln.w.scale = jax_ffn_var.layer_norm.scale jax_layer_2_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0), jax_vars.transformer.repeat.sub.x_layers[0]) tf_theta.dec.layer_004.ln.w.scale = jax_layer_2_var.layer_norm.scale jax_atten_var = jax_layer_2_var.self_attention tf_atten_var = tf_theta.dec.layer_004.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_moe_var = jax_layer_2_var.ff_layer tf_theta.dec.layer_005.ln.w.scale = jax_moe_var.layer_norm.scale tf_theta.dec.layer_005.moe.ffw.top_2_gating.w = jax_moe_var.gate tf_theta.dec.layer_005.moe.moe.wi = jax_moe_var.wi_0 tf_theta.dec.layer_005.moe.moe.wo = jax_moe_var.wo_0 jax_layer_3_var = tf.nest.map_structure( lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0), jax_vars.transformer.repeat.sub.x_layers[1]) tf_theta.dec.layer_006.ln.w.scale = jax_layer_3_var.layer_norm.scale jax_atten_var = jax_layer_3_var.self_attention tf_atten_var = tf_theta.dec.layer_006.dec_self_attention tf_atten_var.w.wk = jax_atten_var.key.w tf_atten_var.w.wq = jax_atten_var.query.w tf_atten_var.w.wv = jax_atten_var.value.w tf_atten_var.w.wo = jax_atten_var.post.w tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb jax_ffn_var = jax_layer_3_var.ff_layer tf_ffn_var = tf_theta.dec.layer_007.dense_relu_dense tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w tf_theta.dec.layer_007.ln.w.scale = jax_ffn_var.layer_norm.scale tf_theta = test_utils.to_tf_nmap(tf_theta) # Compute TF outputs tf_out, _ = tf_layer.FProp(tf_theta, tf_inputs) self.assertAllClose(test_utils.to_np(jax_outputs.total_loss), test_utils.to_np(tf_out['loss'][0]))
def fprop(self, z: JTensor, paddings: JTensor) -> NestedMap: """Quantizes 'z' of shape [B, T, D]. The z_codes of padded locations are 0. Args: z: [B, T, D]. paddings: [B, T]. Returns: A NestedMap of - z_q: [B, T, D]. - z_codes: [B, T, G]. - z_onehot: [B, T, G, C]. - loss: [], weighted sum of quantization loss and commitment loss. - codebook_coverage: [], a float scalar tensor between [0, 1]. - pplx: [], pplx of quantized distribution over the codebook. - entropy: [], exp(pplx). """ p = self.params theta = self.local_theta() b, t, d = z.shape g, c = p.num_groups, p.num_latent_classes mask = 1.0 - paddings num_frames = jnp.sum(mask) z = self._apply_mask(z, mask) if p.normalize_latent_vector: z = self._l2_normalize(z, axis=-1) # [b * t, d], [b * t, g], [b * t, g, c] z_q, z_codes, z_onehot = quantize_vector( jnp.reshape(z, [b * t, d]), self._get_latent_embedding(theta)) z_q = jnp.reshape(z_q, [b, t, d]) z_codes = jnp.reshape(z_codes, [b, t, g]) z_onehot = jnp.reshape(z_onehot, [b, t, g, c]) # Padded locations are all 0s without any 1. z_q = self._apply_mask(z_q, mask) # [b, t, g] z_codes = self._apply_mask(z_codes, mask) # [b, t, g, c] z_onehot = self._apply_mask(z_onehot, mask) # Move z towards z_q. normalizer = 1e-7 + num_frames # [b, t, d] loss_c = (z - jax.lax.stop_gradient(z_q))**2 # [b, t, d] -> [b, t] -> [] loss_c = jnp.sum(jnp.mean(loss_c, -1)) / normalizer # loss_c = py_utils.check_numerics(loss_c, 'loss_c has NaN.') # Move z_q towards z. loss_z = (z_q - jax.lax.stop_gradient(z))**2 loss_z = jnp.sum(jnp.mean(loss_z, -1)) / normalizer # loss_z = py_utils.check_numerics(loss_z, 'loss_z has NaN.') loss = loss_z + p.beta * loss_c # Straight-through estimator. # Doesn't look like this line does anyhing besides stopping gradient ?? z_q = z + jax.lax.stop_gradient(z_q - z) # [], [] pplx, entropy, _ = objectives.batch_pplx_entropy_from_codes( z_codes, c, paddings=paddings) # pplx = py_utils.check_numerics(pplx, f'{p.name} perplexity NaN') codebook_coverage = objectives.batch_codebook_coverage( z_codes, c, paddings=paddings) codebook_num_covered_words = codebook_coverage * c**g return py_utils.NestedMap( z_q=z_q, z_codes=z_codes, z_onehot=z_onehot, loss=loss, codebook_coverage=codebook_coverage, codebook_num_covered_words=codebook_num_covered_words, pplx=pplx, entropy=entropy)
def test_conformer_layer(self, batch_size, seq_len, kernel_size, input_dims, model_dims, atten_num_heads, dropout_prob): # Lingvo TF layers only use dropout on FF and Attention layers p = conformers.Conformer.Params().Set( name='jax_conformer_layer', input_dims=input_dims, conv_residual_dropout=0.0, atten_residual_dropout=dropout_prob, ffn_residual_dropout=dropout_prob, atten_dropout=dropout_prob, ffn_relu_dropout=dropout_prob, kernel_size=kernel_size, model_dims=model_dims, atten_num_heads=atten_num_heads) conformer = p.Instantiate() prng_key = jax.random.PRNGKey(seed=123) initial_vars = conformer.instantiate_variables(prng_key) npy_inputs = np.random.normal( 1.0, 0.5, [batch_size, seq_len, input_dims]).astype('float32') inputs = jnp.asarray(npy_inputs) def GetPaddingfromLength(length): idx = np.tile(np.arange(seq_len), [batch_size, 1]) return (idx >= np.expand_dims(length, -1)).astype('float32') length = np.random.randint(seq_len // 2, seq_len, (batch_size, )) npy_paddings = GetPaddingfromLength(length).astype('float32') paddings = jnp.asarray(npy_paddings) context_p = base_layer.JaxContext.Params().Set(do_eval=True) output = test_utils.apply(conformer, initial_vars, conformer.fprop, inputs, paddings, context_p=context_p) # Test whether tf Conformer layer returns the same output # Modify initial_vars to use TF compatible params tf_initial_vars = test_utils.replace_jax_conformer_layer_vars_to_tf( initial_vars) tf_p = conformer_layer.ConformerLayer.CommonParams( input_dim=input_dims, dropout_prob=dropout_prob, atten_num_heads=atten_num_heads, kernel_size=kernel_size, fflayer_hidden_dim=model_dims * p.ffn_dim_multiplier, use_relative_atten=False, fflayer_residual_weight=0.5).Set(name='tf_conformer') tf_p.trans_atten_tpl = tf_p.trans_atten_tpl.Set(hidden_dim=model_dims) tf_conformer = tf_p.Instantiate() with cluster_factory.SetEval(True): tf_output = tf_conformer.FProp( tf_initial_vars, py_utils.NestedMap(features=tf.constant(inputs, dtype=tf.float32), paddings=tf.constant(npy_paddings, dtype=tf.float32))) np_output = to_np(output) tf_np_output = to_np(tf_output.features) self.assertAllClose(tf_np_output, np_output, atol=1e-5)