コード例 #1
0
    def attach(self, engine, bundle):
        super().attach(engine, bundle)
        if bundle[l.MODE] in [l.TRAIN, l.VAL]:
            px_thresh = bundle[exp.PX_THRESH]
            metric_log_iter = bundle.get(exp.METRIC_LOG_ITER)

            AveragePeriodicMetric(RepTransformer(px_thresh),
                                  metric_log_iter).attach(engine, meu.REP)
            AveragePeriodicMetric(
                MSTransformer(px_thresh, DescriptorDistance.INV_COS_SIM),
                metric_log_iter).attach(engine, meu.MS)
            AveragePeriodicMetric(
                MMATransformer(px_thresh, DescriptorDistance.INV_COS_SIM),
                metric_log_iter).attach(engine, meu.MMA)

            if bundle[l.MODE] == l.VAL:

                AveragePeriodicMetric(EMSTransformer(px_thresh,
                                                     self.device)).attach(
                                                         engine, meu.EMS)

        elif bundle[l.MODE] == l.TEST:
            if du.MEGADEPTH in bundle[du.DATASET_NAME]:
                px_thresh = bundle[exp.PX_THRESH]
                num_cat = len(px_thresh)

                DetailedMetric(RepTransformer(px_thresh, True),
                               len(px_thresh)).attach(engine, meu.REP)
                DetailedMetric(
                    MSTransformer(px_thresh, DescriptorDistance.INV_COS_SIM,
                                  True), num_cat).attach(engine, meu.MS)
                DetailedMetric(
                    MMATransformer(px_thresh, DescriptorDistance.INV_COS_SIM,
                                   True), num_cat).attach(engine, meu.MMA)
                DetailedMetric(EMSTransformer(px_thresh, self.device, True),
                               num_cat).attach(engine, meu.EMS)

                DetailedMetric(PoseTransformer(px_thresh, self.device),
                               num_cat).attach(engine, meu.REL_POSE)
                DetailedMetric(ParamPoseTransformer(px_thresh, self.device),
                               num_cat).attach(engine, meu.PARAM_REL_POSE)

            elif du.HPATCHES_VIEW in bundle[
                    du.DATASET_NAME] or du.HPATCHES_ILLUM in bundle[
                        du.DATASET_NAME]:
                px_thresh = bundle[exp.PX_THRESH]
                num_cat = len(px_thresh)

                DetailedMetric(RepTransformer(px_thresh, True),
                               len(px_thresh)).attach(engine, meu.REP)
                DetailedMetric(
                    MSTransformer(px_thresh, DescriptorDistance.INV_COS_SIM,
                                  True), num_cat).attach(engine, meu.MS)
                DetailedMetric(
                    MMATransformer(px_thresh, DescriptorDistance.INV_COS_SIM,
                                   True), num_cat).attach(engine, meu.MMA)
コード例 #2
0
    def attach(self, engine, bundle):
        loss_log_iter = bundle.get(exp.LOSS_LOG_ITER)
        AveragePeriodicMetric(KeyTransformer(LOSS),
                              loss_log_iter).attach(engine, LOSS)

        for c_w in self.criterion_wrappers:
            c_w.attach(engine, bundle)
コード例 #3
0
 def attach(self, engine, bundle):
     loss_log_iter = bundle.get(exp.LOSS_LOG_ITER)
     AveragePeriodicMetric(KeyTransformer(cu.DET_CONF_LOSS), loss_log_iter).attach(engine, cu.DET_CONF_LOSS)