コード例 #1
0
ファイル: leam.py プロジェクト: yyht/classifynet
    def build_loss(self, *args, **kargs):
        if self.config.loss == "softmax_loss":
            self.loss, _ = point_wise_loss.softmax_loss(self.logits, self.gold_label, 
                                    *args, **kargs)
        elif self.config.loss == "sparse_amsoftmax_loss":
            self.loss, _ = point_wise_loss.sparse_amsoftmax_loss(self.logits, self.gold_label, 
                                        self.config, *args, **kargs)
        elif self.config.loss == "focal_loss_multi_v1":
            self.loss, _ = point_wise_loss.focal_loss_multi_v1(self.logits, self.gold_label, 
                                        self.config, *args, **kargs)
        if self.config.with_center_loss:
            self.center_loss, _ = point_wise_loss.center_loss_v2(self.sent_repres, 
                                            self.gold_label, centers=self.memory,
                                            config=self.config, 
                                            *args, **kargs)
            self.loss = self.loss + self.config.center_gamma * self.center_loss

        if self.config.get("mode", "train") == "train":
            if self.config.with_label_regularization:
                print("===with class regularization===")
                self.class_loss, _ = point_wise_loss.focal_loss_multi_v1(
                                        self.class_logits, self.gold_label, 
                                        self.config, *args, **kargs)
                self.loss += self.config.class_penalty * self.class_loss
        trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) 
        print("List of Variables:")
        for v in trainable_vars:
            print(v.name)
コード例 #2
0
ファイル: base_transformer.py プロジェクト: yyht/classifynet
 def build_loss(self, *args, **kargs):
     if self.config.loss == "softmax_loss":
         self.loss, _ = point_wise_loss.softmax_loss(
             self.logits, self.gold_label, *args, **kargs)
     elif self.config.loss == "sparse_amsoftmax_loss":
         self.loss, _ = point_wise_loss.sparse_amsoftmax_loss(
             self.logits, self.gold_label, self.config, *args, **kargs)
     elif self.config.loss == "focal_loss_multi_v1":
         self.loss, _ = point_wise_loss.focal_loss_multi_v1(
             self.logits, self.gold_label, self.config, *args, **kargs)
     if self.config.with_center_loss:
         self.center_loss, _ = point_wise_loss.center_loss_v2(
             self.sent_repres, self.gold_label, self.config, *args, **kargs)
         self.loss = self.loss + self.config.center_gamma * self.center_loss