예제 #1
0
 def build_loss(self, logits, labels):
     global_step = tf.train.get_or_create_global_step()
     return tf.case([
         (tf.equal(tf.mod(global_step, 3),
                   0), lambda: softmax_cross_entropy(labels, 15, logits)),
         (tf.equal(tf.mod(global_step, 3),
                   1), lambda: softmax_cross_entropy(labels, 7, logits)),
         (tf.equal(tf.mod(global_step, 3),
                   2), lambda: softmax_cross_entropy(labels, 3, logits)),
     ],
                    exclusive=True)
예제 #2
0
    def build_loss(self, logits, labels):
        """ Building loss for KD Student
        """
        ce_loss = softmax_cross_entropy(labels, self.config.num_labels,
                                        logits[-1])
        last_layer_kd_loss = build_kd_loss(
            teacher_logits=self.teacher_logits[-1],
            student_logits=logits[-1],
            task_balance=0.5,
            distill_tempreture=2.0,
            labels=None,
            loss_type='mse')

        total_loss = ce_loss + last_layer_kd_loss

        if self.config.train_probes and len(logits) > 1:
            probes_kd_loss = build_kd_probes_loss(
                teacher_logits=self.teacher_logits,
                student_logits=logits,
                task_balance=0.5,
                distill_tempreture=2.0,
                labels=None,
                loss_type='mse')
            total_loss += probes_kd_loss

        return total_loss
 def build_loss(self, logits, labels):
     """ Building loss for KD Teacher
     """
     loss = 0.0
     for layer_logits in logits:
         loss += softmax_cross_entropy(labels, self.config.num_labels,
                                       layer_logits)
     return loss
예제 #4
0
 def build_loss(self, logits, labels):
     """ Building loss for training text match model
     """
     if self.config.num_labels < 2:
         return mean_square_error(labels, logits)
     else:
         return softmax_cross_entropy(labels,
                                      depth=self.config.num_labels,
                                      logits=logits)
 def build_loss(self, logits, labels):
     """ Building loss for training the Text Classification Model
     """
     if hasattr(self.config, "multi_label") and self.config.multi_label:
         return multi_label_sigmoid_cross_entropy(labels, self.config.num_labels, logits)
     elif self.config.num_labels == 1:
         return mean_square_error(labels, logits)
     else:
         return softmax_cross_entropy(labels, self.config.num_labels, logits)
 def build_loss(self, logits, labels):
     cls_loss = weighted_softmax_cross_entropy(labels, self.num_labels,
                                               logits, self.weights)
     total_domain_loss = 0
     for layer_index in layer_indexes:
         shuffle_domain_labels = tf.random_shuffle(self.domains)
         current_domain_logits = self.domain_logits["domain_logits_" +
                                                    str(layer_index)]
         domain_loss = softmax_cross_entropy(shuffle_domain_labels,
                                             num_domains,
                                             current_domain_logits)
         total_domain_loss += domain_loss
     total_domain_loss = total_domain_loss / len(layer_indexes)
     return cls_loss + domain_weight * total_domain_loss
예제 #7
0
    def build_loss(self, logits, labels):
        if _APP_FLAGS.loss == "mlm":
            lm_logits, task_1_logits = logits
            masked_lm_ids, masked_lm_weights, task_1_label = labels
            masked_lm_loss = masked_language_model_loss(
                lm_logits, masked_lm_ids, masked_lm_weights,
                _APP_FLAGS.vocab_size)

            task_1_loss = softmax_cross_entropy(task_1_label, 2, task_1_logits)

            return masked_lm_loss + task_1_loss

        elif _APP_FLAGS.loss == "mlm+nsp" or _APP_FLAGS.loss == "mlm+sop":

            lm_logits, nsp_logits = logits
            masked_lm_ids, masked_lm_weights, nx_sent_labels = labels

            masked_lm_loss = masked_language_model_loss(
                lm_logits, masked_lm_ids, masked_lm_weights,
                _APP_FLAGS.vocab_size)
            nsp_loss = next_sentence_prediction_loss(nsp_logits,
                                                     nx_sent_labels)

            return masked_lm_loss + nsp_loss
예제 #8
0
 def build_loss(self, logits, labels):
     return softmax_cross_entropy(labels, self.num_labels, logits)