def dae_joint_loss(features, labels, mems, n_token, is_training): """DAE loss with generator.""" del mems del labels initializer = _get_initializer() monitor_dict = {} ##### unpack input gen_inp = features["gen_inp"] gen_tgt = features["gen_tgt"] gen_mask_map = features["gen_mask_map"] gen_tgt_mask = features["gen_tgt_mask"] enc_mask = features["enc_mask"] enc_type = features["enc_type"] # enc_edit_label = features["enc_edit_label"] dec_inp = features["dec_inp"] dec_tgt = features["dec_tgt"] dec_mask = features["dec_mask"] dec_type = features["dec_type"] edit_label = features["edit_label"] dec_mask_map = features["dec_mask_map"] dec_masked_tgt = features["dec_masked_tgt"] dec_lm_tgt_mask = features["dec_lm_tgt_mask"] rep_enc2dec_full = features["rep_enc2dec_full"] rep_enc2dec_part = features["rep_enc2dec_part"] enc_pos = features["enc_pos"] dec_pos = features["dec_pos"] if FLAGS.double_type: # offer a indicator to differeniate encoder and decoder dec_type = dec_type + FLAGS.n_type initializer = _get_initializer() if FLAGS.use_bfloat16: tf_float = tf.bfloat16 else: tf_float = tf.float32 #### Shared input embedding (for generator) with tf.variable_scope("model", reuse=tf.AUTO_REUSE): inp_func = _get_inp_func(n_token, FLAGS.d_model, initializer, is_training) gen_embed, shared_embed_table = inp_func(inputs=gen_inp, type_id=enc_type, return_embed_table=True) #### Generator TFM with tf.variable_scope("generator", reuse=tf.AUTO_REUSE): gen_tfm_func = _get_tfm_func(initializer, is_training, "pretrain", shrink=FLAGS.gen_shrink) gen_output, _ = gen_tfm_func(inputs=gen_embed, input_mask=enc_mask, perm_mask=None) gen_lm_loss, _, logits = model.lm_loss(hidden=gen_output, target=gen_tgt, n_token=n_token, d_model=FLAGS.d_model, initializer=initializer, lookup_table=shared_embed_table, tie_weight=FLAGS.tie_weight, target_mapping=None, hidden_mapping=gen_mask_map, return_logits=True, use_tpu=FLAGS.use_tpu) if gen_lm_loss.dtype != tf.float32: gen_lm_loss = tf.cast(gen_lm_loss, tf.float32) gen_tgt_mask = tf.cast(gen_tgt_mask, gen_lm_loss.dtype) gen_loss = (tf.reduce_sum(gen_lm_loss * gen_tgt_mask) / tf.reduce_sum(gen_tgt_mask)) monitor_dict["gen_loss"] = gen_loss total_loss = gen_loss #### Sample from generator uniform = tf.random.uniform(minval=0, maxval=1, shape=logits.shape, dtype=tf.float32) gumbel = -tf.log(-tf.log(uniform + 1e-9) + 1e-9) samples = tf.argmax(logits + gumbel, -1) gen_tokens = tf.cast(samples, tf.int32) # map `num_predict` samples to full length samples = tf.einsum("bm,bml->bl", tf.cast(samples, tf.float32), tf.cast(gen_mask_map, tf.float32)) samples = tf.cast(samples, tf.int32) enc_inp = tf.where(tf.equal(gen_inp, FLAGS.mask_id), samples, gen_inp) #### get the mask for generated same token as the target same_mask = tf.equal(gen_tokens, gen_tgt) same_mask = (tf.cast(same_mask, rep_enc2dec_full.dtype) * tf.cast(gen_tgt_mask, rep_enc2dec_full.dtype)) # monitor how many generated tokens are the same as the real ones same_prec = (tf.reduce_sum(tf.cast(same_mask, gen_tgt_mask.dtype)) / tf.reduce_sum(gen_tgt_mask)) monitor_dict["same_percent"] = same_prec # If same, change the edit_label to original (0) dec_same_mask_full = tf.einsum("bi,bil->bl", same_mask, rep_enc2dec_full) edit_label = tf.where( tf.cast(dec_same_mask_full, tf.bool), tf.zeros(tf.shape(dec_same_mask_full), dtype=edit_label.dtype), edit_label) # If same, exclude from LM loss dec_same_mask_part = tf.einsum("bi,bij->bj", same_mask, rep_enc2dec_part) dec_diff_mask_part = 1.0 - tf.cast(dec_same_mask_part, dec_lm_tgt_mask.dtype) dec_lm_tgt_mask = dec_lm_tgt_mask * dec_diff_mask_part # shapes bsz = tf.shape(enc_inp)[0] src_len = tf.shape(enc_inp)[1] tgt_len = tf.shape(dec_inp)[1] ##### format joint model inputs inputs = tf.concat([enc_inp, dec_inp], axis=1) type_id = tf.concat([enc_type, dec_type], axis=1) source_seg = tf.zeros([bsz, src_len], dtype=inputs.dtype) target_seg = tf.ones([bsz, tgt_len], dtype=inputs.dtype) ##### attention mask: note that `1` indicates CANNOT attend # src mask src_to_src = tf.not_equal(source_seg[:, :, None], source_seg[:, None, :]) src_to_tgt = tf.ones([bsz, src_len, tgt_len], dtype=src_to_src.dtype) src_mask = tf.concat([src_to_src, src_to_tgt], axis=2) # tgt mask tgt_to_src = tf.not_equal(target_seg[:, :, None], source_seg[:, None, :]) tgt_to_tgt = tf.not_equal(target_seg[:, :, None], target_seg[:, None, :]) causal_mask = tf.cast(causal_attn_mask(qlen=tgt_len), tgt_to_tgt.dtype) # If any one of them is `1` (indicating cannot attend), i.e. `logical_or`, # then the model should NOT attend tgt_to_tgt = tf.logical_or(tgt_to_tgt, causal_mask) tgt_mask = tf.concat([tgt_to_src, tgt_to_tgt], axis=2) # concat perm_mask = tf.concat([src_mask, tgt_mask], axis=1) perm_mask = tf.cast(perm_mask, tf_float) ##### Position sequence pos_seq = tf.concat([enc_pos, dec_pos], axis=1) #### Transformer Model total_loss = 0 with tf.variable_scope("model", reuse=tf.AUTO_REUSE): input_embed = inp_func(inputs=inputs, pos_seq=pos_seq, type_id=type_id) tfm_func = _get_tfm_func(initializer, is_training, phase="pretrain") output, _ = tfm_func(inputs=input_embed, input_mask=None, perm_mask=perm_mask) #### edit loss tgt_out = output[:, src_len:] edit_loss, edit_logits = model.cls_loss(hidden=tgt_out, target=edit_label, n_cls=4, d_model=FLAGS.d_model, initializer=initializer, target_mapping=None, hidden_mapping=None, return_logits=True, scope="edit_type_loss") dec_tgt_mask = tf.cast(1.0 - dec_mask, tf.float32) edit_loss = (tf.reduce_sum(edit_loss * dec_tgt_mask) / tf.reduce_sum(dec_tgt_mask)) edit_pred = tf.cast(tf.argmax(edit_logits, axis=-1), dtype=edit_label.dtype) edit_corr = tf.cast(tf.equal(edit_pred, edit_label), dtype=tf.float32) edit_accu = (tf.reduce_sum(edit_corr * dec_tgt_mask) / tf.reduce_sum(dec_tgt_mask)) if FLAGS.edit_weight > 0: # monitor monitor_dict["edit_loss"] = edit_loss monitor_dict["accu_edit"] = edit_accu def get_class_acc(label_id): mask = tf.ones(tf.shape(edit_label), dtype=edit_label.dtype) * label_id mask = tf.equal(edit_label, mask) mask = tf.cast(mask, dtype=dec_tgt_mask.dtype) * dec_tgt_mask accu = tf.reduce_sum(edit_corr * mask) / tf.reduce_sum(mask) return accu monitor_dict["accu_orig"] = get_class_acc(0) if FLAGS.del_ratio > 0: monitor_dict["accu_del"] = get_class_acc(FLAGS.del_label) if FLAGS.ins_ratio > 0: monitor_dict["accu_ins"] = get_class_acc(FLAGS.ins_label) if FLAGS.rep_ratio > 0: monitor_dict["accu_rep"] = get_class_acc(FLAGS.rep_label) # accumulate total loss total_loss += FLAGS.edit_weight * edit_loss #### LM loss #### Only predict the target part tgt_out = output[:, src_len:] tf.logging.info("Output: %s, target output: %s", output.shape, tgt_out.shape) lm_loss, _ = model.lm_loss(hidden=tgt_out, target=dec_masked_tgt, n_token=n_token, d_model=FLAGS.d_model, initializer=initializer, lookup_table=shared_embed_table, tie_weight=FLAGS.tie_weight, target_mapping=None, hidden_mapping=dec_mask_map, use_tpu=FLAGS.use_tpu) if lm_loss.dtype != tf.float32: lm_loss = tf.cast(lm_loss, tf.float32) dec_lm_tgt_mask = tf.cast(dec_lm_tgt_mask, lm_loss.dtype) lm_loss = (tf.reduce_sum(lm_loss * dec_lm_tgt_mask) / tf.reduce_sum(dec_lm_tgt_mask)) if FLAGS.lm_weight > 0: # monitor monitor_dict["lm_loss"] = lm_loss # accumulate total loss total_loss += FLAGS.lm_weight * lm_loss return total_loss, {}, monitor_dict
def joint_loss(features, labels, n_token, is_training): """Decoder only seq2seq.""" del labels initializer = _get_initializer() #### Unpack input source = features["source"] source_pos = features["source_position"] source_seg = features["source_segmentation"] target = features["target"] target_pos = features["target_position"] target_seg = features["target_segmentation"] # shapes bsz = tf.shape(source)[0] src_len = tf.shape(source)[1] tgt_len = tf.shape(target)[1] if FLAGS.use_bfloat16: tf_float = tf.bfloat16 else: tf_float = tf.float32 ##### format inputs inputs = tf.concat([source, target], axis=1) position = tf.concat([source_pos, target_pos], axis=1) src_type_id = features.get("source_type", tf.zeros([bsz, src_len], dtype=inputs.dtype)) tgt_type_id = features.get("target_type", tf.ones([bsz, tgt_len], dtype=inputs.dtype)) if FLAGS.double_type: tgt_type_id = tgt_type_id + FLAGS.n_token type_id = tf.concat([src_type_id, tgt_type_id], axis=1) ##### attention mask: note that `1` indicates CANNOT attend # src mask src_to_src = tf.not_equal(source_seg[:, :, None], source_seg[:, None, :]) src_to_tgt = tf.ones([bsz, src_len, tgt_len], dtype=src_to_src.dtype) src_mask = tf.concat([src_to_src, src_to_tgt], axis=2) # tgt mask tgt_to_src = tf.not_equal(target_seg[:, :, None], source_seg[:, None, :]) tgt_to_tgt = tf.not_equal(target_seg[:, :, None], target_seg[:, None, :]) causal_mask = tf.cast(causal_attn_mask(qlen=tgt_len), tgt_to_tgt.dtype) # If any one of them is `1` (indicating cannot attend), i.e. `logical_or`, # then the model should NOT attend tgt_to_tgt = tf.logical_or(tgt_to_tgt, causal_mask) tgt_mask = tf.concat([tgt_to_src, tgt_to_tgt], axis=2) # concat perm_mask = tf.concat([src_mask, tgt_mask], axis=1) perm_mask = tf.cast(perm_mask, tf_float) # padding non_pad_mask = tf.not_equal(target_seg, 0) all_eos = tf.constant(FLAGS.eos_id, shape=target.shape, dtype=target.dtype) # Replace all <pad> (/P) with <eos> (/S) # - target : /S a1 a2 a3 /S b1 b2 /S c1 c2 /P /P # - tmptgt : /S a1 a2 a3 /S b1 b2 /S c1 c2 /S /S tmptgt = tf.where(non_pad_mask, target, all_eos) # Shift the `tmptgt` to form the (next-step) prediction target # - target : \S a1 a2 a3 \S b1 b2 \S c1 c2 \P \P # - pred_tgt : a1 a2 a3 \S b1 b2 \S c1 c2 \S \S \S pred_tgt = tf.concat([tmptgt[:, 1:], tmptgt[:, :1]], axis=1) loss_mask = tf.cast(non_pad_mask, tf.float32) #### Transformer Model with tf.variable_scope("model", reuse=tf.AUTO_REUSE): inp_func = _get_inp_func(n_token, FLAGS.d_model, initializer, is_training) input_embed, word_embed_table = inp_func(inputs=inputs, type_id=type_id, pos_seq=position, return_embed_table=True) tfm_func = _get_tfm_func(initializer, is_training, phase="pretrain") output, _ = tfm_func(inputs=input_embed, input_mask=None, perm_mask=perm_mask, pos_seq=position) #### Only predict the target part tgt_out = output[:, src_len:] tf.logging.info("Output: %s, target output: %s", output.shape, tgt_out.shape) lm_loss, nll_loss = model.lm_loss(hidden=tgt_out, target=pred_tgt, n_token=n_token, d_model=FLAGS.d_model, initializer=initializer, lookup_table=word_embed_table, tie_weight=FLAGS.tie_weight, label_smooth=FLAGS.label_smooth, use_tpu=FLAGS.use_tpu) if lm_loss.dtype != tf.float32: lm_loss = tf.cast(lm_loss, tf.float32) num_loss = tf.reduce_sum(loss_mask) total_loss = tf.reduce_sum(lm_loss * loss_mask) / num_loss nll = tf.reduce_sum(nll_loss * loss_mask) / num_loss # To be compatible with fairseq, convert to base 2 for logging monitor_dict = { "loss": total_loss / math.log(2), "nll": nll / math.log(2), "num_loss": num_loss, } return total_loss, monitor_dict
def transformer(inputs, n_layer, d_model, n_head, d_head, d_inner, dropout, dropatt, dropact, initializer, is_training, context=None, context_mask=None, input_mask=None, perm_mask=None, ff_activation="relu", causal=False, rel_attn=False, pos_seq=None, clamp_len=-1, return_all_hidden=False, scope="transformer"): """Transformer model.""" monitor_dict = {} if input_mask is not None: monitor_dict["inp_mask"] = input_mask tf.logging.info("===== Transformer =====") tf.logging.info("Input related:") tf.logging.info(" - inputs %s", inputs) tf.logging.info(" - input_mask %s", input_mask) tf.logging.info(" - perm_mask %s", perm_mask) tf.logging.info(" - context %s", context) tf.logging.info(" - context_mask %s", context_mask) tf.logging.info("Hparam related:") tf.logging.info(" - initializer %s", initializer) tf.logging.info(" - ff_activation %s", ff_activation) tf.logging.info(" - causal %s", causal) tf.logging.info("============================") hiddens = [] with tf.variable_scope(scope): ##### Attention mask if causal: causal_mask = causal_attn_mask(tf.shape(inputs)[1], dtype=inputs.dtype) causal_mask = causal_mask[None, None] else: causal_mask = None attn_mask = merge_attn_masks(causal_mask, input_mask, perm_mask) ##### Input projection if inputs.shape.as_list()[-1] != d_model: tf.logging.info("Project input embedding: %d -> %d", inputs.shape.as_list()[-1], d_model) output = tf.layers.dense(inputs, d_model, activation=None, kernel_initializer=initializer, name="input_projection") else: output = inputs hiddens.append(output) ##### Get relative attention bias if rel_attn: tf.logging.info("Use relative attention") if pos_seq is None: seq_len = tf.shape(output)[1] attn_bias = consecutive_rel_encoding(seq_len, d_model, n_head, clamp_len, dropout, is_training, initializer, output.dtype) else: attn_bias = rel_encoding(pos_seq, pos_seq, d_model, n_head, clamp_len, dropout, is_training, initializer, output.dtype) else: attn_bias = None ##### Attention layers for i in range(n_layer): with tf.variable_scope("layer_{}".format(i)): output, attn_dict = multihead_attn( q=output, k=output, v=output, attn_mask=attn_mask, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, attn_bias=attn_bias, scope="self_attn") if context is not None: output, _ = multihead_attn(q=output, k=context, v=context, attn_mask=context_mask, d_model=d_model, n_head=n_head, d_head=d_head, dropout=dropout, dropatt=dropatt, is_training=is_training, kernel_initializer=initializer, scope="cross_attn") output, ffn_dict = positionwise_ffn( inp=output, d_model=d_model, d_inner=d_inner, dropout=dropout, dropact=dropact, initializer=initializer, activation_type=ff_activation, is_training=is_training) hiddens.append(output) # Update monitor dict monitor_dict = update_monitor_dict( monitor_dict, attn_dict, prefix="layer_{}_attn".format(i)) monitor_dict = update_monitor_dict( monitor_dict, ffn_dict, prefix="layer_{}_ff".format(i)) if return_all_hidden: return hiddens, monitor_dict else: return output, monitor_dict
def mass_loss(features, labels, mems, n_token, is_training): """MASS pretraining loss.""" del labels del mems initializer = _get_initializer() # Type if FLAGS.use_bfloat16: tf_float = tf.bfloat16 else: tf_float = tf.float32 #### Unpack input target = features["target"] target_mask = features["target_mask"] target_mapping = features["target_mapping"] enc_inp = features["enc_inp"] enc_type = features["type_id"] enc_pos = features["enc_pos"] dec_inp = features["dec_inp"] dec_type = features["dec_type"] dec_pos = features["dec_pos"] dec_seg = features["dec_seg"] # shapes bsz = tf.shape(enc_inp)[0] enc_len = tf.shape(enc_inp)[1] dec_len = tf.shape(dec_inp)[1] ##### format inputs inputs = tf.concat([enc_inp, dec_inp], axis=1) position = tf.concat([enc_pos, dec_pos], axis=1) type_id = tf.concat([enc_type, dec_type], axis=1) ##### attention mask: note that `1` indicates CANNOT attend # enc mask enc_to_enc = tf.zeros([bsz, enc_len, enc_len], dtype=tf_float) enc_to_dec = tf.ones([bsz, enc_len, dec_len], dtype=tf_float) enc_mask = tf.concat([enc_to_enc, enc_to_dec], axis=2) # dec mask dec_to_enc = tf.zeros([bsz, dec_len, enc_len], dtype=tf.bool) dec_to_dec = tf.not_equal(dec_seg[:, :, None], dec_seg[:, None, :]) causal_mask = tf.cast(causal_attn_mask(qlen=dec_len), tf.bool) # If any one of them is `1` (indicating cannot attend), i.e. `logical_or`, # then the model should NOT attend dec_to_dec = tf.logical_or(dec_to_dec, causal_mask) dec_mask = tf.cast(tf.concat([dec_to_enc, dec_to_dec], axis=2), tf_float) # concat perm_mask = tf.concat([enc_mask, dec_mask], axis=1) perm_mask = tf.cast(perm_mask, tf_float) #### Transformer Model with tf.variable_scope("model", reuse=tf.AUTO_REUSE): inp_func = _get_inp_func(n_token, FLAGS.d_model, initializer, is_training) input_embed, word_embed_table = inp_func(inputs=inputs, type_id=type_id, pos_seq=position, return_embed_table=True) tfm_func = _get_tfm_func(initializer, is_training, phase="pretrain") output, _ = tfm_func(inputs=input_embed, input_mask=None, perm_mask=perm_mask) #### Only predict the target part enc_out = output[:, :enc_len] dec_out = output[:, enc_len:] tf.logging.info("Output: %s, enc output: %s, dec output: %s", output.shape, enc_out.shape, dec_out.shape) enc_loss, _ = model.lm_loss(hidden=enc_out, target=target, n_token=n_token, d_model=FLAGS.d_model, initializer=initializer, lookup_table=word_embed_table, tie_weight=FLAGS.tie_weight, hidden_mapping=target_mapping, use_tpu=FLAGS.use_tpu) dec_loss, _ = model.lm_loss(hidden=dec_out, target=target, n_token=n_token, d_model=FLAGS.d_model, initializer=initializer, lookup_table=word_embed_table, tie_weight=FLAGS.tie_weight, use_tpu=FLAGS.use_tpu) if dec_loss.dtype != tf.float32: dec_loss = tf.cast(dec_loss, tf.float32) if enc_loss.dtype != tf.float32: enc_loss = tf.cast(enc_loss, tf.float32) loss_mask = tf.cast(target_mask, tf.float32) num_loss = tf.reduce_sum(loss_mask) avg_dec_loss = tf.reduce_sum(dec_loss * loss_mask) / num_loss avg_enc_loss = tf.reduce_sum(enc_loss * loss_mask) / num_loss monitor_dict = {} total_loss = 0 if FLAGS.enc_weight > 0: total_loss += FLAGS.enc_weight * avg_enc_loss monitor_dict["loss_enc"] = avg_enc_loss if FLAGS.dec_weight > 0: total_loss += FLAGS.dec_weight * avg_dec_loss monitor_dict["loss_dec"] = avg_dec_loss monitor_dict["loss"] = total_loss return total_loss, {}, monitor_dict