예제 #1
0
 def update(self, loss, preds, golds):
     mask = golds != self.mask_id
     pred_probs = torch.exp(preds)
     pred_classes = pred_probs.argmax(dim=-1)
     self.loss += loss
     self.pred_probs.append(unroll(unmask(pred_probs, mask)))
     self._pred_classes_sent.append(unmask(pred_classes, mask))
     self.pred_classes.append(unroll(self._pred_classes_sent[-1]))
     self._golds_sent.append(unmask(golds, mask))
     self.golds.append(unroll(self._golds_sent[-1]))
예제 #2
0
    def update(self, loss, preds, golds):
        pred_probs = torch.exp(preds)  # assuming log softmax at the nn output
        pred_classes = pred_probs.argmax(dim=-1)
        self.loss += loss

        # unmask & flatten predictions and gold labels before storing them
        mask = golds != self.mask_id
        self.pred_probs.append(unroll(unmask(pred_probs, mask)))
        self._pred_classes_sent.append(unmask(pred_classes, mask))
        self.pred_classes.append(unroll(self._pred_classes_sent[-1]))
        self._golds_sent.append(unmask(golds, mask))
        self.golds.append(unroll(self._golds_sent[-1]))
예제 #3
0
    def calc(self, current_epoch, words):
        specials = [constants.PAD, constants.START, constants.STOP]
        words = list(filter(lambda w: w not in specials, unroll(words)))
        current_loss = self.get_loss()
        current_acc = self.get_acc()
        current_acc_oov = self.get_acc_oov(words)
        current_acc_emb = self.get_acc_emb(words)
        current_acc_sent = self.get_acc_sentence()

        if current_loss < self.best_loss.value:
            self.best_loss.value = current_loss
            self.best_loss.epoch = current_epoch

        if current_acc > self.best_acc.value:
            self.best_acc.value = current_acc
            self.best_acc.epoch = current_epoch

        if current_acc_oov > self.best_acc_oov.value:
            self.best_acc_oov.value = current_acc_oov
            self.best_acc_oov.epoch = current_epoch

        if current_acc_emb > self.best_acc_emb.value:
            self.best_acc_emb.value = current_acc_emb
            self.best_acc_emb.epoch = current_epoch

        if current_acc_sent > self.best_acc_sent.value:
            self.best_acc_sent.value = current_acc_sent
            self.best_acc_sent.epoch = current_epoch
예제 #4
0
 def get_acc_emb(self, words=None):
     if self.acc_emb is None:
         idx = [i for i, w in enumerate(unroll(words))
                if w not in self.emb_vocab and w != constants.UNK]
         if len(idx) == 0:
             self.acc_emb = 1.0
             return self.acc_emb
         bins = self._get_bins()
         self.acc_emb = bins[idx].mean()
     return self.acc_emb
예제 #5
0
 def get_acc_oov(self, words=None):
     if self.acc_oov is None:
         idx = [i for i, w in enumerate(unroll(words))
                if w not in self.train_vocab or w == constants.UNK]
         if len(idx) == 0:
             self.acc_oov = 1.0
             return self.acc_oov
         bins = self._get_bins()
         self.acc_oov = bins[idx].mean()
     return self.acc_oov
예제 #6
0
 def _get_bins(self):
     if self._flattened_preds is None:
         self._flattened_preds = np.array(unroll(self.pred_classes))
     if self._flattened_golds is None:
         self._flattened_golds = np.array(unroll(self.golds))
     return self._flattened_preds == self._flattened_golds