Exemplo n.º 1
0
    def print_card(self, file):
        """Print the corresponding input card

        Parameters
        ----------
        file : file
            File object with a :meth:`write` method.

        """
        #file.write('$dvar.id dvar.code %8d %-8s\n' % (self.id, self.code))
        #file.write('$HMNAME DESVAR %8d %-8s\n' % (self.id, self.label))
        file.write('%s% 8d% 8s%s%s%s% 8s% 8s\n' % ('DESVAR'.ljust(8), self.id,
            self.label, ff(self.xinit), ff(self.xlb), ff(self.xub),
            str(self.delx), str(self.ddval)))
Exemplo n.º 2
0
def evaluate(model, eval_dataset, eval_dataset_video, batch_size, criterion):
    # print("eval called")
    model = model.eval()
    epoch_loss = 0
    nc = 1
    epoch_acc_dict = { i:0 for i in range(nc)}
    all_tp, all_fp, all_tn, all_fn =  { i:0 for i in range(nc)},  { i:0 for i in range(nc)},  { i:0 for i in range(nc)},  { i:0 for i in range(nc)}
    num_batches = 0
    all_labels = { i:[] for i in range(nc)}
    all_preds = { i:[] for i in range(nc)}
    for batch_features, video_batch_features, batch_labels in read_batches(eval_dataset, eval_dataset_video, batch_size, False, device, is_multi=False):
        batch_features = batch_features.permute((0, 2, 1))
        out, pred_classes = model.predict(video_batch_features)
        outs = [out]
        pred_classes=[pred_classes]
        losses = []
        for i, tup in enumerate(zip(outs, pred_classes)):
            out, pred_class = tup
            losses.append(criterion(
                out,  # (batch_size , num_classes)
                batch_labels[:,i:i+1].view(-1)  # (batch_size * 1)
            ))
            epoch_acc_dict[i] += categorical_accuracy(pred_class, batch_labels[:,i:i+1])
            tn, fp, fn, tp = conf_matrix(pred_class, batch_labels[:,i:i+1])
            all_tp[i] += tp
            all_tn[i] += tn
            all_fp[i] += fp
            all_fn[i] += fn
            all_labels[i] += list(batch_labels[:,i:i+1].view(-1).cpu().numpy())
            all_preds[i] += list(pred_class)
        loss = sum(losses)
        epoch_loss += loss.item()
        num_batches += len(batch_features)
    bal_acc_dict = {}
    fscore_dict = {}
    for i in range(nc):
        # print(i,"pred, label", Counter(all_preds[i]), Counter(all_labels[i]))
        bal_acc = balanced_accuracy_score(all_labels[i], all_preds[i])
        recall = all_tp[i] / (all_tp[i] + all_fn[i])
        if all_tp[i] + all_fp[i] > 0:
            precision = all_tp[i] / (all_tp[i] + all_fp[i])
            fscore = (2 * precision * recall) / (precision + recall)
        else:
            precision = -1
            fscore = -1
        bal_acc_dict[i] = ff(bal_acc)
        fscore_dict[i] = ff(fscore)
        epoch_acc_dict[i] = ff(epoch_acc_dict[i]/num_batches)
    return epoch_loss / num_batches, bal_acc_dict, fscore_dict, epoch_acc_dict
Exemplo n.º 3
0
def evaluate(model, eval_dataset, eval_dataset_video, batch_size, criterion):
    # print("eval called")
    model = model.eval()
    epoch_loss = 0
    epoch_acc_dict = {i: 0 for i in range(1)}
    all_tp, all_fp, all_tn, all_fn = 0, 0, 0, 0
    num_batches = 0
    all_labels = []
    all_preds = []
    for batch_features, video_batch_features, batch_labels in read_batches(
            eval_dataset, eval_dataset_video, batch_size, False, device):
        batch_features = batch_features.permute((0, 2, 1))
        out, pred_classes = model.predict(video_batch_features, batch_features)
        outs = [out]
        losses = []
        for i, out in enumerate(outs):
            losses.append(
                criterion(
                    out,  # (batch_size , num_classes)
                    batch_labels.view(-1)  # (batch_size * 1)
                ))
            epoch_acc_dict[i] += categorical_accuracy(pred_classes,
                                                      batch_labels)
        loss = sum(losses)
        epoch_loss += loss.item()
        tn, fp, fn, tp = conf_matrix(pred_classes, batch_labels)
        all_tp += tp
        all_tn += tn
        all_fp += fp
        all_fn += fn
        num_batches += len(batch_features)
        all_labels += list(batch_labels.view(-1).cpu().numpy())
        all_preds += list(pred_classes)
    # print("pred, label", Counter(all_preds), Counter(all_labels))
    epoch_acc_dict = {k: v / num_batches for k, v in epoch_acc_dict.items()}
    bal_acc = ((all_tn / (all_tn + all_fp)) + (all_tp / (all_tp + all_fn))) / 2
    recall = all_tp / (all_tp + all_fn)
    if all_tp + all_fp > 0:
        precision = all_tp / (all_tp + all_fp)
        fscore = (2 * precision * recall) / (precision + recall)
    else:
        precision = -1
        fscore = -1
    return epoch_loss / num_batches, ff(
        epoch_acc_dict[0]), ff(bal_acc), ff(fscore)
Exemplo n.º 4
0
    def print_card(self, file):
        """Print the corresponding input card

        Parameters
        ----------
        file : file
            File object with a :meth:`write` method.

        """
        dlinkstr = ('%s% 8d% 8d%s%s' % ('DLINK'.ljust(8), self.id,
                    self.ddvid, ff(self.c0), ff(self.cmult)))
        dvicicount = 4
        for i in range(len(self.idvs)):
            dvicicount += 2
            if dvicicount == 10:
                file.write(dlinkstr + '\n')
                dlinkstr = '+'.ljust(8)
                dvicicount = 2
            dlinkstr += ('% 8d%s' % (self.idvs[i], ff(self.cs[i])))

        file.write(dlinkstr + '\n')
Exemplo n.º 5
0
    def print_card(self, file):
        """Print the corresponding input card

        Parameters
        ----------
        file : file
            File object with a :meth:`write` method.

        """
        file.write('%s% 8d% 8s% 8d% 8s% 8s% 8s%s\n' %
                   ('DVPREL1'.ljust(8), self.id, self.type, self.pid,
                       self.pname, '', '', ff(self.c0)))

        fieldnum = 0
        dvprel1str = '+'.ljust(8)
        for i in range(len(self.dvids)):
            fieldnum += 2
            if fieldnum == 10:
                file.write(dvprel1str + '\n')
                dvprel1str = '+'.ljust(8)
                fieldnum = 2
            dvprel1str += ('% 8d%s' % (self.dvids[i], ff(self.coeffs[i])))
        file.write(dvprel1str + '\n')
Exemplo n.º 6
0
    def print_card(self, file):
        """Print the corresponding input card

        Parameters
        ----------
        file : file
            File object with a :meth:`write` method.

        """
        if len(self.input_dict) > 0:
            keys = self.input_dict.keys()
            dtable_str = 'DTABLE'.ljust(8)
            count = 0
            for k in sorted(keys):
                v = self.input_dict[k]
                count += 2
                if count == 10:
                    file.write(dtable_str + '\n')
                    dtable_str = '+'.ljust(8)
                    count = 2
                dtable_str += '% 8s%s' % (str(k), ff(v))
            file.write(dtable_str + '\n')
Exemplo n.º 7
0
 def show_card(self):
     print('%s% 8d% 8s%s%s%s% 8s% 8s\n' % ('DESVAR'.ljust(8),
           self.id, self.label, ff(self.xinit), ff(self.xlb), ff(self.xub),
           str(self.delx), str(self.ddval)))