Example #1
0
def replace_jax_transformer_ffwd_vars_to_tf(
        jax_initial_vars: NestedMap) -> NestedMap:
    """Replaces JAX TransformerFeedForward vars to TF compatible vars.

  Args:
    jax_initial_vars: JAX TransformerFeedforward layer vars.

  Returns:
    tf_initial_vars which is TF compatible.
  """
    tf_initial_vars = jax_initial_vars.copy()
    tf_initial_vars = py_utils.NestedMap()
    tf_initial_vars.fflayer = py_utils.NestedMap(
        fc=[
            py_utils.NestedMap(w=jax_initial_vars.ffn_layer1.linear.w,
                               b=jax_initial_vars.ffn_layer1.bias.b),
            py_utils.NestedMap(w=jax_initial_vars.ffn_layer2.linear.w,
                               b=jax_initial_vars.ffn_layer2.bias.b),
        ],
        dropout=[py_utils.NestedMap(),
                 py_utils.NestedMap()],
    )
    tf_initial_vars.layer_norm = py_utils.NestedMap(
        bias=jax_initial_vars.layer_norm.bias,
        scale=jax_initial_vars.layer_norm.scale)
    tf_initial_vars.residual_dropout = py_utils.NestedMap()
    tf_initial_vars.residual_droppath = py_utils.NestedMap()
    return tf_initial_vars
Example #2
0
    def test_extract_prefixed_keys_from_dataclass(self):
        @struct.dataclass
        class GlobalShardedParameterStats:
            statistics: np.ndarray  # Statistics
            preconditioners: np.ndarray  # Preconditioners
            exponents: np.ndarray  # exponents
            index_start: int = struct.field(pytree_node=False)
            sizes: Any = struct.field(pytree_node=False)

        stats0 = GlobalShardedParameterStats(
            statistics=np.array([0], dtype=np.float32),
            preconditioners=np.array([1, 1], dtype=np.float32),
            exponents=np.array([2, 2, 2], dtype=np.float32),
            index_start=0,
            sizes=0,
        )
        # Even though the `preconditioners` is first here, the order is decided
        # by the order in `GlobalShardedParameterStats` class.
        stats1 = GlobalShardedParameterStats(
            preconditioners=np.array([5, 5], dtype=np.float32),
            statistics=np.array([4], dtype=np.float32),
            exponents=np.array([6, 6, 6], dtype=np.float32),
            index_start=1,
            sizes=1,
        )

        nested_data = py_utils.NestedMap(stats0=stats0, stats1=stats1)
        nested_names = py_utils.extract_prefixed_keys_from_nested_map(
            nested_data)
        flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names)

        self.assertListEqual([
            'stats0/statistics', 'stats0/preconditioners', 'stats0/exponents',
            'stats1/statistics', 'stats1/preconditioners', 'stats1/exponents'
        ], flattened_nested_names)
Example #3
0
 def test_feedforward_layer_no_bias(self, activation):
   p = linears.FeedForward.Params().Set(
       name='jax_ffn',
       input_dims=3,
       output_dims=20,
       has_bias=False,
       activation=activation)
   ffn = p.Instantiate()
   prng_key = jax.random.PRNGKey(seed=123)
   initial_vars = ffn.instantiate_variables(prng_key)
   npy_input = np.random.normal(1.0, 0.5,
                                [10, 10, p.input_dims]).astype('float32')
   inputs = jnp.asarray(npy_input)
   outputs = test_utils.apply(ffn, initial_vars, ffn.fprop, inputs)
   logging.info('initial_vars in ffn = %s', initial_vars)
   # Test whether tf projection layer returns same output
   # Modify initial_vars to use TF compatible params
   tf_initial_vars = py_utils.NestedMap()
   tf_initial_vars.w = initial_vars.linear.w
   tf_initial_vars = to_tf_nmap(tf_initial_vars)
   tf_p = lingvo_layers.ProjectionLayer.Params().Set(
       name='tf_ffn',
       input_dim=p.input_dims,
       output_dim=p.output_dims,
       batch_norm=False,
       has_bias=False,
       activation=activation)
   tf_ffn = tf_p.Instantiate()
   tf_output = tf_ffn.FProp(tf_initial_vars,
                            tf.constant(inputs, dtype=tf.float32))
   np_outputs = to_np(outputs)
   tf_np_outputs = to_np(tf_output)
   self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-6)
Example #4
0
def replace_jax_single_shard_full_softmax_vars_to_tf(
        jax_initial_vars: NestedMap) -> NestedMap:
    """Replaces JAX Single Shard Full Softmax vars to TF compatible vars.

  Args:
    jax_initial_vars: JAX ConvBNAct layer vars.

  Returns:
    tf_initial_vars which is TF compatible with ConvBNAct.
  """
    tf_initial_vars = jax_initial_vars.copy()
    tf_initial_vars.linear = py_utils.NestedMap(
        w=jax_initial_vars.logits_ffn.linear.w)
    tf_initial_vars.bias = py_utils.NestedMap(
        b=jax_initial_vars.logits_ffn.bias.b)
    del tf_initial_vars.logits_ffn
    return tf_initial_vars
Example #5
0
def replace_jax_attention_vars_to_tf(
        jax_initial_vars: NestedMap,
        cross_attention: Optional[bool] = False) -> NestedMap:
    """Replaces JAX attention vars to TF compatible vars.

  Args:
    jax_initial_vars: JAX attention layer vars.
    cross_attention: Whether cross attention is involved.

  Returns:
    tf_initial_vars which is TF compatible.
  """
    tf_initial_vars = jax_initial_vars.copy()
    tf_initial_vars.fflayer = jax_initial_vars.ff_layer
    tf_initial_vars.fflayer.fflayer = py_utils.NestedMap()
    is_moe = 'gate' in jax_initial_vars.ff_layer
    if is_moe:
        tf_initial_vars.fflayer.fflayer.layer_norm = py_utils.NestedMap()
        tf_initial_vars.fflayer.fflayer.layer_norm.scale = jax_initial_vars.ff_layer.layer_norm.scale
        tf_initial_vars.fflayer.fflayer.layer_norm.bias = jax_initial_vars.ff_layer.layer_norm.bias
        tf_initial_vars.fflayer.fflayer.gate = jax_initial_vars.ff_layer.gate
        tf_initial_vars.fflayer.fflayer.wi_0 = jax_initial_vars.ff_layer.wi_0
        tf_initial_vars.fflayer.fflayer.wo_0 = jax_initial_vars.ff_layer.wo_0
    else:
        tf_initial_vars.fflayer.fflayer.dropout = [1.0, 1.0]
        tf_initial_vars.fflayer.fflayer.fc = [NestedMap(), NestedMap()]
        tf_initial_vars.fflayer.fflayer.fc[0].w = (
            jax_initial_vars.ff_layer.ffn_layer1.linear.w)
        tf_initial_vars.fflayer.fflayer.fc[0].b = (
            jax_initial_vars.ff_layer.ffn_layer1.bias.b)
        tf_initial_vars.fflayer.fflayer.fc[1].w = (
            jax_initial_vars.ff_layer.ffn_layer2.linear.w)
        tf_initial_vars.fflayer.fflayer.fc[1].b = (
            jax_initial_vars.ff_layer.ffn_layer2.bias.b)
    tf_initial_vars.self_atten = NestedMap()
    tf_initial_vars.self_atten.layer_norm = jax_initial_vars.layer_norm
    tf_initial_vars.self_atten.atten = jax_initial_vars.self_attention
    tf_initial_vars.self_atten.residual_dropout = 1.0
    if cross_attention:
        tf_initial_vars.cross_atten = NestedMap()
        tf_initial_vars.cross_atten.layer_norm = jax_initial_vars.layer_norm
        tf_initial_vars.cross_atten.atten = jax_initial_vars.cross_attention
        tf_initial_vars.cross_atten.residual_dropout = 1.0
    return tf_initial_vars
Example #6
0
    def test_LSTMSimple(self, jax_cell_class, cifg, output_nonlinearity):
        np.random.seed(_NUMPY_RANDOM_SEED)
        inputs = py_utils.NestedMap(act=[np.random.uniform(size=(3, 2))],
                                    padding=jnp.zeros([3, 1]))
        state0 = py_utils.NestedMap(c=np.random.uniform(size=(3, 2)),
                                    m=np.random.uniform(size=(3, 2)))
        tf_inputs = py_utils.NestedMap(
            act=[tf.constant(inputs.act[0], tf.float32)],
            padding=tf.zeros([3, 1]))
        tf_state0 = py_utils.NestedMap(c=tf.constant(state0.c, tf.float32),
                                       m=tf.constant(state0.m, tf.float32))

        params = rnn_cell.LSTMCellSimple.Params().Set(
            name='lstm',
            params_init=py_utils.WeightInit.Uniform(1.24, _INIT_RANDOM_SEED),
            bias_init=py_utils.WeightInit.Uniform(1.24, _INIT_RANDOM_SEED),
            num_input_nodes=2,
            num_output_nodes=2,
            couple_input_forget_gates=cifg,
            enable_lstm_bias=True,
            output_nonlinearity=output_nonlinearity)
        lstm = rnn_cell.LSTMCellSimple(params)
        res, _ = lstm.FPropDefaultTheta(tf_state0, tf_inputs)
        m_expected = res.m.numpy()
        c_expected = res.c.numpy()

        p = jax_cell_class.Params().Set(
            num_input_nodes=2,
            num_output_nodes=2,
            name='lstm',
            output_nonlinearity=output_nonlinearity,
        )
        model = p.Instantiate()

        theta = model.instantiate_variables(jax.random.PRNGKey(5678))
        theta.wm = lstm.vars['wm'].numpy()
        theta.b = lstm.vars['b'].numpy()

        output, _ = test_utils.apply(model, model.vars_to_flax_vars(theta),
                                     model.fprop, state0, inputs)
        self.assertAllClose(m_expected, output.m)
        self.assertAllClose(c_expected, output.c)
Example #7
0
  def _parse_record(self, record) -> NestedMap:
    """Reads and parses a single record."""
    p = self.params
    name_to_features = {
        'input_ids':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
        'input_mask':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
        'masked_lm_positions':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
        'masked_lm_ids':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
        'masked_lm_weights':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32),
    }
    example = tf.io.parse_single_example(record, name_to_features)
    mask_length = tf.cast(
        tf.reduce_sum(example['masked_lm_weights']), dtype=tf.int32)
    masked_lm_positions = tf.slice(example['masked_lm_positions'], [0],
                                   [mask_length])
    masked_lm_ids = tf.cast(
        tf.slice(example['masked_lm_ids'], [0], [mask_length]), dtype=tf.int32)
    ret = py_utils.NestedMap()
    ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32)
    # Get back non-masked, original ids.
    ret.labels = tf.tensor_scatter_nd_update(
        tensor=ret.masked_ids,
        indices=tf.reshape(masked_lm_positions, [-1, 1]),
        updates=masked_lm_ids)
    ret.masked_pos = tf.tensor_scatter_nd_update(
        tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32),
        indices=tf.reshape(masked_lm_positions, [-1, 1]),
        updates=tf.ones_like(masked_lm_ids, dtype=tf.float32))
    ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32)

    first_eos_idx = tf.where(tf.math.equal(ret.labels, p.eos_token_id))[0][0]

    def remove_first_eos(x):
      # We remove the element at position `first_eos_idx`, and pad with 0
      # to keep length unchanged.
      zero = tf.constant(0, shape=(1,), dtype=x.dtype)
      return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero], axis=0)

    ret = ret.Transform(remove_first_eos)
    ret.paddings = 1.0 - ret.segment_ids
    pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32)
    ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32)

    if p.remask:
      new_masked_ids, new_masked_pos = self.mlm.FProp(None, ret.labels,
                                                      ret.paddings)
      ret.masked_ids = new_masked_ids
      ret.masked_pos = new_masked_pos
    return ret
Example #8
0
        def _scan_fn(layer_in, layer_vars):
            jax_context = base_layer.cur_jax_context()
            flax_variables = self.sub.vars_to_flax_vars(layer_vars)
            # properly setup scope.
            jax_context.bind(self.sub, flax_variables,
                             [base_layer.SCOPE_AUX_LOSS])

            layer_out, extra = fprop_fn(self.sub, layer_in.carry, *args,
                                        **kwargs)
            tf.nest.assert_same_structure(layer_in.carry, layer_out)
            return NestedMap(carry=layer_out), py_utils.NestedMap(extra=extra)
Example #9
0
 def _all_paddings_batch(self) -> NestedMap:
   p = self.params
   shape = [p.batch_size, p.max_sequence_length]
   ret = py_utils.NestedMap()
   ret.labels = tf.zeros(shape, dtype=tf.int32)
   ret.masked_ids = ret.labels
   ret.segment_pos = ret.labels
   ret.masked_pos = tf.zeros(shape, dtype=tf.float32)
   ret.segment_ids = ret.masked_pos
   ret.paddings = 1.0 - ret.segment_ids
   return ret
Example #10
0
    def extend_step(self, extend_fn, cached_states: NestedMap,
                    step_inputs: NestedJTensor, *args: Any,
                    **kwargs: Any) -> Any:
        """Extends decoder states by one step.

    extend_fn should have the following signature.

    extended_states, step_out = extend_fn(self.sub, states, step_input,
                                          *args, **kwargs)
    extended_states should have the same structure as states
    step_out should have the same structure as step_input

    Args:
      extend_fn: fn to extend cached_states for one step. It should be of the
        expected signature as described above.
      cached_states: The combined states for all sub-layers.
      step_inputs: Input to the bottom decoder layer.
      *args: Additional positional input.
      **kwargs: Additional keyword input.

    Returns:
      new_states, top_decoder_out, where new_states is the updated decoder
      states, and top_decoder_out is the output from the top decoder layer.
    """
        # Wrap inputs in a NestedMap to conform to recurrent.scan interface.
        step_inputs_mp = NestedMap(carry=step_inputs)

        def _scan_fn(layer_in, vars_and_states):
            layer_vars = vars_and_states.layer_vars
            layer_states = vars_and_states.layer_states
            # Properly setup context.
            jax_context = base_layer.cur_jax_context()
            flax_variables = self.sub.vars_to_flax_vars(layer_vars)
            jax_context.bind(self.sub, flax_variables,
                             [base_layer.SCOPE_AUX_LOSS])

            extended_states, layer_out = extend_fn(self.sub, layer_states,
                                                   layer_in.carry, *args,
                                                   **kwargs)
            tf.nest.assert_same_structure(layer_in.carry, layer_out)
            tf.nest.assert_same_structure(extended_states, layer_states)
            return NestedMap(carry=layer_out), extended_states

        vars_and_states = py_utils.NestedMap(layer_vars=self.sub.local_theta(),
                                             layer_states=cached_states)

        final_out, new_states, summaries = recurrent.scan(step_inputs_mp,
                                                          vars_and_states,
                                                          _scan_fn,
                                                          root_layer=self)
        # forward summaries to the out-context.
        self._forward_summary(summaries)
        return new_states, final_out.carry
Example #11
0
    def compute_loss(self, predictions: NestedMap,
                     input_batch: NestedMap) -> Tuple[Metrics, Dict[str, Any]]:
        """Computes the loss and other metrics for the given predictions.

    Args:
      predictions: The output of `compute_predictions`.
      input_batch: A `.NestedMap` object containing input tensors to this tower.

    Returns:
      - A dict or NestedMap containing str keys and (metric, weight) pairs as
        values, where one of the entries is expected to corresponds to the loss.
      - A dict containing arbitrary tensors describing something about each
        training example, where the first dimension of each tensor is the batch
        index.
    """
        labels = input_batch.labels
        num_tokens = jnp.sum(1.0 - input_batch.paddings.astype(jnp.float32))
        num_seqs = jnp.sum(
            jnp.amax(input_batch.segment_ids.astype(jnp.float32), axis=1))
        weights = predictions.augmented_pos.astype(jnp.float32)
        predicted_labels = predictions.per_example_argmax.astype(labels.dtype)
        num_preds = predictions.total_weight.astype(jnp.float32)
        mean_acc = jnp.sum(
            (labels == predicted_labels) * weights) / jnp.maximum(
                num_preds, 1)
        metric_weight = jnp.array(num_preds, predictions.avg_xent.dtype)
        metrics = py_utils.NestedMap(
            total_loss=(predictions.total_loss, metric_weight),
            avg_xent=(predictions.avg_xent, metric_weight),
            aux_loss=(predictions.aux_loss, metric_weight),
            log_pplx=(predictions.avg_xent, metric_weight),
            fraction_of_correct_preds=(mean_acc,
                                       jnp.array(num_preds, mean_acc.dtype)),
            num_predictions=(num_preds, jnp.array(1.0, num_preds.dtype)),
            num_tokens=(num_tokens, jnp.array(1.0, num_tokens.dtype)),
            num_seqs=(num_seqs, jnp.array(1.0, num_seqs.dtype)),
        )

        per_example_output = py_utils.NestedMap()
        return metrics, per_example_output
Example #12
0
 def _to_nested_map(self, text) -> py_utils.NestedMap:
   ids, labels, paddings = self.tokenizer.StringsToIds(
       text, max_length=self.params.max_sequence_length)
   # Unfortunately some tokenizers don't return the correct paddings.
   # We recompute it by looking at when the labels sequence terminates.
   indices = tf.where(tf.math.equal(labels, self.tokenizer.eos_id))
   lengths = tf.math.segment_min(indices[:, 1], indices[:, 0]) + 1
   new_paddings = tf.cast(
       1.0 - tf.sequence_mask(
           lengths,
           maxlen=self.params.max_sequence_length,
           dtype=paddings.dtype),
       dtype=paddings.dtype)
   return py_utils.NestedMap(ids=ids, labels=labels, paddings=new_paddings)
Example #13
0
def replace_jax_light_conv_vars_to_tf(
        jax_initial_vars: NestedMap) -> NestedMap:
    """Replace the JAX LightConv vars to TF compatible vars.

  Args:
    jax_initial_vars: JAX LightConv vars.

  Returns:
    tf_initial_vars which is TF compatible with LightConv.
  """
    tf_initial_vars = py_utils.NestedMap()

    tf_initial_vars.ln = py_utils.NestedMap()
    tf_initial_vars.ln.bias = jax_initial_vars.ln.bias
    tf_initial_vars.ln.scale = jax_initial_vars.ln.scale

    tf_initial_vars.norm = py_utils.NestedMap()
    tf_initial_vars.norm.beta = jax_initial_vars.conv_norm.beta
    tf_initial_vars.norm.gamma = jax_initial_vars.conv_norm.gamma
    tf_initial_vars.norm.moving_mean = jax_initial_vars.conv_norm.moving_mean
    tf_initial_vars.norm.moving_variance = jax_initial_vars.conv_norm.moving_variance

    tf_initial_vars.dropout = [py_utils.NestedMap(), py_utils.NestedMap()]

    tf_initial_vars.depthwise_conv1d = py_utils.NestedMap()
    tf_initial_vars.depthwise_conv1d.w = np.expand_dims(
        jax_initial_vars.depthwise_conv1d.w, axis=-1)

    tf_initial_vars.linear_end = py_utils.NestedMap()
    tf_initial_vars.linear_end.w = jax_initial_vars.linear_end.linear.w
    tf_initial_vars.linear_end.b = jax_initial_vars.linear_end.bias.b

    tf_initial_vars.linear_start = py_utils.NestedMap()
    tf_initial_vars.linear_start.w = np.concatenate([
        jax_initial_vars.linear_start_gated.linear.w,
        jax_initial_vars.linear_start_act.linear.w
    ],
                                                    axis=-1)
    tf_initial_vars.linear_start.b = np.concatenate([
        jax_initial_vars.linear_start_gated.bias.b,
        jax_initial_vars.linear_start_act.bias.b
    ],
                                                    axis=-1)

    tf_initial_vars = to_tf_nmap(tf_initial_vars)
    return tf_initial_vars
Example #14
0
 def _InputBatch(self):
   p = self.params
   targets = tf.ones([p.batch_size, p.seq_len], dtype=tf.int32)
   input_batch = py_utils.NestedMap()
   input_batch.ids = targets  # equivalent to tf.roll(targets, 1, axis=1)
   input_batch.paddings = tf.zeros_like(targets)
   input_batch.weights = tf.ones_like(targets)
   input_batch.labels = targets
   # segment_id = 0 meant padded tokens
   # e.g., if we have two segments packed into one sentence with paddings
   # segment_ids = 1, 1, 1, 1, 2, 2, 2, 2, 0, 0
   # segment_pos = 0, 1, 2, 3, 0, 1, 2, 3, 0, 0
   input_batch.segment_ids = targets
   input_batch.segment_pos = tf.tile(
       tf.range(0, p.seq_len)[tf.newaxis, :], [p.batch_size, 1])
   return input_batch
    def test_transformer_bert(self, trainable_position_emb):
        seq_len = 512
        if trainable_position_emb:
            position_emb_tpl = embedding_softmax.TrainablePositionalEmbedding.Params(
            )
            position_emb_tpl.max_seq_length = seq_len
        else:
            position_emb_tpl = embedding_softmax.PositionalEmbedding.Params()
        p = transformer_models.TransformerLm.Params().Set(
            name='bert_lm',
            model_dims=32,
            vocab_size=52,
            position_emb_tpl=position_emb_tpl)
        stacked_transformer_tpl = p.stacked_transformer_tpl
        stacked_transformer_tpl.model_dims = 32
        stacked_transformer_tpl.hidden_dims = 4 * 32
        stacked_transformer_tpl.num_heads = 4
        stacked_transformer_tpl.num_layers = 1
        p.softmax_tpl.scale_sqrt_depth = True
        batch_size = 8
        bert_lm = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = bert_lm.instantiate_variables(prng_key)
        input_ids = jax.random.randint(jax.random.PRNGKey(1234),
                                       [batch_size, seq_len], 0, 51)
        input_paddings = jnp.zeros([batch_size, seq_len])
        input_weights = jnp.ones([batch_size, seq_len])
        input_segment_ids = jnp.ones([batch_size, seq_len])
        input_segment_pos = jnp.tile(
            jnp.arange(0, seq_len)[jnp.newaxis, :], [batch_size, 1])

        labels = py_utils.NestedMap()
        labels.class_ids = input_ids
        labels.class_weights = input_weights
        outputs = test_utils.apply(bert_lm,
                                   initial_vars,
                                   bert_lm.fprop,
                                   input_ids,
                                   input_paddings,
                                   labels=labels,
                                   segment_ids=input_segment_ids,
                                   segment_pos=input_segment_pos)
        logging.info('outputs: %s', outputs)
Example #16
0
def replace_jax_conformer_layer_vars_to_tf(
        jax_initial_vars: NestedMap) -> NestedMap:
    """Replace the JAX conformer layer vars to TF compatible vars.

  Args:
    jax_initial_vars: JAX conformer layer vars.

  Returns:
    tf_initial_vars which is TF compatible with ConformerLayer.
  """

    tf_initial_vars = py_utils.NestedMap()

    tf_initial_vars.lconv = replace_jax_light_conv_vars_to_tf(
        jax_initial_vars.lconv)

    tf_initial_vars.final_ln = py_utils.NestedMap()
    tf_initial_vars.final_ln.bias = jax_initial_vars.final_ln.bias
    tf_initial_vars.final_ln.scale = jax_initial_vars.final_ln.scale

    tf_initial_vars.fflayer_start = py_utils.NestedMap()
    tf_initial_vars.fflayer_start.residual_dropout = jax_initial_vars.fflayer_start.residual_dropout
    tf_initial_vars.fflayer_start.layer_norm = jax_initial_vars.fflayer_start.layer_norm
    tf_initial_vars.fflayer_start.fflayer = py_utils.NestedMap()
    tf_initial_vars.fflayer_start.fflayer.dropout = [
        jax_initial_vars.fflayer_start.relu_dropout, {}
    ]
    tf_initial_vars.fflayer_start.fflayer.fc = [
        py_utils.NestedMap(), py_utils.NestedMap()
    ]
    tf_initial_vars.fflayer_start.fflayer.fc[
        0].w = jax_initial_vars.fflayer_start.ffn_layer1.linear.w
    tf_initial_vars.fflayer_start.fflayer.fc[
        0].b = jax_initial_vars.fflayer_start.ffn_layer1.bias.b
    tf_initial_vars.fflayer_start.fflayer.fc[
        1].w = jax_initial_vars.fflayer_start.ffn_layer2.linear.w
    tf_initial_vars.fflayer_start.fflayer.fc[
        1].b = jax_initial_vars.fflayer_start.ffn_layer2.bias.b

    tf_initial_vars.fflayer_end = py_utils.NestedMap()
    tf_initial_vars.fflayer_end.layer_norm = jax_initial_vars.fflayer_end.layer_norm
    tf_initial_vars.fflayer_end.residual_dropout = jax_initial_vars.fflayer_end.residual_dropout
    tf_initial_vars.fflayer_end.fflayer = py_utils.NestedMap()
    tf_initial_vars.fflayer_end.fflayer.dropout = [
        jax_initial_vars.fflayer_end.relu_dropout, {}
    ]
    tf_initial_vars.fflayer_end.fflayer.fc = [
        py_utils.NestedMap(), py_utils.NestedMap()
    ]
    tf_initial_vars.fflayer_end.fflayer.fc[
        0].w = jax_initial_vars.fflayer_end.ffn_layer1.linear.w
    tf_initial_vars.fflayer_end.fflayer.fc[
        0].b = jax_initial_vars.fflayer_end.ffn_layer1.bias.b
    tf_initial_vars.fflayer_end.fflayer.fc[
        1].w = jax_initial_vars.fflayer_end.ffn_layer2.linear.w
    tf_initial_vars.fflayer_end.fflayer.fc[
        1].b = jax_initial_vars.fflayer_end.ffn_layer2.bias.b

    tf_initial_vars.trans_atten = py_utils.NestedMap()
    tf_initial_vars.trans_atten.layer_norm = jax_initial_vars.trans_atten.norm
    tf_initial_vars.trans_atten.residual_dropout = jax_initial_vars.trans_atten.residual_dropout
    tf_initial_vars.trans_atten.atten = jax_initial_vars.trans_atten.self_atten
    tf_initial_vars = to_tf_nmap(tf_initial_vars)
    return tf_initial_vars
Example #17
0
    def test_stacked_transformer_layer(self, mask_self_attention, packed_input,
                                       cross_attention):
        p = transformers.StackedTransformer.Params().Set(
            name='jax_stacked_transformer_layer',
            model_dims=16,
            hidden_dims=64,
            num_heads=8,
            mask_self_attention=mask_self_attention,
            num_layers=4,
            packed_input=packed_input,
            cross_attention=cross_attention)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        stacked_transformer_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = stacked_transformer_layer.instantiate_variables(
            prng_key)

        # test conversion between vars and flax vars.
        pax_vars = stacked_transformer_layer.vars
        flax_vars = stacked_transformer_layer.flax_vars
        tf.nest.assert_same_structure(
            pax_vars, stacked_transformer_layer.flax_vars_to_vars(flax_vars))
        tf.nest.assert_same_structure(
            flax_vars, stacked_transformer_layer.vars_to_flax_vars(pax_vars))

        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, p.model_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        segment_mask = None
        tf_segment_mask = None
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)
            if mask_self_attention:
                tf_segment_mask = batch_major_attention.CausalSegmentMask(
                    segment_ids, tf.float32)
            else:
                tf_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, segment_ids)

        cross_inputs = None
        cross_paddings = None
        cross_segment_mask = None
        tf_cross_inputs = None
        tf_cross_paddings = None
        tf_cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 64)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, p.model_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            tf_cross_inputs = tf.constant(npy_cross_inputs, dtype=tf.float32)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            tf_cross_paddings = tf.constant(npy_cross_paddings,
                                            dtype=tf.float32)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)
                tf_cross_segment_mask = batch_major_attention.SegmentMask(
                    segment_ids, source_segment_ids)

        outputs = test_utils.apply(stacked_transformer_layer,
                                   initial_vars,
                                   stacked_transformer_layer.fprop,
                                   inputs,
                                   paddings,
                                   context_p=None,
                                   segment_mask=segment_mask,
                                   cross_inputs=cross_inputs,
                                   cross_paddings=cross_paddings,
                                   cross_segment_mask=cross_segment_mask)
        logging.info('initial_vars in transformer layer = %s', initial_vars)

        # Test whether tf Transformer layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = py_utils.NestedMap()
        tf_initial_vars.x_layers = []
        for jax_initial_vars in initial_vars.x_layers:
            tf_layer_vars = test_utils.replace_jax_attention_vars_to_tf(
                jax_initial_vars, cross_attention)
            tf_initial_vars.x_layers.append(tf_layer_vars)
        tf_initial_vars = test_utils.to_tf_nmap(tf_initial_vars)
        logging.info('tf_initial_vars in transformer layer = %s', initial_vars)
        tf_p = batch_major_attention.StackedTransformerLayers.Params().Set(
            name='tf_transformer_layer',
            mdl_dim=p.model_dims,
            hidden_dim=p.hidden_dims,
            num_atten_heads=p.num_heads,
            mask_self_atten=mask_self_attention,
            num_layers=p.num_layers,
            packed_input=packed_input,
            has_aux_atten=cross_attention)
        tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.batch_norm = (
            False)
        tf_p.transformer_layer_params_tpl.tr_fflayer_tpl.fflayer_tpl.has_bias = True
        tf_stacked_transformer_layer = tf_p.Instantiate()
        tf_output, _ = tf_stacked_transformer_layer.FProp(
            tf_initial_vars,
            test_utils.to_tf_nmap(npy_inputs),
            paddings=test_utils.to_tf_nmap(npy_paddings),
            segment_mask=test_utils.to_tf_nmap(tf_segment_mask),
            aux_vec=test_utils.to_tf_nmap(tf_cross_inputs),
            aux_paddings=test_utils.to_tf_nmap(tf_cross_paddings),
            aux_segment_mask=test_utils.to_tf_nmap(tf_cross_segment_mask))
        np_outputs = test_utils.to_np(outputs)
        tf_np_outputs = test_utils.to_np(tf_output)
        self.assertAllClose(tf_np_outputs, np_outputs, atol=1e-5)
Example #18
0
 def test_single_sharded_softmax_layer(self, soft_cap_logits, use_class_ids,
                                       use_class_probabilities,
                                       label_smoothing_prob):
     if use_class_ids:
         class_ids = np.random.randint(0, 50, [8, 10, 1])
     else:
         class_ids = None
     if use_class_probabilities:
         class_probabilities = np.random.normal(1.5, 2.0, [8, 10, 50])
     else:
         class_probabilities = None
     p = embedding_softmax.SingleShardFullSoftmax.Params().Set(
         name='jax_softmax',
         num_classes=50,
         input_dims=40,
         soft_cap_logits=soft_cap_logits,
         label_smoothing_prob=label_smoothing_prob)
     softmax_layer = p.Instantiate()
     prng_key = jax.random.PRNGKey(seed=1234)
     initial_vars = softmax_layer.instantiate_variables(prng_key)
     npy_input = np.random.normal(1.5, 2.0, [8, 10, p.input_dims])
     inputs = jnp.asarray(npy_input)
     class_weights = np.random.normal(1.5, 2.0, [8, 10, 1])
     if class_probabilities is not None:
         class_probabilities /= np.sum(class_probabilities,
                                       axis=-1,
                                       keepdims=True)
     logits = test_utils.apply(softmax_layer, initial_vars,
                               softmax_layer.get_logits, inputs)
     outputs = test_utils.apply(softmax_layer,
                                initial_vars,
                                softmax_layer.fprop,
                                inputs,
                                class_weights,
                                class_ids=class_ids,
                                class_probabilities=class_probabilities)
     # Test whether tf Softmax layer returns same output
     # Modify initial_vars to use TF compatible params
     tf_initial_vars = initial_vars
     tf_initial_vars.linear = py_utils.NestedMap()
     tf_initial_vars.linear.w = initial_vars.logits_ffn.linear.w
     tf_initial_vars.bias = py_utils.NestedMap()
     tf_initial_vars.bias.b = initial_vars.logits_ffn.bias.b
     tf_p = lingvo_layers.SingleShardFullSoftmax.Params().Set(
         name='tf_softmax',
         num_classes=p.num_classes,
         input_dim=p.input_dims,
         logits_soft_max=soft_cap_logits)
     tf_softmax_layer = tf_p.Instantiate()
     tf_logits = tf_softmax_layer.Logits(
         tf_initial_vars, tf.constant(inputs, dtype=tf.float32))
     if use_class_ids and label_smoothing_prob > 0:
         class_probabilities = np.zeros([8, 10, 50])
         index = np.indices([8, 10])
         class_probabilities[index[0], index[1],
                             np.squeeze(class_ids, 2)] = 1
         class_probabilities = (
             class_probabilities * (1 - label_smoothing_prob) +
             (1 - class_probabilities) * label_smoothing_prob /
             (p.num_classes - 1))
         class_ids = None
     tf_output = tf_softmax_layer.FProp(
         tf_initial_vars,
         tf.constant(inputs, dtype=tf.float32),
         class_weights,
         class_ids=class_ids,
         class_probabilities=class_probabilities)
     # Check all entries in the NestedMap and ensure it matches TF
     np_get_logits = to_np(logits)
     tf_np_get_logits = to_np(tf_logits)
     self.assertAllClose(np_get_logits, tf_np_get_logits, atol=1e-6)
     # Note: The argmax-related values are very sensitive to numerical errors.
     for k in outputs.keys():
         self.assertAllClose(to_np(outputs[k]),
                             to_np(tf_output[k]),
                             atol=1e-6)
Example #19
0
    def test_single_sharded_shared_embedding_softmax_layer(
            self, soft_cap_logits, lookup_style, scale_sqrt_depth):
        class_ids = np.random.randint(1, 50, [8, 10, 1])
        p = embedding_softmax.SingleShardSharedEmbeddingSoftmax.Params().Set(
            name='jax_softmax',
            num_classes=50,
            input_dims=40,
            soft_cap_logits=soft_cap_logits,
            lookup_style=lookup_style,
            scale_sqrt_depth=scale_sqrt_depth)
        softmax_layer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = softmax_layer.instantiate_variables(prng_key)
        npy_input = np.random.normal(1.5, 2.0, [8, 10, p.input_dims])
        inputs = jnp.asarray(npy_input)
        class_weights = np.random.normal(1.5, 2.0, [8, 10, 1])
        outputs = test_utils.apply(softmax_layer,
                                   initial_vars,
                                   softmax_layer.fprop,
                                   inputs,
                                   class_weights,
                                   class_ids=class_ids)
        ids = np.squeeze(class_ids, axis=-1)
        emb_lookup_outputs = test_utils.apply(softmax_layer,
                                              initial_vars,
                                              softmax_layer.emb_lookup,
                                              ids=jnp.asarray(ids))
        # Test whether tf Softmax layer returns same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = initial_vars
        tf_initial_vars.linear = py_utils.NestedMap()
        tf_initial_vars.linear.w = initial_vars.logits_ffn.linear.w
        tf_initial_vars.bias = py_utils.NestedMap()
        tf_initial_vars.bias.b = initial_vars.logits_ffn.bias.b
        tf_p = lingvo_layers.SingleShardSharedEmbeddingSoftmax.Params().Set(
            name='tf_softmax',
            num_classes=p.num_classes,
            input_dim=p.input_dims,
            vocab_size=p.num_classes,
            embedding_dim=p.input_dims,
            logits_soft_max=soft_cap_logits,
            scale_sqrt_depth=scale_sqrt_depth)
        tf_softmax_layer = tf_p.Instantiate()
        tf_output = tf_softmax_layer.FProp(tf_initial_vars,
                                           tf.constant(inputs,
                                                       dtype=tf.float32),
                                           class_weights,
                                           class_ids=class_ids)
        tf_emb_lookup_output = tf_softmax_layer.EmbLookup(tf_initial_vars,
                                                          ids=tf.constant(ids))

        # Check all entries in the NestedMap and ensure it matches TF
        np_logits = to_np(outputs.logits)
        tf_np_logits = to_np(tf_output.logits)
        self.assertAllClose(np_logits, tf_np_logits, atol=1e-6)
        for k in outputs.keys():
            self.assertAllClose(to_np(outputs[k]),
                                to_np(tf_output[k]),
                                atol=1e-6)
        np_emb_lookup_output = to_np(emb_lookup_outputs)
        tf_np_emb_lookup_output = to_np(tf_emb_lookup_output)
        self.assertAllClose(tf_np_emb_lookup_output,
                            np_emb_lookup_output,
                            atol=1e-6)
Example #20
0
 def _process(source_id, record):
     del source_id
     num = tf.strings.to_number(record, tf.int32)
     if not tf_py_utils.use_tpu():
         num = num * num
     return py_utils.NestedMap(num=num), 1
Example #21
0
    def test_repeated_stacked_xformer_layer(self, mask_self_attention,
                                            packed_input, cross_attention):
        model_dims = 16
        p1 = transformers.StackedTransformer.Params().Set(
            name='jax_stacked_transformer_layer',
            model_dims=model_dims,
            hidden_dims=64,
            num_heads=8,
            mask_self_attention=mask_self_attention,
            num_layers=4,
            packed_input=packed_input,
            cross_attention=cross_attention)
        p1_one_layer = p1.Copy()
        p1_one_layer.num_layers = 1
        p2 = transformers.StackedTransformerRepeated.Params().Set(
            name='jax_stacked_transformer_layer_repeated',
            block=p1_one_layer,
            x_times=p1.num_layers)
        seq_len = np.random.randint(10, 32)
        batch_size = 10
        stacked_transformer_layer = p1.Instantiate()
        repeated_transformer_layer = p2.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)

        initial_vars = stacked_transformer_layer.instantiate_variables(
            prng_key)
        repeated_transformer_layer.instantiate_variable_configs()

        def _stack_vars(*args):
            args = [x[jnp.newaxis, :] for x in args]
            return jnp.vstack(args)

        stacked_vars = tf.nest.map_structure(_stack_vars,
                                             *initial_vars.x_layers)
        repeated_vars = py_utils.NestedMap(repeat=py_utils.NestedMap(
            sub=py_utils.NestedMap(x_layers=[stacked_vars])))

        tf.nest.assert_same_structure(
            repeated_vars,
            repeated_transformer_layer.instantiate_variables(prng_key))

        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, model_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)
        npy_paddings = np.random.randint(
            0, 1, [batch_size, seq_len]).astype('float32')
        paddings = jnp.asarray(npy_paddings)
        segment_mask = None
        if packed_input:
            segment_ids = np.random.random_integers(0, 2,
                                                    [batch_size, seq_len])
            segment_mask = attentions.segment_mask(segment_ids,
                                                   dtype=np.float32)

        cross_inputs = None
        cross_paddings = None
        cross_segment_mask = None
        if cross_attention:
            cross_seq_len = np.random.randint(10, 64)
            npy_cross_inputs = np.random.normal(
                1.0, 0.5,
                [batch_size, cross_seq_len, model_dims]).astype('float32')
            cross_inputs = jnp.asarray(npy_cross_inputs)
            npy_cross_paddings = np.random.randint(
                0, 1, [batch_size, cross_seq_len]).astype('float32')
            cross_paddings = jnp.asarray(npy_cross_paddings)
            if packed_input:
                source_segment_ids = np.random.random_integers(
                    0, 2, [batch_size, cross_seq_len])
                cross_segment_mask = attentions.segment_mask(
                    segment_ids, source_segment_ids, dtype=np.float32)

        outputs = test_utils.apply(stacked_transformer_layer,
                                   initial_vars,
                                   stacked_transformer_layer.fprop,
                                   inputs,
                                   paddings,
                                   context_p=None,
                                   segment_mask=segment_mask,
                                   cross_inputs=cross_inputs,
                                   cross_paddings=cross_paddings,
                                   cross_segment_mask=cross_segment_mask)

        outputs_repeated = test_utils.apply(
            repeated_transformer_layer,
            repeated_vars,
            repeated_transformer_layer.fprop,
            inputs,
            paddings,
            context_p=None,
            segment_mask=segment_mask,
            cross_inputs=cross_inputs,
            cross_paddings=cross_paddings,
            cross_segment_mask=cross_segment_mask)
        self.assertAllClose(outputs, outputs_repeated, atol=1e-5)
Example #22
0
 def _to_nested_map(self, x) -> py_utils.NestedMap:
     t = tf.ones(shape=[4], dtype=tf.int32) * tf.cast(x, dtype=tf.int32)
     return py_utils.NestedMap(data=t)
    def test_glam_unitransformer(self):
        batch = 2
        length = 3
        d_model = 6
        num_heads = 2
        vocab_size = 16
        ff_dim = 8
        c_dim = 3
        e_dim = 2
        num_layers = 4
        # Build jax layer
        jax_p = transformer_models.TransformerLm.GLaMUniTransformerParams(
            name='model',
            vocab_size=vocab_size,
            num_transformer_layers=num_layers,
            moe=True,
            model_dim=d_model,
            ff_dim=ff_dim,
            moe_hidden_dim=ff_dim,
            attention_num_heads=num_heads,
            attention_key_value_dim=d_model // num_heads,
            attention_extra_logit=0.0,
            use_tgt_labels_size_as_loss_denominator=True,
            moe_load_balance_loss_weight=0.01,
            z_loss_weight=1e-4,
            c_dim=c_dim,
            e_dim=e_dim)
        assert jax_p.packed_input
        jax_layer = jax_p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=42)
        jax_vars = jax_layer.instantiate_variables(prng_key)

        builder_p = gshard_builder.DenseBuilder.Params().Set(
            num_groups=1,
            second_expert_policy='all',
            relative_attention_type='bias',
            model_dim=d_model,
            attention_key_value_dim=d_model // num_heads,
            attention_num_heads=num_heads,
            attention_combine_dims=True,
            c_dim=c_dim,
            capacity_factor=None,
            attention_extra_logit=0.0,
            e_dim=e_dim,
            moe_hidden_dim=ff_dim,
            ff_dim=ff_dim)
        tf_layer = gshard_builder.UniTransformer.Params().Set(
            name='model',
            num_transformer_layers=num_layers,
            builder=builder_p,
            vocab_size=vocab_size,
            sequence_length=length,
            label_smoothing=0,
            aux_loss_coef=0.01,
            z_loss=1e-4,
            use_tgt_labels_size_as_loss_denominator=True,
            positional_embedding=False,
            gated_gelu=True,
            moe=True).Instantiate()

        # Build Jax Inputs
        np.random.seed(42)
        npy_ids = np.random.randint(0, vocab_size - 1, [batch, length])
        jax_ids = jnp.asarray(npy_ids)
        npy_paddings = np.array([[0, 0, 1], [0, 0, 1]], dtype=np.float32)

        jax_paddings = jnp.asarray(npy_paddings)
        npy_segment_ids = np.array([[1, 2, 0], [1, 1, 0]], dtype=np.int32)
        npy_segment_pos = np.array([[0, 0, 0], [0, 1, 0]], dtype=np.int32)
        npy_labels = np.roll(npy_ids, -1, axis=1)
        jax_labels = jnp.asarray(npy_labels)
        jax_seg_ids = jnp.asarray(npy_segment_ids)
        jax_seg_pos = jnp.asarray(npy_segment_pos)
        jax_label_weighs = jnp.asarray([[1, 1, 0], [1, 1, 0]])

        # Build TF Inputs
        tf_tgt_inputs = py_utils.NestedMap(
            ids=tf.convert_to_tensor(npy_ids, dtype=tf.int32),
            labels=tf.convert_to_tensor(npy_labels, dtype=tf.int32),
            segment_ids=tf.convert_to_tensor(npy_segment_ids, dtype=tf.int32),
            segment_pos=tf.convert_to_tensor(npy_segment_pos, dtype=tf.int32))
        tf_inputs = py_utils.NestedMap(tgt=tf_tgt_inputs)

        # Compute jax outputs
        jax_outputs = test_utils.apply(jax_layer,
                                       jax_vars,
                                       jax_layer.fprop,
                                       jax_ids,
                                       jax_paddings,
                                       context_p=None,
                                       labels=py_utils.NestedMap(
                                           class_ids=jax_labels,
                                           class_weights=jax_label_weighs,
                                       ),
                                       segment_ids=jax_seg_ids,
                                       segment_pos=jax_seg_pos)

        # Copy jax vars to tf ones.
        tf_theta = tf_layer.theta.DeepCopy()

        # GShardBuilder softmax weight use self.vars rather than theta.
        tf_layer.vars.dec_emb.w.embedding.assign(jax_vars.softmax.embedding.w)
        tf_theta.dec_emb.w.embedding = jax_vars.softmax.embedding.w
        tf_theta.dec.final_layer_norm.w.scale = jax_vars.final_ln.scale
        jax_layer_0_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[0])
        tf_theta.dec.layer_000.ln.w.scale = jax_layer_0_var.layer_norm.scale
        jax_atten_var = jax_layer_0_var.self_attention
        tf_atten_var = tf_theta.dec.layer_000.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_moe_var = jax_layer_0_var.ff_layer
        tf_theta.dec.layer_001.ln.w.scale = jax_moe_var.layer_norm.scale
        tf_theta.dec.layer_001.moe.ffw.top_2_gating.w = jax_moe_var.gate
        tf_theta.dec.layer_001.moe.moe.wi = jax_moe_var.wi_0
        tf_theta.dec.layer_001.moe.moe.wo = jax_moe_var.wo_0

        jax_layer_1_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[0], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[1])
        tf_theta.dec.layer_002.ln.w.scale = jax_layer_1_var.layer_norm.scale
        jax_atten_var = jax_layer_1_var.self_attention
        tf_atten_var = tf_theta.dec.layer_002.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_ffn_var = jax_layer_1_var.ff_layer
        tf_ffn_var = tf_theta.dec.layer_003.dense_relu_dense
        tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w
        tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w
        tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w
        tf_theta.dec.layer_003.ln.w.scale = jax_ffn_var.layer_norm.scale

        jax_layer_2_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[0])
        tf_theta.dec.layer_004.ln.w.scale = jax_layer_2_var.layer_norm.scale
        jax_atten_var = jax_layer_2_var.self_attention
        tf_atten_var = tf_theta.dec.layer_004.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_moe_var = jax_layer_2_var.ff_layer
        tf_theta.dec.layer_005.ln.w.scale = jax_moe_var.layer_norm.scale
        tf_theta.dec.layer_005.moe.ffw.top_2_gating.w = jax_moe_var.gate
        tf_theta.dec.layer_005.moe.moe.wi = jax_moe_var.wi_0
        tf_theta.dec.layer_005.moe.moe.wo = jax_moe_var.wo_0

        jax_layer_3_var = tf.nest.map_structure(
            lambda v: jnp.squeeze(jnp.split(v, 2)[1], axis=0),
            jax_vars.transformer.repeat.sub.x_layers[1])
        tf_theta.dec.layer_006.ln.w.scale = jax_layer_3_var.layer_norm.scale
        jax_atten_var = jax_layer_3_var.self_attention
        tf_atten_var = tf_theta.dec.layer_006.dec_self_attention
        tf_atten_var.w.wk = jax_atten_var.key.w
        tf_atten_var.w.wq = jax_atten_var.query.w
        tf_atten_var.w.wv = jax_atten_var.value.w
        tf_atten_var.w.wo = jax_atten_var.post.w
        tf_atten_var.wrb.wrb = jax_atten_var.relative_bias.wrb

        jax_ffn_var = jax_layer_3_var.ff_layer
        tf_ffn_var = tf_theta.dec.layer_007.dense_relu_dense
        tf_ffn_var.w.wi_0 = jax_ffn_var.ffn_layer1_gate.linear.w
        tf_ffn_var.w.wi_1 = jax_ffn_var.ffn_layer1.linear.w
        tf_ffn_var.w.wo = jax_ffn_var.ffn_layer2.linear.w
        tf_theta.dec.layer_007.ln.w.scale = jax_ffn_var.layer_norm.scale

        tf_theta = test_utils.to_tf_nmap(tf_theta)

        # Compute TF outputs
        tf_out, _ = tf_layer.FProp(tf_theta, tf_inputs)
        self.assertAllClose(test_utils.to_np(jax_outputs.total_loss),
                            test_utils.to_np(tf_out['loss'][0]))
Example #24
0
  def fprop(self, z: JTensor, paddings: JTensor) -> NestedMap:
    """Quantizes 'z' of shape [B, T, D].

    The z_codes of padded locations are 0.

    Args:
      z:        [B, T, D].
      paddings: [B, T].

    Returns:
      A NestedMap of
        - z_q:               [B, T, D].
        - z_codes:           [B, T, G].
        - z_onehot:          [B, T, G, C].
        - loss:              [], weighted sum of quantization loss and
          commitment loss.
        - codebook_coverage: [], a float scalar tensor between [0, 1].
        - pplx:              [], pplx of quantized distribution over the
          codebook.
        - entropy:           [], exp(pplx).
    """
    p = self.params
    theta = self.local_theta()
    b, t, d = z.shape
    g, c = p.num_groups, p.num_latent_classes

    mask = 1.0 - paddings
    num_frames = jnp.sum(mask)
    z = self._apply_mask(z, mask)

    if p.normalize_latent_vector:
      z = self._l2_normalize(z, axis=-1)

    # [b * t, d], [b * t, g], [b * t, g, c]
    z_q, z_codes, z_onehot = quantize_vector(
        jnp.reshape(z, [b * t, d]), self._get_latent_embedding(theta))

    z_q = jnp.reshape(z_q, [b, t, d])
    z_codes = jnp.reshape(z_codes, [b, t, g])
    z_onehot = jnp.reshape(z_onehot, [b, t, g, c])

    # Padded locations are all 0s without any 1.
    z_q = self._apply_mask(z_q, mask)
    # [b, t, g]
    z_codes = self._apply_mask(z_codes, mask)
    # [b, t, g, c]
    z_onehot = self._apply_mask(z_onehot, mask)

    # Move z towards z_q.
    normalizer = 1e-7 + num_frames
    # [b, t, d]
    loss_c = (z - jax.lax.stop_gradient(z_q))**2
    # [b, t, d] -> [b, t] -> []
    loss_c = jnp.sum(jnp.mean(loss_c, -1)) / normalizer
    # loss_c = py_utils.check_numerics(loss_c, 'loss_c has NaN.')

    # Move z_q towards z.
    loss_z = (z_q - jax.lax.stop_gradient(z))**2
    loss_z = jnp.sum(jnp.mean(loss_z, -1)) / normalizer
    # loss_z = py_utils.check_numerics(loss_z, 'loss_z has NaN.')
    loss = loss_z + p.beta * loss_c

    # Straight-through estimator.
    # Doesn't look like this line does anyhing besides stopping gradient ??
    z_q = z + jax.lax.stop_gradient(z_q - z)

    # [], []
    pplx, entropy, _ = objectives.batch_pplx_entropy_from_codes(
        z_codes, c, paddings=paddings)
    # pplx = py_utils.check_numerics(pplx, f'{p.name} perplexity NaN')

    codebook_coverage = objectives.batch_codebook_coverage(
        z_codes, c, paddings=paddings)
    codebook_num_covered_words = codebook_coverage * c**g

    return py_utils.NestedMap(
        z_q=z_q,
        z_codes=z_codes,
        z_onehot=z_onehot,
        loss=loss,
        codebook_coverage=codebook_coverage,
        codebook_num_covered_words=codebook_num_covered_words,
        pplx=pplx,
        entropy=entropy)
Example #25
0
    def test_conformer_layer(self, batch_size, seq_len, kernel_size,
                             input_dims, model_dims, atten_num_heads,
                             dropout_prob):
        # Lingvo TF layers only use dropout on FF and Attention layers
        p = conformers.Conformer.Params().Set(
            name='jax_conformer_layer',
            input_dims=input_dims,
            conv_residual_dropout=0.0,
            atten_residual_dropout=dropout_prob,
            ffn_residual_dropout=dropout_prob,
            atten_dropout=dropout_prob,
            ffn_relu_dropout=dropout_prob,
            kernel_size=kernel_size,
            model_dims=model_dims,
            atten_num_heads=atten_num_heads)
        conformer = p.Instantiate()
        prng_key = jax.random.PRNGKey(seed=123)
        initial_vars = conformer.instantiate_variables(prng_key)
        npy_inputs = np.random.normal(
            1.0, 0.5, [batch_size, seq_len, input_dims]).astype('float32')
        inputs = jnp.asarray(npy_inputs)

        def GetPaddingfromLength(length):
            idx = np.tile(np.arange(seq_len), [batch_size, 1])
            return (idx >= np.expand_dims(length, -1)).astype('float32')

        length = np.random.randint(seq_len // 2, seq_len, (batch_size, ))
        npy_paddings = GetPaddingfromLength(length).astype('float32')
        paddings = jnp.asarray(npy_paddings)

        context_p = base_layer.JaxContext.Params().Set(do_eval=True)

        output = test_utils.apply(conformer,
                                  initial_vars,
                                  conformer.fprop,
                                  inputs,
                                  paddings,
                                  context_p=context_p)
        # Test whether tf Conformer layer returns the same output
        # Modify initial_vars to use TF compatible params
        tf_initial_vars = test_utils.replace_jax_conformer_layer_vars_to_tf(
            initial_vars)

        tf_p = conformer_layer.ConformerLayer.CommonParams(
            input_dim=input_dims,
            dropout_prob=dropout_prob,
            atten_num_heads=atten_num_heads,
            kernel_size=kernel_size,
            fflayer_hidden_dim=model_dims * p.ffn_dim_multiplier,
            use_relative_atten=False,
            fflayer_residual_weight=0.5).Set(name='tf_conformer')
        tf_p.trans_atten_tpl = tf_p.trans_atten_tpl.Set(hidden_dim=model_dims)

        tf_conformer = tf_p.Instantiate()
        with cluster_factory.SetEval(True):
            tf_output = tf_conformer.FProp(
                tf_initial_vars,
                py_utils.NestedMap(features=tf.constant(inputs,
                                                        dtype=tf.float32),
                                   paddings=tf.constant(npy_paddings,
                                                        dtype=tf.float32)))
        np_output = to_np(output)
        tf_np_output = to_np(tf_output.features)
        self.assertAllClose(tf_np_output, np_output, atol=1e-5)