コード例 #1
0
def dae_joint_loss(features, labels, mems, n_token, is_training):
    """DAE loss with generator."""
    del mems
    del labels

    initializer = _get_initializer()
    monitor_dict = {}

    ##### unpack input
    gen_inp = features["gen_inp"]
    gen_tgt = features["gen_tgt"]
    gen_mask_map = features["gen_mask_map"]
    gen_tgt_mask = features["gen_tgt_mask"]

    enc_mask = features["enc_mask"]
    enc_type = features["enc_type"]
    # enc_edit_label = features["enc_edit_label"]

    dec_inp = features["dec_inp"]
    dec_tgt = features["dec_tgt"]
    dec_mask = features["dec_mask"]
    dec_type = features["dec_type"]
    edit_label = features["edit_label"]
    dec_mask_map = features["dec_mask_map"]
    dec_masked_tgt = features["dec_masked_tgt"]
    dec_lm_tgt_mask = features["dec_lm_tgt_mask"]
    rep_enc2dec_full = features["rep_enc2dec_full"]
    rep_enc2dec_part = features["rep_enc2dec_part"]

    enc_pos = features["enc_pos"]
    dec_pos = features["dec_pos"]

    if FLAGS.double_type:
        # offer a indicator to differeniate encoder and decoder
        dec_type = dec_type + FLAGS.n_type

    initializer = _get_initializer()

    if FLAGS.use_bfloat16:
        tf_float = tf.bfloat16
    else:
        tf_float = tf.float32

    #### Shared input embedding (for generator)
    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        inp_func = _get_inp_func(n_token, FLAGS.d_model, initializer,
                                 is_training)
        gen_embed, shared_embed_table = inp_func(inputs=gen_inp,
                                                 type_id=enc_type,
                                                 return_embed_table=True)

    #### Generator TFM
    with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
        gen_tfm_func = _get_tfm_func(initializer,
                                     is_training,
                                     "pretrain",
                                     shrink=FLAGS.gen_shrink)
        gen_output, _ = gen_tfm_func(inputs=gen_embed,
                                     input_mask=enc_mask,
                                     perm_mask=None)

        gen_lm_loss, _, logits = model.lm_loss(hidden=gen_output,
                                               target=gen_tgt,
                                               n_token=n_token,
                                               d_model=FLAGS.d_model,
                                               initializer=initializer,
                                               lookup_table=shared_embed_table,
                                               tie_weight=FLAGS.tie_weight,
                                               target_mapping=None,
                                               hidden_mapping=gen_mask_map,
                                               return_logits=True,
                                               use_tpu=FLAGS.use_tpu)

        if gen_lm_loss.dtype != tf.float32:
            gen_lm_loss = tf.cast(gen_lm_loss, tf.float32)
        gen_tgt_mask = tf.cast(gen_tgt_mask, gen_lm_loss.dtype)
        gen_loss = (tf.reduce_sum(gen_lm_loss * gen_tgt_mask) /
                    tf.reduce_sum(gen_tgt_mask))
        monitor_dict["gen_loss"] = gen_loss

        total_loss = gen_loss

    #### Sample from generator
    uniform = tf.random.uniform(minval=0,
                                maxval=1,
                                shape=logits.shape,
                                dtype=tf.float32)
    gumbel = -tf.log(-tf.log(uniform + 1e-9) + 1e-9)
    samples = tf.argmax(logits + gumbel, -1)
    gen_tokens = tf.cast(samples, tf.int32)

    # map `num_predict` samples to full length
    samples = tf.einsum("bm,bml->bl", tf.cast(samples, tf.float32),
                        tf.cast(gen_mask_map, tf.float32))
    samples = tf.cast(samples, tf.int32)
    enc_inp = tf.where(tf.equal(gen_inp, FLAGS.mask_id), samples, gen_inp)

    #### get the mask for generated same token as the target
    same_mask = tf.equal(gen_tokens, gen_tgt)
    same_mask = (tf.cast(same_mask, rep_enc2dec_full.dtype) *
                 tf.cast(gen_tgt_mask, rep_enc2dec_full.dtype))

    # monitor how many generated tokens are the same as the real ones
    same_prec = (tf.reduce_sum(tf.cast(same_mask, gen_tgt_mask.dtype)) /
                 tf.reduce_sum(gen_tgt_mask))
    monitor_dict["same_percent"] = same_prec

    # If same, change the edit_label to original (0)
    dec_same_mask_full = tf.einsum("bi,bil->bl", same_mask, rep_enc2dec_full)
    edit_label = tf.where(
        tf.cast(dec_same_mask_full, tf.bool),
        tf.zeros(tf.shape(dec_same_mask_full), dtype=edit_label.dtype),
        edit_label)

    # If same, exclude from LM loss
    dec_same_mask_part = tf.einsum("bi,bij->bj", same_mask, rep_enc2dec_part)
    dec_diff_mask_part = 1.0 - tf.cast(dec_same_mask_part,
                                       dec_lm_tgt_mask.dtype)
    dec_lm_tgt_mask = dec_lm_tgt_mask * dec_diff_mask_part

    # shapes
    bsz = tf.shape(enc_inp)[0]
    src_len = tf.shape(enc_inp)[1]
    tgt_len = tf.shape(dec_inp)[1]

    ##### format joint model inputs
    inputs = tf.concat([enc_inp, dec_inp], axis=1)
    type_id = tf.concat([enc_type, dec_type], axis=1)
    source_seg = tf.zeros([bsz, src_len], dtype=inputs.dtype)
    target_seg = tf.ones([bsz, tgt_len], dtype=inputs.dtype)

    ##### attention mask: note that `1` indicates CANNOT attend
    # src mask
    src_to_src = tf.not_equal(source_seg[:, :, None], source_seg[:, None, :])
    src_to_tgt = tf.ones([bsz, src_len, tgt_len], dtype=src_to_src.dtype)
    src_mask = tf.concat([src_to_src, src_to_tgt], axis=2)

    # tgt mask
    tgt_to_src = tf.not_equal(target_seg[:, :, None], source_seg[:, None, :])
    tgt_to_tgt = tf.not_equal(target_seg[:, :, None], target_seg[:, None, :])
    causal_mask = tf.cast(causal_attn_mask(qlen=tgt_len), tgt_to_tgt.dtype)
    # If any one of them is `1` (indicating cannot attend), i.e. `logical_or`,
    # then the model should NOT attend
    tgt_to_tgt = tf.logical_or(tgt_to_tgt, causal_mask)
    tgt_mask = tf.concat([tgt_to_src, tgt_to_tgt], axis=2)

    # concat
    perm_mask = tf.concat([src_mask, tgt_mask], axis=1)
    perm_mask = tf.cast(perm_mask, tf_float)

    ##### Position sequence
    pos_seq = tf.concat([enc_pos, dec_pos], axis=1)

    #### Transformer Model
    total_loss = 0
    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        input_embed = inp_func(inputs=inputs, pos_seq=pos_seq, type_id=type_id)
        tfm_func = _get_tfm_func(initializer, is_training, phase="pretrain")
        output, _ = tfm_func(inputs=input_embed,
                             input_mask=None,
                             perm_mask=perm_mask)

        #### edit loss
        tgt_out = output[:, src_len:]

        edit_loss, edit_logits = model.cls_loss(hidden=tgt_out,
                                                target=edit_label,
                                                n_cls=4,
                                                d_model=FLAGS.d_model,
                                                initializer=initializer,
                                                target_mapping=None,
                                                hidden_mapping=None,
                                                return_logits=True,
                                                scope="edit_type_loss")

        dec_tgt_mask = tf.cast(1.0 - dec_mask, tf.float32)
        edit_loss = (tf.reduce_sum(edit_loss * dec_tgt_mask) /
                     tf.reduce_sum(dec_tgt_mask))

        edit_pred = tf.cast(tf.argmax(edit_logits, axis=-1),
                            dtype=edit_label.dtype)
        edit_corr = tf.cast(tf.equal(edit_pred, edit_label), dtype=tf.float32)
        edit_accu = (tf.reduce_sum(edit_corr * dec_tgt_mask) /
                     tf.reduce_sum(dec_tgt_mask))

        if FLAGS.edit_weight > 0:
            # monitor
            monitor_dict["edit_loss"] = edit_loss
            monitor_dict["accu_edit"] = edit_accu

            def get_class_acc(label_id):
                mask = tf.ones(tf.shape(edit_label),
                               dtype=edit_label.dtype) * label_id
                mask = tf.equal(edit_label, mask)
                mask = tf.cast(mask, dtype=dec_tgt_mask.dtype) * dec_tgt_mask
                accu = tf.reduce_sum(edit_corr * mask) / tf.reduce_sum(mask)
                return accu

            monitor_dict["accu_orig"] = get_class_acc(0)
            if FLAGS.del_ratio > 0:
                monitor_dict["accu_del"] = get_class_acc(FLAGS.del_label)
            if FLAGS.ins_ratio > 0:
                monitor_dict["accu_ins"] = get_class_acc(FLAGS.ins_label)
            if FLAGS.rep_ratio > 0:
                monitor_dict["accu_rep"] = get_class_acc(FLAGS.rep_label)

            # accumulate total loss
            total_loss += FLAGS.edit_weight * edit_loss

        #### LM loss
        #### Only predict the target part
        tgt_out = output[:, src_len:]
        tf.logging.info("Output: %s, target output: %s", output.shape,
                        tgt_out.shape)
        lm_loss, _ = model.lm_loss(hidden=tgt_out,
                                   target=dec_masked_tgt,
                                   n_token=n_token,
                                   d_model=FLAGS.d_model,
                                   initializer=initializer,
                                   lookup_table=shared_embed_table,
                                   tie_weight=FLAGS.tie_weight,
                                   target_mapping=None,
                                   hidden_mapping=dec_mask_map,
                                   use_tpu=FLAGS.use_tpu)

        if lm_loss.dtype != tf.float32:
            lm_loss = tf.cast(lm_loss, tf.float32)
        dec_lm_tgt_mask = tf.cast(dec_lm_tgt_mask, lm_loss.dtype)
        lm_loss = (tf.reduce_sum(lm_loss * dec_lm_tgt_mask) /
                   tf.reduce_sum(dec_lm_tgt_mask))

        if FLAGS.lm_weight > 0:
            # monitor
            monitor_dict["lm_loss"] = lm_loss

            # accumulate total loss
            total_loss += FLAGS.lm_weight * lm_loss

    return total_loss, {}, monitor_dict
コード例 #2
0
def joint_loss(features, labels, n_token, is_training):
    """Decoder only seq2seq."""
    del labels

    initializer = _get_initializer()

    #### Unpack input
    source = features["source"]
    source_pos = features["source_position"]
    source_seg = features["source_segmentation"]

    target = features["target"]
    target_pos = features["target_position"]
    target_seg = features["target_segmentation"]

    # shapes
    bsz = tf.shape(source)[0]
    src_len = tf.shape(source)[1]
    tgt_len = tf.shape(target)[1]

    if FLAGS.use_bfloat16:
        tf_float = tf.bfloat16
    else:
        tf_float = tf.float32

    ##### format inputs
    inputs = tf.concat([source, target], axis=1)
    position = tf.concat([source_pos, target_pos], axis=1)

    src_type_id = features.get("source_type",
                               tf.zeros([bsz, src_len], dtype=inputs.dtype))
    tgt_type_id = features.get("target_type",
                               tf.ones([bsz, tgt_len], dtype=inputs.dtype))

    if FLAGS.double_type:
        tgt_type_id = tgt_type_id + FLAGS.n_token
    type_id = tf.concat([src_type_id, tgt_type_id], axis=1)

    ##### attention mask: note that `1` indicates CANNOT attend
    # src mask
    src_to_src = tf.not_equal(source_seg[:, :, None], source_seg[:, None, :])
    src_to_tgt = tf.ones([bsz, src_len, tgt_len], dtype=src_to_src.dtype)
    src_mask = tf.concat([src_to_src, src_to_tgt], axis=2)

    # tgt mask
    tgt_to_src = tf.not_equal(target_seg[:, :, None], source_seg[:, None, :])
    tgt_to_tgt = tf.not_equal(target_seg[:, :, None], target_seg[:, None, :])
    causal_mask = tf.cast(causal_attn_mask(qlen=tgt_len), tgt_to_tgt.dtype)
    # If any one of them is `1` (indicating cannot attend), i.e. `logical_or`,
    # then the model should NOT attend
    tgt_to_tgt = tf.logical_or(tgt_to_tgt, causal_mask)
    tgt_mask = tf.concat([tgt_to_src, tgt_to_tgt], axis=2)

    # concat
    perm_mask = tf.concat([src_mask, tgt_mask], axis=1)
    perm_mask = tf.cast(perm_mask, tf_float)

    # padding
    non_pad_mask = tf.not_equal(target_seg, 0)
    all_eos = tf.constant(FLAGS.eos_id, shape=target.shape, dtype=target.dtype)
    # Replace all <pad> (/P) with <eos> (/S)
    #   - target : /S a1 a2 a3 /S b1 b2 /S c1 c2 /P /P
    #   - tmptgt : /S a1 a2 a3 /S b1 b2 /S c1 c2 /S /S
    tmptgt = tf.where(non_pad_mask, target, all_eos)
    # Shift the `tmptgt` to form the (next-step) prediction target
    #   - target   : \S a1 a2 a3 \S b1 b2 \S c1 c2 \P \P
    #   - pred_tgt : a1 a2 a3 \S b1 b2 \S c1 c2 \S \S \S
    pred_tgt = tf.concat([tmptgt[:, 1:], tmptgt[:, :1]], axis=1)
    loss_mask = tf.cast(non_pad_mask, tf.float32)

    #### Transformer Model
    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        inp_func = _get_inp_func(n_token, FLAGS.d_model, initializer,
                                 is_training)
        input_embed, word_embed_table = inp_func(inputs=inputs,
                                                 type_id=type_id,
                                                 pos_seq=position,
                                                 return_embed_table=True)

        tfm_func = _get_tfm_func(initializer, is_training, phase="pretrain")
        output, _ = tfm_func(inputs=input_embed,
                             input_mask=None,
                             perm_mask=perm_mask,
                             pos_seq=position)

        #### Only predict the target part
        tgt_out = output[:, src_len:]
        tf.logging.info("Output: %s, target output: %s", output.shape,
                        tgt_out.shape)

        lm_loss, nll_loss = model.lm_loss(hidden=tgt_out,
                                          target=pred_tgt,
                                          n_token=n_token,
                                          d_model=FLAGS.d_model,
                                          initializer=initializer,
                                          lookup_table=word_embed_table,
                                          tie_weight=FLAGS.tie_weight,
                                          label_smooth=FLAGS.label_smooth,
                                          use_tpu=FLAGS.use_tpu)

        if lm_loss.dtype != tf.float32:
            lm_loss = tf.cast(lm_loss, tf.float32)

        num_loss = tf.reduce_sum(loss_mask)
        total_loss = tf.reduce_sum(lm_loss * loss_mask) / num_loss
        nll = tf.reduce_sum(nll_loss * loss_mask) / num_loss

        # To be compatible with fairseq, convert to base 2 for logging
    monitor_dict = {
        "loss": total_loss / math.log(2),
        "nll": nll / math.log(2),
        "num_loss": num_loss,
    }

    return total_loss, monitor_dict
コード例 #3
0
def transformer(inputs,
                n_layer,
                d_model,
                n_head,
                d_head,
                d_inner,
                dropout,
                dropatt,
                dropact,
                initializer,
                is_training,
                context=None,
                context_mask=None,
                input_mask=None,
                perm_mask=None,
                ff_activation="relu",
                causal=False,
                rel_attn=False,
                pos_seq=None,
                clamp_len=-1,
                return_all_hidden=False,
                scope="transformer"):
    """Transformer model."""

    monitor_dict = {}
    if input_mask is not None:
        monitor_dict["inp_mask"] = input_mask

    tf.logging.info("===== Transformer =====")
    tf.logging.info("Input related:")
    tf.logging.info("  - inputs %s", inputs)
    tf.logging.info("  - input_mask %s", input_mask)
    tf.logging.info("  - perm_mask %s", perm_mask)
    tf.logging.info("  - context %s", context)
    tf.logging.info("  - context_mask %s", context_mask)
    tf.logging.info("Hparam related:")
    tf.logging.info("  - initializer %s", initializer)
    tf.logging.info("  - ff_activation %s", ff_activation)
    tf.logging.info("  - causal %s", causal)
    tf.logging.info("============================")

    hiddens = []
    with tf.variable_scope(scope):
        ##### Attention mask
        if causal:
            causal_mask = causal_attn_mask(tf.shape(inputs)[1],
                                           dtype=inputs.dtype)
            causal_mask = causal_mask[None, None]
        else:
            causal_mask = None
        attn_mask = merge_attn_masks(causal_mask, input_mask, perm_mask)

        ##### Input projection
        if inputs.shape.as_list()[-1] != d_model:
            tf.logging.info("Project input embedding: %d -> %d",
                            inputs.shape.as_list()[-1], d_model)
            output = tf.layers.dense(inputs,
                                     d_model,
                                     activation=None,
                                     kernel_initializer=initializer,
                                     name="input_projection")
        else:
            output = inputs

        hiddens.append(output)

        ##### Get relative attention bias
        if rel_attn:
            tf.logging.info("Use relative attention")
            if pos_seq is None:
                seq_len = tf.shape(output)[1]
                attn_bias = consecutive_rel_encoding(seq_len, d_model, n_head,
                                                     clamp_len, dropout,
                                                     is_training, initializer,
                                                     output.dtype)
            else:
                attn_bias = rel_encoding(pos_seq, pos_seq, d_model, n_head,
                                         clamp_len, dropout, is_training,
                                         initializer, output.dtype)
        else:
            attn_bias = None

        ##### Attention layers
        for i in range(n_layer):
            with tf.variable_scope("layer_{}".format(i)):
                output, attn_dict = multihead_attn(
                    q=output,
                    k=output,
                    v=output,
                    attn_mask=attn_mask,
                    d_model=d_model,
                    n_head=n_head,
                    d_head=d_head,
                    dropout=dropout,
                    dropatt=dropatt,
                    is_training=is_training,
                    kernel_initializer=initializer,
                    attn_bias=attn_bias,
                    scope="self_attn")

                if context is not None:
                    output, _ = multihead_attn(q=output,
                                               k=context,
                                               v=context,
                                               attn_mask=context_mask,
                                               d_model=d_model,
                                               n_head=n_head,
                                               d_head=d_head,
                                               dropout=dropout,
                                               dropatt=dropatt,
                                               is_training=is_training,
                                               kernel_initializer=initializer,
                                               scope="cross_attn")

                output, ffn_dict = positionwise_ffn(
                    inp=output,
                    d_model=d_model,
                    d_inner=d_inner,
                    dropout=dropout,
                    dropact=dropact,
                    initializer=initializer,
                    activation_type=ff_activation,
                    is_training=is_training)

                hiddens.append(output)

                # Update monitor dict
                monitor_dict = update_monitor_dict(
                    monitor_dict, attn_dict, prefix="layer_{}_attn".format(i))
                monitor_dict = update_monitor_dict(
                    monitor_dict, ffn_dict, prefix="layer_{}_ff".format(i))

        if return_all_hidden:
            return hiddens, monitor_dict
        else:
            return output, monitor_dict
コード例 #4
0
def mass_loss(features, labels, mems, n_token, is_training):
    """MASS pretraining loss."""
    del labels
    del mems

    initializer = _get_initializer()

    # Type
    if FLAGS.use_bfloat16:
        tf_float = tf.bfloat16
    else:
        tf_float = tf.float32

    #### Unpack input
    target = features["target"]
    target_mask = features["target_mask"]
    target_mapping = features["target_mapping"]

    enc_inp = features["enc_inp"]
    enc_type = features["type_id"]
    enc_pos = features["enc_pos"]

    dec_inp = features["dec_inp"]
    dec_type = features["dec_type"]
    dec_pos = features["dec_pos"]
    dec_seg = features["dec_seg"]

    # shapes
    bsz = tf.shape(enc_inp)[0]
    enc_len = tf.shape(enc_inp)[1]
    dec_len = tf.shape(dec_inp)[1]

    ##### format inputs
    inputs = tf.concat([enc_inp, dec_inp], axis=1)
    position = tf.concat([enc_pos, dec_pos], axis=1)
    type_id = tf.concat([enc_type, dec_type], axis=1)

    ##### attention mask: note that `1` indicates CANNOT attend
    # enc mask
    enc_to_enc = tf.zeros([bsz, enc_len, enc_len], dtype=tf_float)
    enc_to_dec = tf.ones([bsz, enc_len, dec_len], dtype=tf_float)
    enc_mask = tf.concat([enc_to_enc, enc_to_dec], axis=2)

    # dec mask
    dec_to_enc = tf.zeros([bsz, dec_len, enc_len], dtype=tf.bool)
    dec_to_dec = tf.not_equal(dec_seg[:, :, None], dec_seg[:, None, :])
    causal_mask = tf.cast(causal_attn_mask(qlen=dec_len), tf.bool)
    # If any one of them is `1` (indicating cannot attend), i.e. `logical_or`,
    # then the model should NOT attend
    dec_to_dec = tf.logical_or(dec_to_dec, causal_mask)
    dec_mask = tf.cast(tf.concat([dec_to_enc, dec_to_dec], axis=2), tf_float)

    # concat
    perm_mask = tf.concat([enc_mask, dec_mask], axis=1)
    perm_mask = tf.cast(perm_mask, tf_float)

    #### Transformer Model
    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        inp_func = _get_inp_func(n_token, FLAGS.d_model, initializer,
                                 is_training)
        input_embed, word_embed_table = inp_func(inputs=inputs,
                                                 type_id=type_id,
                                                 pos_seq=position,
                                                 return_embed_table=True)

        tfm_func = _get_tfm_func(initializer, is_training, phase="pretrain")
        output, _ = tfm_func(inputs=input_embed,
                             input_mask=None,
                             perm_mask=perm_mask)

        #### Only predict the target part
        enc_out = output[:, :enc_len]
        dec_out = output[:, enc_len:]
        tf.logging.info("Output: %s, enc output: %s, dec output: %s",
                        output.shape, enc_out.shape, dec_out.shape)

        enc_loss, _ = model.lm_loss(hidden=enc_out,
                                    target=target,
                                    n_token=n_token,
                                    d_model=FLAGS.d_model,
                                    initializer=initializer,
                                    lookup_table=word_embed_table,
                                    tie_weight=FLAGS.tie_weight,
                                    hidden_mapping=target_mapping,
                                    use_tpu=FLAGS.use_tpu)

        dec_loss, _ = model.lm_loss(hidden=dec_out,
                                    target=target,
                                    n_token=n_token,
                                    d_model=FLAGS.d_model,
                                    initializer=initializer,
                                    lookup_table=word_embed_table,
                                    tie_weight=FLAGS.tie_weight,
                                    use_tpu=FLAGS.use_tpu)

        if dec_loss.dtype != tf.float32:
            dec_loss = tf.cast(dec_loss, tf.float32)
        if enc_loss.dtype != tf.float32:
            enc_loss = tf.cast(enc_loss, tf.float32)
        loss_mask = tf.cast(target_mask, tf.float32)

        num_loss = tf.reduce_sum(loss_mask)
        avg_dec_loss = tf.reduce_sum(dec_loss * loss_mask) / num_loss
        avg_enc_loss = tf.reduce_sum(enc_loss * loss_mask) / num_loss

    monitor_dict = {}
    total_loss = 0
    if FLAGS.enc_weight > 0:
        total_loss += FLAGS.enc_weight * avg_enc_loss
        monitor_dict["loss_enc"] = avg_enc_loss
    if FLAGS.dec_weight > 0:
        total_loss += FLAGS.dec_weight * avg_dec_loss
        monitor_dict["loss_dec"] = avg_dec_loss

    monitor_dict["loss"] = total_loss

    return total_loss, {}, monitor_dict