Exemple #1
0
def metrics(logits, targets):
    preds = torch.argmax(logits, dim=1)
    cm = ConfusionMatrix()(preds, targets)
    if len(cm.size()) == 0:
        idx = preds[0].item()
        n = cm.item()
        cm = torch.zeros((2, 2))
        cm[idx, idx] = n
    # cm_{i,j} is the number of observations in group i that were predicted in group j
    tp, tn, fn, fp = cm[1, 1], cm[0, 0], cm[0, 1], cm[1, 0]
    metrics = {'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn}
    return metrics
Exemple #2
0
    def print_pycm(self):
        cm = ConfusionMatrix(self.gts, self.preds)

        for cls_name in cm.classes:
            print('============' * 5)
            print('Class Name : [{}]'.format(cls_name)) # Class name 에 대한걸 positive라고 생각하고 tp, fn, fp, tn 구하기
            TP = cm.TP[cls_name]
            TN = cm.TN[cls_name]
            FP = cm.FP[cls_name]
            FN = cm.FN[cls_name]
            acc = cm.ACC[cls_name]
            pre = cm.PPV[cls_name]
            rec = cm.TPR[cls_name]
            spec = cm.TNR[cls_name]

            if acc is 'None':
                acc = 0.0
            if pre is 'None':
                pre = 0.0
            if rec is 'None':
                rec = 0.0
            if spec is 'None':
                spec = 0.0

            print('TP : {}, FN : {}, FP : {}, TN : {}'.format(TP, FN, FP, TN))
            print('Accuracy : {:.4f}, Precision : {:.4f}, Recall(Sensitivity) : {:.4f}, Specificity : {:.4f}'.
                format(acc, pre, rec, spec))
            print('============' * 5)
        cm.print_matrix()
        auc_list = list(cm.AUC.values())
        print('AUROC : ', auc_list)
        
        auroc_mean = 0
        for auc in auc_list:
            if auc is 'None':
                auroc_mean += 0
            else:
                auroc_mean += auc
        auroc_mean = auroc_mean / len(auc_list)
        print("AUROC mean: {:.4f}".format(auroc_mean))
        
        self.gts = []
        self.preds = []
Exemple #3
0
def test_v1_5_metric_classif_mix():
    ConfusionMatrix.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        ConfusionMatrix(num_classes=1)

    FBeta.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        FBeta(num_classes=1)

    F1.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        F1(num_classes=1)

    HammingDistance.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        HammingDistance()

    StatScores.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        StatScores()

    target = torch.tensor([1, 1, 0, 0])
    preds = torch.tensor([0, 1, 0, 0])
    confusion_matrix._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(
            confusion_matrix(preds, target, num_classes=2).float(),
            torch.tensor([[2.0, 0.0], [1.0, 1.0]]))

    target = torch.tensor([0, 1, 2, 0, 1, 2])
    preds = torch.tensor([0, 2, 1, 0, 0, 1])
    fbeta._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5),
                              torch.tensor(0.3333),
                              atol=1e-4)

    f1._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(f1(preds, target, num_classes=3),
                              torch.tensor(0.3333),
                              atol=1e-4)

    target = torch.tensor([[0, 1], [1, 1]])
    preds = torch.tensor([[0, 1], [0, 1]])
    hamming_distance._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert hamming_distance(preds, target) == torch.tensor(0.25)

    preds = torch.tensor([1, 0, 2, 1])
    target = torch.tensor([1, 1, 2, 0])
    stat_scores._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(stat_scores(preds, target, reduce="micro"),
                           torch.tensor([2, 2, 6, 2, 4]))
    def __init__(self, lr: float, num_classes: int, *args, **kwargs):
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.train_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.confusion = ConfusionMatrix(self.num_classes)

        self.init_layers(*args, **kwargs)
        self.save_hyperparameters()
Exemple #5
0
    class __GlobalConfusionMatrix:
        def __init__(self):
            self.confusion_matrix = ConfusionMatrix(7)
            self.has_data = False
            self.is_logging = False

        def enable_logging(self):
            self.is_logging = True

        def disable_logging(self):
            self.is_logging = False

        def update(self, predictions: Tensor, targets: Tensor):
            if not self.is_logging:
                return
            self.has_data = True
            preds, y = predictions.cpu(), targets.cpu()
            self.confusion_matrix.update(preds, y)

        def compute(self):
            self.has_data = False
            return self.confusion_matrix.compute()
Exemple #6
0
def accuracy_test(classifiers: T.Dict[str, T.Callable[[int],
                                                      FewshotClassifier]],
                  collections: T.Dict[str, T.Dict[str, U.data.Dataset]],
                  as_bit=True,
                  as_half=False,
                  as_cuda=True,
                  n_support=10):

    cast_tensor = (lambda t: t.cuda()) if as_cuda else (lambda t: t)
    cast_integer = (lambda t: torch.tensor(t, device="cuda")
                    if as_cuda else (lambda t: t))

    for classifier_name, classifier_constructor in classifiers.items():
        classifier = classifier_constructor(416 if as_bit else 52)

        if as_half:
            classifier.half()
        if as_cuda:
            classifier.cuda()

        for collection_name, datasets in collections.items():

            confusion_matrix = ConfusionMatrix(num_classes=len(datasets))

            key_list = list(datasets.keys())
            value_list = [datasets[k] for k in key_list]

            queries_ds = U.data.dmap(sum(value_list[1:], value_list[0]),
                                     transform=cast_tensor)
            labels_ds = U.data.dmap(
                [x for i, ds in enumerate(value_list) for x in [i] * len(ds)],
                transform=cast_integer)

            support_ds_list = []
            for dataset in datasets:
                support_ds = U.data.dconst()
Exemple #7
0
class PrototypicalCnnLstmNet(pl.LightningModule):
    def __init__(self,in_channel_cnn,out_feature_dim_cnn,out_feature_dim_lstm=52,num_lstm_layer=1,
                 metric_method="Euclidean",k_shot=1,num_class_linear_flag=None,combine=False):
        """
        :param in_channel_cnn: the input channel of CNN
        :param out_feature_dim_cnn: the output feature dimension of CNN
        :param out_feature_dim_lstm: the output feature dimension of LSTM
        :param num_lstm_layer: the number of LSTM layers
        :param metric_method: the method to metric similarity : Euclidean, cosine
        :param k_shot: the number of samples per class in the support set
        :param num_class_linear_flag: the number of classes to classifying and using the linear dual path or not
        :param combine: combine the two path results or not
        """
        super().__init__()
        self.alpha = 0.02

        self.in_channel_cnn = in_channel_cnn
        self.out_feature_dim_cnn = out_feature_dim_cnn
        self.out_feature_dim_lstm = out_feature_dim_lstm
        self.num_lstm_layer = num_lstm_layer

        self.metric_method = metric_method
        self.k_shot = k_shot
        self.combine = combine  # we need to combine the linear or not
        self.num_class_linear_flag = num_class_linear_flag  # only using when we add the linear classifier

        self.CNN_encoder = RawCSIEncoder(in_channel_cnn=self.in_channel_cnn,
                                         out_feature_dim_cnn=self.out_feature_dim_cnn)
        self.LSTM_encoder = SeqFeatureEncoder(in_feature_dim_lstm=self.out_feature_dim_cnn,
                                              out_feature_dim_lstm=self.out_feature_dim_lstm,
                                              num_lstm_layer=self.num_lstm_layer)

        if self.num_class_linear_flag is not None:
            self.linear_classifier = LinearClassifier(in_feature_dim_linear=self.out_feature_dim_lstm,num_class=self.num_class_linear_flag)
            self.train_acc_linear = pl.metrics.Accuracy()
            self.val_acc_linear = pl.metrics.Accuracy()

        self.similarity_metric = similarity.Pair_metric(metric_method=self.metric_method)
        self.similarity = similarity.Pair_metric(metric_method="cosine")
        # for calculating the cosine distance between support set feature and linear layer weights W

        self.criterion = nn.CrossEntropyLoss(size_average=False)
        self.train_acc = pl.metrics.Accuracy()  # the training accuracy of metric classifier
        self.val_acc = pl.metrics.Accuracy()  # the validation accuracy of metric classifier

        self.confmat_linear_all = []  # storage all the confusion matrix of linear classifier
        self.comfmat_metric_all = []  # storage all the confusion matrix of metric classifier

    def training_step(self, batch, batch_idx):
        batch = batch[0]
        query_dataset, support_set = batch

        query_data, query_activity_label, = query_dataset   # the data list: [num_sample1,(in_channel,time).tensor]
        qu_activity_label = torch.stack(query_activity_label)

        support_data, support_activity_label, = support_set  # the data list:[num_sample2,(in_channel,time).tensor]
        su_activity_label = torch.stack(support_activity_label)

        # extracting the features
        qu_data = torch.cat(query_data)  # [batch*time,in_channel,length,width]
        su_data = torch.cat(support_data)  # [batch*time,in_channel,length,width]

        qu_feature_out = self.CNN_encoder(qu_data)
        qu_seq_in = torch.split(qu_feature_out, [len(x) for x in query_data])  # [batch,time,feature_dim]
        qu_feature = self.LSTM_encoder(qu_seq_in)

        su_feature_out = self.CNN_encoder(su_data)
        su_seq_in = torch.split(su_feature_out, [len(x) for x in support_data])  # [batch,time,feature_dim]
        su_feature = self.LSTM_encoder(su_seq_in)

        # if num_class_linear_flag is not None, which means we using the Dual path.
        if self.num_class_linear_flag is not None:
            pre_gesture_linear_qu = self.linear_classifier(qu_feature)
            pre_gesture_linear_su = self.linear_classifier(su_feature)
            pre_gesture_linear = torch.cat([pre_gesture_linear_su,pre_gesture_linear_qu])
            gesture_label = torch.cat([su_activity_label , qu_activity_label])
            linear_classifier_loss = self.criterion(pre_gesture_linear,gesture_label.long().squeeze())
            self.log("GesTr_loss_linear", linear_classifier_loss)
            self.train_acc_linear(pre_gesture_linear, gesture_label.long().squeeze())
        else:
            linear_classifier_loss = 0

        # for few-shot, we using average values of all the support set sample-feature as the final feature.
        if self.k_shot != 1:
            su_feature_temp1 = su_feature.reshape(-1,self.k_shot,su_feature.size()[1])
            su_feature_k_shot = su_feature_temp1.mean(1,keepdim=False)
        else:
            su_feature_k_shot = su_feature

        # combine the dual path knowledge
        if self.combine:
            # su_feature_final = torch.ones_like(su_feature_k_shot)
            # w = self.linear_classifier.decoder.weight.detach()
            # cosine_distance = self.similarity(su_feature_k_shot, w)
            # for i in range(len(w)):
            #     cosine = cosine_distance[i][i]
            #     t = 0
            #     for j in range(len(w)):
            #         t += cosine_distance[i][j]
            #     if cosine / t > self.alpha:
            #         replace_support = cosine * w[i] * (torch.norm(su_feature_k_shot[i], p=2) / torch.norm(w[i], p=2))
            #         su_feature_final[i] = replace_support

            su_feature_final = su_feature_k_shot
            w = self.linear_classifier.decoder.weight
            cosine_distance = self.similarity(w, w)
            zero = torch.zeros_like(cosine_distance)
            constraint_element1 = torch.where(cosine_distance < self.alpha, zero, cosine_distance)
            constraint_element = torch.where(constraint_element1 == 1, zero,
                                             constraint_element1)
            loss_orthogonal_constraint = constraint_element.sum() / 2
            linear_classifier_loss += loss_orthogonal_constraint
        else:
            su_feature_final = su_feature_k_shot

        predict_label = self.similarity_metric(qu_feature,su_feature_final)
        loss = self.criterion(predict_label, qu_activity_label.long().squeeze())
        self.log("GesTr_loss", loss)
        self.train_acc(predict_label, qu_activity_label.long().squeeze())

        loss += linear_classifier_loss
        return loss

    def validation_step(self,  batch, batch_idx):
        batch = batch[0]
        query_dataset, support_set = batch

        query_data, query_activity_label, = query_dataset  # the data list: [num_sample1,(in_channel,time).tensor]
        qu_activity_label = torch.stack(query_activity_label)

        support_data, support_activity_label, = support_set  # the data list:[num_sample2,(in_channel,time).tensor]
        su_activity_label = torch.stack(support_activity_label)

        # extracting the features
        qu_data = torch.cat(query_data)  # [batch*time,in_channel,length,width]
        su_data = torch.cat(support_data)  # [batch*time,in_channel,length,width]

        qu_feature_out = self.CNN_encoder(qu_data)
        qu_seq_in = torch.split(qu_feature_out, [len(x) for x in query_data])  # [batch,time,feature_dim]
        qu_feature = self.LSTM_encoder(qu_seq_in)

        su_feature_out = self.CNN_encoder(su_data)
        su_seq_in = torch.split(su_feature_out, [len(x) for x in support_data])  # [batch,time,feature_dim]
        su_feature = self.LSTM_encoder(su_seq_in)

        # if num_class_linear_flag is not None, which means we using the Dual path.
        if self.num_class_linear_flag is not None:
            pre_gesture_linear_qu = self.linear_classifier(qu_feature)
            pre_gesture_linear_su = self.linear_classifier(su_feature)
            pre_gesture_linear = torch.cat([pre_gesture_linear_su, pre_gesture_linear_qu])
            gesture_label = torch.cat([su_activity_label, qu_activity_label])
            linear_classifier_loss = self.criterion(pre_gesture_linear, gesture_label.long().squeeze())
            self.log("GesVa_loss_linear", linear_classifier_loss)
            self.val_acc_linear(pre_gesture_linear, gesture_label.long().squeeze())
            self.confmat_linear.update(pre_gesture_linear.cpu(), gesture_label.long().squeeze().cpu())
        else:
            linear_classifier_loss = 0

        # for few-shot, we using average values of all the support set sample-feature as the final feature.
        if self.k_shot != 1:
            su_feature_temp1 = su_feature.reshape(-1, self.k_shot, su_feature.size()[1])
            su_feature_k_shot = su_feature_temp1.mean(1, keepdim=False)
        else:
            su_feature_k_shot = su_feature

        # combine the dual path knowledge or add the orthogonal constraint
        if self.combine:
            # su_feature_final = torch.ones_like(su_feature_k_shot)
            # w = self.linear_classifier.decoder.weight.detach()
            # cosine_distance = self.similarity(su_feature_k_shot, w)
            # for i in range(len(w)):
            #     cosine = cosine_distance[i][i]
            #     t = 0
            #     for j in range(len(w)):
            #         t += cosine_distance[i][j]
            #     if cosine / t > self.alpha:
            #         replace_support = cosine * w[i] * (torch.norm(su_feature_k_shot[i], p=2) / torch.norm(w[i], p=2))
            #         su_feature_final[i] = replace_support

            su_feature_final = su_feature_k_shot
            w = self.linear_classifier.decoder.weight
            cosine_distance = self.similarity(w, w)
            zero = torch.zeros_like(cosine_distance)
            # constraint_element = torch.where((cosine_distance < self.alpha) or (cosine_distance == 1), zero, cosine_distance)
            # loss_orthogonal_constraint = constraint_element.sum() / 2
            constraint_element1 = torch.where(cosine_distance < self.alpha, zero, cosine_distance)
            constraint_element = torch.where(constraint_element1 == 1, zero,
                                             constraint_element1)
            loss_orthogonal_constraint = constraint_element.sum() / 2
            linear_classifier_loss += loss_orthogonal_constraint
        else:
            su_feature_final = su_feature_k_shot

        predict_label = self.similarity_metric(qu_feature, su_feature_final)
        loss = self.criterion(predict_label, qu_activity_label.long().squeeze())
        self.log("GesVa_loss", loss)
        self.val_acc(predict_label, qu_activity_label.long().squeeze())
        self.confmat_metric.update(predict_label.cpu(), qu_activity_label.long().squeeze().cpu())

        loss += linear_classifier_loss
        return loss

    def on_validation_epoch_start(self):
        self.confmat_metric = ConfusionMatrix(num_classes=6)
        if self.num_class_linear_flag is not None:
            self.confmat_linear = ConfusionMatrix(num_classes=6)

    def validation_epoch_end(self, val_step_outputs):
        self.log('GesVa_Acc', self.val_acc.compute())
        self.comfmat_metric_all.append(self.confmat_metric.compute())

        if self.num_class_linear_flag is not None:
            self.log('GesVa_Acc_linear', self.val_acc_linear.compute())
            self.confmat_linear_all.append(self.confmat_linear.compute())

    def training_epoch_end(self, training_step_outputs):
        self.log('GesTr_Acc', self.train_acc.compute())
        if self.num_class_linear_flag is not None:
            self.log('train_acc_linear', self.train_acc_linear.compute())

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=[80, 160, 240, 320,400],
                                                         gamma=0.5)
        return [optimizer, ], [scheduler, ]
Exemple #8
0
 def on_validation_epoch_start(self):
     self.confmat_metric = ConfusionMatrix(num_classes=6)
     if self.num_class_linear_flag is not None:
         self.confmat_linear = ConfusionMatrix(num_classes=6)
class PrototypicalResNet(pl.LightningModule):
    def __init__(self,
                 layers,
                 strides,
                 inchannel=52,
                 groups=1,
                 align=False,
                 metric_method="Euclidean",
                 k_shot=1,
                 num_class_linear_flag=None,
                 combine=False):
        """
        :param layers: this is a list, which define the number of types of layers
        :param strides:  the convolution strides of layers
        :param inchannel: input channel
        :param groups: convolution groups
        :param align: whether the length of input series are the same or not
        :param metric_method: the method to metric similarity : Euclidean, cosine
        :param k_shot: the number of samples per class in the support set
        :param num_class_linear_flag: the number of classes to classifying and using the linear dual path or not
        :param combine: combine the two path results or not
        """
        super().__init__()
        self.alpha = 0.02

        self.layers = layers
        self.strides = strides
        self.inchannel = inchannel
        self.groups = groups
        self.align = align

        self.metric_method = metric_method
        self.k_shot = k_shot
        self.combine = combine  # we need to combine the linear or not
        self.num_class_linear_flag = num_class_linear_flag  # only using when we add the linear classifier

        self.ResNet_encoder = ResNet_CSI(
            block=BasicBlock,
            layers=self.layers,
            strides=self.strides,
            inchannel=self.inchannel,
            groups=self.groups)  # output shape [feature_dim, length]
        self.feature_dim = self.ResNet_encoder.out_dim

        if self.num_class_linear_flag is not None:
            self.linear_classifier = LinearClassifier(
                in_channel=self.feature_dim,
                num_class=self.num_class_linear_flag)
            self.train_acc_linear = pl.metrics.Accuracy()
            self.val_acc_linear = pl.metrics.Accuracy()

        self.similarity_metric = similarity.Pair_metric(
            metric_method=self.metric_method, inchannel=self.feature_dim * 2)
        self.similarity = similarity.Pair_metric(metric_method="cosine")
        # for calculating the cosine distance between support set feature and linear layer weights W

        self.criterion = nn.CrossEntropyLoss(size_average=False)
        self.train_acc = pl.metrics.Accuracy(
        )  # the training accuracy of metric classifier
        self.val_acc = pl.metrics.Accuracy(
        )  # the validation accuracy of metric classifier

        self.confmat_linear_all = [
        ]  # storage all the confusion matrix of linear classifier
        self.comfmat_metric_all = [
        ]  # storage all the confusion matrix of metric classifier

    def training_step(self, batch, batch_idx):
        batch = batch[0]
        query_dataset, support_set = batch

        query_data, query_activity_label, = query_dataset  # the data list: [num_sample1,(in_channel,time).tensor]
        qu_activity_label = torch.stack(query_activity_label)

        support_data, support_activity_label, = support_set  # the data list:[num_sample2,(in_channel,time).tensor]
        su_activity_label = torch.stack(support_activity_label)

        # extracting the features
        if self.align:
            qu_data = torch.stack(query_data)  # [num_sample1,in_channel,time]
            su_data = torch.stack(
                support_data)  # [num_sample2,in_channel,time]
            qu_feature = self.ResNet_encoder(qu_data)
            su_feature = self.ResNet_encoder(su_data)
        else:
            qu_data = custom_stack(query_data,
                                   time_dim=1)  # [num_sample1,in_channel,time]
            su_data = custom_stack(support_data,
                                   time_dim=1)  # [num_sample2,in_channel,time]
            qu_feature = self.ResNet_encoder(qu_data)
            su_feature = self.ResNet_encoder(su_data)

            # qu_feature_temp = []
            # su_feature_temp = []
            # for i,x in enumerate(query_data):
            #     feature = self.ResNet_encoder(x)
            #     qu_feature_temp.append(feature)
            # for j,x in enumerate(support_data):
            #     feature = self.ResNet_encoder(x)
            #     su_feature_temp.append(feature)
            # qu_feature = torch.stack(qu_feature_temp)  # [num_sample1,out_channel,length]
            # su_feature = torch.stack(su_feature_temp)  # [num_sample2,out_channel,length]

        # if num_class_linear_flag is not None, which means we using the Dual path.
        if self.num_class_linear_flag is not None:
            pre_gesture_linear_qu = self.linear_classifier(qu_feature)
            pre_gesture_linear_su = self.linear_classifier(su_feature)
            pre_gesture_linear = torch.cat(
                [pre_gesture_linear_su, pre_gesture_linear_qu])
            gesture_label = torch.cat([su_activity_label, qu_activity_label])
            linear_classifier_loss = self.criterion(
                pre_gesture_linear,
                gesture_label.long().squeeze())
            self.log("GesTr_loss_linear", linear_classifier_loss)
            self.train_acc_linear(pre_gesture_linear,
                                  gesture_label.long().squeeze())
        else:
            linear_classifier_loss = 0

        # for few-shot, we using average values of all the support set sample-feature as the final feature.
        if self.k_shot != 1:
            su_feature_temp1 = su_feature.reshape(-1, self.k_shot,
                                                  su_feature.size()[1],
                                                  su_feature.size()[2])
            su_feature_k_shot = su_feature_temp1.mean(1, keepdim=False)
        else:
            su_feature_k_shot = su_feature

        # combine the dual path knowledge
        if self.combine:
            su_feature_final = su_feature_k_shot
            w = self.linear_classifier.decoder.weight
            cosine_distance = self.similarity(w, w)
            zero = torch.zeros_like(cosine_distance)
            # constraint_element = torch.where((cosine_distance < self.alpha) or (cosine_distance == 1), zero, cosine_distance)
            constraint_element1 = torch.where(cosine_distance < self.alpha,
                                              zero, cosine_distance)
            constraint_element = torch.where(constraint_element1 == 1, zero,
                                             constraint_element1)
            loss_orthogonal_constraint = constraint_element.sum() / 2
            linear_classifier_loss += loss_orthogonal_constraint
        else:
            su_feature_final = su_feature_k_shot

        predict_label = self.similarity_metric(qu_feature, su_feature_final)
        loss = self.criterion(predict_label,
                              qu_activity_label.long().squeeze())
        self.log("GesTr_loss", loss)
        self.train_acc(predict_label, qu_activity_label.long().squeeze())

        loss += linear_classifier_loss
        return loss

    def validation_step(self, batch, batch_idx):
        batch = batch[0]
        query_dataset, support_set = batch

        query_data, query_activity_label, = query_dataset  # the data list: [num_sample1,(in_channel,time).tensor]
        qu_activity_label = torch.stack(query_activity_label)

        support_data, support_activity_label, = support_set  # the data list:[num_sample2,(in_channel,time).tensor]
        su_activity_label = torch.stack(support_activity_label)

        # extracting the features
        if self.align:
            qu_data = torch.stack(
                query_data)  # [num_sample1,time,in_channel,time]
            su_data = torch.stack(
                support_data)  # [num_sample2,time,in_channel,time]
            qu_feature = self.ResNet_encoder(qu_data)
            su_feature = self.ResNet_encoder(su_data)
        else:
            qu_data = custom_stack(query_data,
                                   time_dim=1)  # [num_sample1,in_channel,time]
            su_data = custom_stack(support_data,
                                   time_dim=1)  # [num_sample2,in_channel,time]
            qu_feature = self.ResNet_encoder(qu_data)
            su_feature = self.ResNet_encoder(su_data)
            # qu_feature_temp = []
            # su_feature_temp = []
            # for i, x in enumerate(query_data):
            #     feature = self.ResNet_encoder(x)
            #     qu_feature_temp.append(feature)
            # for j, x in enumerate(support_data):
            #     feature = self.ResNet_encoder(x)
            #     su_feature_temp.append(feature)
            # qu_feature = torch.stack(qu_feature_temp)
            # su_feature = torch.stack(su_feature_temp)

        # if num_class_linear_flag is not None, which means we using the Dual path.
        if self.num_class_linear_flag is not None:
            pre_gesture_linear_qu = self.linear_classifier(qu_feature)
            pre_gesture_linear_su = self.linear_classifier(su_feature)
            pre_gesture_linear = torch.cat(
                [pre_gesture_linear_su, pre_gesture_linear_qu])
            gesture_label = torch.cat([su_activity_label, qu_activity_label])
            linear_classifier_loss = self.criterion(
                pre_gesture_linear,
                gesture_label.long().squeeze())
            self.log("GesVa_loss_linear", linear_classifier_loss)
            self.val_acc_linear(pre_gesture_linear,
                                gesture_label.long().squeeze())
            self.confmat_linear.update(pre_gesture_linear.cpu(),
                                       gesture_label.long().squeeze().cpu())
        else:
            linear_classifier_loss = 0

        # for few-shot, we using average values of all the support set sample-feature as the final feature.
        if self.k_shot != 1:
            su_feature_temp1 = su_feature.reshape(-1, self.k_shot,
                                                  su_feature.size()[1],
                                                  su_feature.size()[2])
            su_feature_k_shot = su_feature_temp1.mean(1, keepdim=False)

        else:
            su_feature_k_shot = su_feature

        # combine the dual path knowledge or add the orthogonal constraint
        if self.combine:
            su_feature_final = su_feature_k_shot
            w = self.linear_classifier.decoder.weight
            cosine_distance = self.similarity(w, w)
            zero = torch.zeros_like(cosine_distance)
            # constraint_element = torch.where((cosine_distance < self.alpha) or (cosine_distance == 1), zero, cosine_distance)
            constraint_element1 = torch.where(cosine_distance < self.alpha,
                                              zero, cosine_distance)
            constraint_element = torch.where(constraint_element1 == 1, zero,
                                             constraint_element1)
            loss_orthogonal_constraint = constraint_element.sum() / 2
            linear_classifier_loss += loss_orthogonal_constraint
        else:
            su_feature_final = su_feature_k_shot

        predict_label = self.similarity_metric(qu_feature, su_feature_final)
        loss = self.criterion(predict_label,
                              qu_activity_label.long().squeeze())
        self.log("GesVa_loss", loss)
        self.val_acc(predict_label, qu_activity_label.long().squeeze())
        self.confmat_metric.update(predict_label.cpu(),
                                   qu_activity_label.long().squeeze().cpu())

        loss += linear_classifier_loss
        return loss

    def on_validation_epoch_start(self):
        self.confmat_metric = ConfusionMatrix(num_classes=6)
        if self.num_class_linear_flag is not None:
            self.confmat_linear = ConfusionMatrix(num_classes=6)

    def validation_epoch_end(self, val_step_outputs):
        self.log('GesVa_Acc', self.val_acc.compute())
        self.comfmat_metric_all.append(self.confmat_metric.compute())

        if self.num_class_linear_flag is not None:
            self.log('GesVa_Acc_linear', self.val_acc_linear.compute())
            self.confmat_linear_all.append(self.confmat_linear.compute())

    def training_epoch_end(self, training_step_outputs):
        self.log('GesTr_Acc', self.train_acc.compute())
        if self.num_class_linear_flag is not None:
            self.log('train_acc_linear', self.train_acc_linear.compute())

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[80, 160, 240, 320, 400], gamma=0.5)
        return [
            optimizer,
        ], [
            scheduler,
        ]
 def __init__(self, config, trial=None):
     if hasattr(config.net_config, "use_detector_number"):
         self.use_detector_number = config.net_config.use_detector_number
         if self.use_detector_number:
             if not hasattr(config.net_config, "num_detectors"):
                 raise IOError(
                     "net config must contain 'num_detectors' property if 'use_detector_number' set to true"
                 )
             config.system_config.n_samples = config.system_config.n_samples + 3
             if config.net_config.num_detectors == 308:
                 self.detector_num_factor_x = 1. / 13
                 self.detector_num_factor_y = 1. / 10
             else:
                 raise IOError("num detectors " +
                               str(config.net_config.num_detector) +
                               " not supported")
     else:
         self.use_detector_number = False
     super(LitWaveform, self).__init__(config, trial)
     if config.net_config.net_class.endswith("RecurrentWaveformNet"):
         self.squeeze_index = 2
     else:
         self.squeeze_index = 1
     self.test_has_phys = False
     if hasattr(self.config.dataset_config, "test_dataset_params"):
         if self.config.dataset_config.test_dataset_params.label_name == "phys" and not hasattr(
                 self.config.dataset_config.test_dataset_params,
                 "label_index"):
             self.test_has_phys = True
     if hasattr(self.config.dataset_config, "calgroup"):
         calgroup = self.config.dataset_config.calgroup
     else:
         calgroup = None
     if hasattr(self.config.dataset_config.dataset_params, "label_index"):
         self.target_index = self.config.dataset_config.dataset_params.label_index
     else:
         self.target_index = None
     self.use_accuracy = False
     if config.net_config.criterion_class == "L1Loss":
         metric_name = "mean absolute error"
     elif config.net_config.criterion_class == "MSELoss":
         metric_name = "mean squared error"
     elif config.net_config.criterion_class.startswith(
             "BCE") or config.net_config.criterion_class.startswith(
                 "CrossEntropy"):
         self.use_accuracy = True
         metric_name = "Accuracy"
     else:
         metric_name = "?"
     eval_params = {}
     if hasattr(config, "evaluation_config"):
         eval_params = DictionaryUtility.to_dict(config.evaluation_config)
     self.evaluator = TensorEvaluator(self.logger,
                                      calgroup=calgroup,
                                      target_has_phys=self.test_has_phys,
                                      target_index=self.target_index,
                                      metric_name=metric_name,
                                      **eval_params)
     self.loss_no_reduce = self.criterion_class(
         *config.net_config.criterion_params, reduction="none")
     if self.use_accuracy:
         self.accuracy = Accuracy()
         self.confusion = ConfusionMatrix(2)
         self.softmax = Softmax(dim=1)
Exemple #11
0
 def __init__(self):
     self.confusion_matrix = ConfusionMatrix(7)
     self.has_data = False
     self.is_logging = False