Ejemplo n.º 1
0
Archivo: uda.py Proyecto: dptam/uda
def create_model(
    bert_config,
    is_training,
    input_ids,
    input_mask,
    input_type_ids,
    labels,
    num_labels,
    use_one_hot_embeddings,
    tsa,
    unsup_ratio,
    global_step,
    num_train_steps,
    ):

  num_sample = input_ids.shape[0].value
  if is_training:
    assert num_sample % (1 + 2 * unsup_ratio) == 0
    sup_batch_size = num_sample // (1 + 2 * unsup_ratio)
    unsup_batch_size = sup_batch_size * unsup_ratio
  else:
    sup_batch_size = num_sample
    unsup_batch_size = 0

  pooled = modeling.bert_model(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=input_type_ids,
      use_one_hot_embeddings=use_one_hot_embeddings)

  clas_logits = hidden_to_logits(
      hidden=pooled,
      is_training=is_training,
      num_classes=num_labels,
      scope="classifier")

  log_probs = tf.nn.log_softmax(clas_logits, axis=-1)
  correct_label_probs = None

  with tf.variable_scope("sup_loss"):
    sup_log_probs = log_probs[:sup_batch_size]
    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
    tgt_label_prob = one_hot_labels

    per_example_loss = -tf.reduce_sum(tgt_label_prob * sup_log_probs, axis=-1)
    loss_mask = tf.ones_like(per_example_loss, dtype=per_example_loss.dtype)
    correct_label_probs = tf.reduce_sum(
        one_hot_labels * tf.exp(sup_log_probs), axis=-1)

    if tsa:
      tsa_start = 1. / num_labels
      tsa_threshold = get_tsa_threshold(
          tsa, global_step, num_train_steps,
          tsa_start, end=1)

      larger_than_threshold = tf.greater(
          correct_label_probs, tsa_threshold)
      loss_mask = loss_mask * (1 - tf.cast(larger_than_threshold, tf.float32))
    else:
      tsa_threshold = 1

    loss_mask = tf.stop_gradient(loss_mask)
    per_example_loss = per_example_loss * loss_mask
    sup_loss = (tf.reduce_sum(per_example_loss) /
                tf.maximum(tf.reduce_sum(loss_mask), 1))

  unsup_loss_mask = None
  if is_training and unsup_ratio > 0:
    with tf.variable_scope("unsup_loss"):
      ori_start = sup_batch_size
      ori_end = ori_start + unsup_batch_size
      aug_start = sup_batch_size + unsup_batch_size
      aug_end = aug_start + unsup_batch_size

      ori_log_probs = log_probs[ori_start : ori_end]
      aug_log_probs = log_probs[aug_start : aug_end]
      unsup_loss_mask = 1
      if FLAGS.uda_softmax_temp != -1:
        tgt_ori_log_probs = tf.nn.log_softmax(
            clas_logits[ori_start : ori_end] / FLAGS.uda_softmax_temp,
            axis=-1)
        tgt_ori_log_probs = tf.stop_gradient(tgt_ori_log_probs)
      else:
        tgt_ori_log_probs = tf.stop_gradient(ori_log_probs)

      if FLAGS.uda_confidence_thresh != -1:
        largest_prob = tf.reduce_max(tf.exp(ori_log_probs), axis=-1)
        unsup_loss_mask = tf.cast(tf.greater(
            largest_prob, FLAGS.uda_confidence_thresh), tf.float32)
        unsup_loss_mask = tf.stop_gradient(unsup_loss_mask)

      per_example_kl_loss = kl_for_log_probs(
          tgt_ori_log_probs, aug_log_probs) * unsup_loss_mask
      unsup_loss = tf.reduce_mean(per_example_kl_loss)

  else:
    unsup_loss = 0.

  return (sup_loss, unsup_loss, clas_logits[:sup_batch_size],
          per_example_loss, loss_mask,
          tsa_threshold, unsup_loss_mask, correct_label_probs)
Ejemplo n.º 2
0
def create_model(
    bert_config,
    is_training,
    input_ids,
    input_mask,
    input_type_ids,
    labels,
    num_labels,
    use_one_hot_embeddings,
    tsa,
    unsup_ratio,
    global_step,
    num_train_steps,
):

    num_sample = input_ids.shape[0].value
    if is_training:
        assert num_sample % (1 + 2 * unsup_ratio) == 0
        sup_batch_size = num_sample // (1 + 2 * unsup_ratio)
        unsup_batch_size = sup_batch_size * unsup_ratio
    else:
        sup_batch_size = num_sample
        unsup_batch_size = 0

    sequence = modeling.bert_model(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=input_type_ids,
        use_one_hot_embeddings=use_one_hot_embeddings,
        output_type="sequence")

    clas_logits = hidden_to_logits(hidden=sequence,
                                   is_training=is_training,
                                   num_classes=num_labels,
                                   scope="classifier")

    log_probs = tf.nn.log_softmax(clas_logits, axis=-1)
    correct_label_probs = None

    with tf.variable_scope("sup_loss"):
        sup_log_probs = log_probs[:sup_batch_size]
        tf.logging.info("***************%d", sup_batch_size)
        one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
        tgt_label_prob = one_hot_labels

        per_example_loss = -tf.reduce_sum(tgt_label_prob * sup_log_probs,
                                          axis=-1)
        loss_mask = tf.ones_like(per_example_loss,
                                 dtype=per_example_loss.dtype)
        correct_label_probs = tf.reduce_sum(one_hot_labels *
                                            tf.exp(sup_log_probs),
                                            axis=-1)

        if tsa:
            tsa_start = 1. / num_labels
            tsa_threshold = get_tsa_threshold(tsa,
                                              global_step,
                                              num_train_steps,
                                              tsa_start,
                                              end=1)

            larger_than_threshold = tf.greater(correct_label_probs,
                                               tsa_threshold)
            loss_mask = loss_mask * (
                1 - tf.cast(larger_than_threshold, tf.float32))
        else:
            tsa_threshold = 1

        loss_mask = tf.stop_gradient(loss_mask)
        per_example_loss = per_example_loss * loss_mask
        sup_loss = (tf.reduce_sum(per_example_loss) /
                    tf.maximum(tf.reduce_sum(loss_mask), 1))

    unsup_loss_mask = None
    if unsup_ratio > 0:
        with tf.variable_scope("unsup_loss"):
            ori_start = sup_batch_size
            ori_end = ori_start + unsup_batch_size
            aug_start = sup_batch_size + unsup_batch_size
            aug_end = aug_start + unsup_batch_size

            ori_log_probs = log_probs[ori_start:ori_end]
            aug_log_probs_before = log_probs[aug_start:aug_end]
            ##########################################
            #kick out B-TRI and E-TRI tag, rematch
            _, max_seq_length, hidden_dim = clas_logits.get_shape().as_list()
            aug_input_type_ids = input_type_ids[aug_start:aug_end]
            one = tf.where(tf.equal(aug_input_type_ids, 1))
            zero = tf.where(tf.equal(aug_input_type_ids, 0))
            p = tf.concat([one, zero], 0)

            def my_numpy_func(x):
                return x[x[:, 0].argsort()]

            aug_input_type_ids_trans = tf.py_func(my_numpy_func, [p],
                                                  Tout=tf.int64)
            aug_log_probs_middle = tf.gather_nd(
                params=aug_log_probs_before, indices=aug_input_type_ids_trans)
            aug_log_probs = tf.reshape(aug_log_probs_middle,
                                       [-1, max_seq_length, hidden_dim])
            #vec = tf.to_int32(tf.expand_dims(np.arange(unsup_batch_size)*max_seq_length,-1))
            #trans=tf.tile(vec, [1,max_seq_length])
            #aug_input_type_ids_trans=tf.reshape(aug_input_type_ids+trans,[-1])
            #aug_input_type_ids_trans = tf.stop_gradient(aug_input_type_ids_trans)

            #aug_log_probs_middle = tf.reshape(aug_log_probs_before,[-1,hidden_dim])
            #aug_log_probs = tf.reshape(tf.gather(params=aug_log_probs_middle, indices=aug_input_type_ids_trans), [-1,max_seq_length,hidden_dim])
            #########################################
            unsup_loss_mask = 1
            if FLAGS.uda_softmax_temp != -1:
                m_aug_log_probs = tf.nn.log_softmax(aug_log_probs /
                                                    FLAGS.uda_softmax_temp,
                                                    axis=-1)
                tgt_aug_log_probs = tf.stop_gradient(m_aug_log_probs)
            else:
                tgt_aug_log_probs = tf.stop_gradient(aug_log_probs)

            if FLAGS.uda_confidence_thresh != -1:
                largest_prob = tf.reduce_max(tf.exp(ori_log_probs), axis=-1)
                unsup_loss_mask = tf.cast(
                    tf.greater(largest_prob, FLAGS.uda_confidence_thresh),
                    tf.float32)
                unsup_loss_mask = tf.stop_gradient(unsup_loss_mask)

            per_example_kl_loss = kl_for_log_probs(
                tgt_aug_log_probs, ori_log_probs) * unsup_loss_mask
            unsup_loss = tf.reduce_mean(per_example_kl_loss)
            tf.logging.info(
                "*** **************************************************************************8Features ***"
            )
    else:
        tf.logging.info(
            "*** **************************************************************************We have the issue here ***"
        )
        unsup_loss = sup_loss
        tgt_aug_log_probs = log_probs
        ori_log_probs = log_probs
    return (sup_loss, unsup_loss, clas_logits[:sup_batch_size],
            per_example_loss, loss_mask, tsa_threshold, unsup_loss_mask,
            correct_label_probs, tgt_aug_log_probs, ori_log_probs)