def __init_meters__(self) -> List[Union[str, List[str]]]:
     meter_config = {
         "lr":
         AverageValueMeter(),
         "trloss":
         AverageValueMeter(),
         "trdice":
         SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
         "valloss":
         AverageValueMeter(),
         "valdice":
         SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
         "valbdice":
         BatchDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
     }
     self.METERINTERFACE = MeterInterface(meter_config)
     return [
         "trloss_mean",
         ["trdice_DSC1", "trdice_DSC2", "trdice_DSC3"],
         "valloss_mean",
         ["valdice_DSC1", "valdice_DSC2", "valdice_DSC3"],
         ["valbdice_DSC1", "valbdice_DSC2", "valbdice_DSC3"],
         "lr_mean",
     ]
示例#2
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     columns = super().__init_meters__()
     self.METERINTERFACE.register_meter("uda_reg", AverageValueMeter())
     self.METERINTERFACE.register_meter("entropy", AverageValueMeter())
     self.METERINTERFACE.register_meter("marginal", AverageValueMeter())
     self.METERINTERFACE.register_meter("unl_acc", ConfusionMatrix(5))
     columns.extend(["uda_reg_mean", "marginal_mean", "entropy_mean"])
     return columns
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     meter_config = {
         "lr": AverageValueMeter(),
         "traloss": AverageValueMeter(),
         "traconf": ConfusionMatrix(self.model.torchnet.num_classes),
         "valloss": AverageValueMeter(),
         "valconf": ConfusionMatrix(self.model.torchnet.num_classes),
     }
     self.METERINTERFACE = MeterInterface(meter_config)
     return ["traloss_mean", "traconf_acc", "valloss_mean", "valconf_acc", "lr_mean"]
示例#4
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     _ = super().__init_meters__()
     self.METERINTERFACE.register_new_meter("train_swa_loss", AverageValueMeter())
     self.METERINTERFACE.register_new_meter("train_swa_acc", AverageValueMeter())
     self.METERINTERFACE.register_new_meter("val_swa_loss", AverageValueMeter())
     self.METERINTERFACE.register_new_meter("val_swa_acc", ConfusionMatrix(10))
     self.METERINTERFACE.register_new_meter("train_swa_acc", ConfusionMatrix(10))
     return [
         ["train_loss_mean", "val_loss_mean"],
         ["train_swa_loss_mean", "val_swa_loss_mean"],
         ["train_acc_acc", "val_acc_acc"],
         ["train_swa_acc_acc", "val_swa_acc_acc"],
         "lr_value",
     ]
示例#5
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     METER_CONFIG = {
         "lr": InstanceValue(),
         "train_loss": AverageValueMeter(),
         "val_loss": AverageValueMeter(),
         "train_acc": ConfusionMatrix(10),
         "val_acc": ConfusionMatrix(10),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)  # type:ignore
     return [
         ["train_loss_mean", "val_loss_mean"],
         ["train_acc_acc", "val_acc_acc"],
         "lr_value",
     ]
示例#6
0
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {
         'tra_reg_total': AverageValueMeter(),
         'tra_sup_label': AverageValueMeter(),
         'tra_adv': AverageValueMeter(),
         'tra_entropy': AverageValueMeter(),
         'tra_conf': ConfusionMatrix(num_classes=10),
         'val_conf': ConfusionMatrix(num_classes=10)
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [
         'tra_reg_total_mean', 'tra_sup_label_mean', 'tra_adv_mean',
         'tra_entropy_mean', 'tra_conf_acc', 'val_conf_acc'
     ]
示例#7
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     """
     basic meters to record clustering results, specifically for multi-subheads.
     :return:
     """
     METER_CONFIG = {
         "val_average_acc": AverageValueMeter(),
         "val_best_acc": AverageValueMeter(),
         "val_worst_acc": AverageValueMeter(),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [[
         "val_average_acc_mean", "val_best_acc_mean", "val_worst_acc_mean"
     ]]
示例#8
0
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {
         "train_mi": AverageValueMeter(),
         "train_sat": AverageValueMeter(),
         "val_best_acc": AverageValueMeter(),
         "val_average_acc": AverageValueMeter(),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [
         "train_mi_mean",
         "train_sat_mean",
         "val_average_acc_mean",
         "val_best_acc_mean",
     ]
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     """
     Initialize the meters by extending the father class with MI related meters.
     :return: [ "train_mi_mean", "train_entropy_mean", "train_centropy_mean", validation meters]
     """
     colum_to_draw = super().__init_meters__()
     self.METERINTERFACE.register_new_meter("train_mi", AverageValueMeter())
     self.METERINTERFACE.register_new_meter("train_entropy", AverageValueMeter())
     self.METERINTERFACE.register_new_meter("train_centropy", AverageValueMeter())
     colum_to_draw = [  # type: ignore
                         "train_mi_mean",  # type: ignore
                         "train_entropy_mean",
                         "train_centropy_mean",
                     ] + colum_to_draw  # type: ignore
     return colum_to_draw
示例#10
0
 def setUp(self) -> None:
     super().setUp()
     self._meter_config = {
         "avg1": AverageValueMeter(),
         "dice1": SliceDiceMeter(C=2),
         "dice2": SliceDiceMeter(C=2),
     }
     self.meters = MeterInterface(self._meter_config)
 def test_meter_interface(self):
     meterinterface = MeterInterface(
         {"avg1": AverageValueMeter(), "dice1": SliceDiceMeter()}
     )
     print(meterinterface.summary())
     for epoch in range(10):
         if epoch == 2:
             meterinterface.register_meter("avg2", AverageValueMeter())
         for i in range(10):
             meterinterface["avg1"].add(1)
             meterinterface["dice1"].add(
                 torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224))
             )
             try:
                 meterinterface["avg2"].add(2)
             except:
                 pass
         meterinterface.step()
     print(meterinterface.summary())
    def test_resume(self):

        meterinterface = MeterInterface(
            {"avg1": AverageValueMeter(), "dice1": SliceDiceMeter()}
        )
        meterinterface.step()
        meterinterface.step()
        meterinterface.step()

        for epoch in range(10):
            if epoch == 2:
                meterinterface.register_meter("avg2", AverageValueMeter())
            for i in range(10):
                meterinterface["avg1"].add(1)
                meterinterface["dice1"].add(
                    torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224))
                )
                try:
                    meterinterface["avg2"].add(2)
                except:
                    pass
            meterinterface.step()
        print(meterinterface.summary())
        state_dict = meterinterface.state_dict()

        meterinterface2 = MeterInterface(
            {
                "avg1": AverageValueMeter(),
                "avg2": AverageValueMeter(),
                "dice1": SliceDiceMeter(),
                "avg3": AverageValueMeter(),
            }
        )
        meterinterface2.load_state_dict(state_dict)

        for epoch in range(10):
            for i in range(10):
                meterinterface2["avg3"].add(1)
                meterinterface2["dice1"].add(
                    torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224))
                )
            meterinterface2.step()
        print(meterinterface2.summary())
示例#13
0
    def __init__(self,
                 VAT_params: Dict[str, Union[str, float]] = {"eps": 10},
                 MeterInterface=None) -> None:
        # super().__init__()

        self.VAT_params = VAT_params
        self.vat_module = VATModuleInterface(VAT_params)
        self.MeterInterface = MeterInterface
        if self.MeterInterface:
            self.MeterInterface.register_new_meter("train_adv",
                                                   AverageValueMeter())
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {
         "train_adv_loss": AverageValueMeter(),
         "train_sat_loss": AverageValueMeter(),
         "train_mi_loss": AverageValueMeter(),
         "val_avg_acc": AverageValueMeter(),
         "val_best_acc": AverageValueMeter(),
         "val_worst_acc": AverageValueMeter(),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [
         "train_mi_loss_mean",
         "train_sat_loss_mean",
         "train_adv_loss_mean",
         ["val_avg_acc_mean", "val_best_acc_mean", "val_worst_acc_mean"],
     ]
示例#15
0
 def test_averagevalueMeter(self):
     meter = AverageValueMeter()
     for i in range(10000):
         meter.add(1)
     print(meter)
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     columns_to_draw = super().__init_meters__()
     self.METERINTERFACE.register_new_meter("train_head_A", AverageValueMeter())
     self.METERINTERFACE.register_new_meter("train_head_B", AverageValueMeter())
     columns_to_draw = ["train_head_B_mean"] + columns_to_draw
     return columns_to_draw
示例#17
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     columns = super().__init_meters__()
     self.METERINTERFACE.register_meter("marginal", AverageValueMeter())
     self.METERINTERFACE.register_meter("centropy", AverageValueMeter())
     columns.extend(["marginal_mean", "centropy_mean"])
     return columns
示例#18
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     columns = super().__init_meters__()
     self.METERINTERFACE.register_meter("residual", AverageValueMeter())
     columns.append("residual_mean")
     return columns
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     columns = super().__init_meters__()
     self.METERINTERFACE.register_new_meter("train_geo", AverageValueMeter())
     self.METERINTERFACE.register_new_meter("train_adv", AverageValueMeter())
     return ["train_geo_mean", "train_adv_mean"] + columns
 def setUp(self) -> None:
     self.meter_config = {
         "loss": AverageValueMeter(),
         "tra_dice": SliceDiceMeter(C=5),
     }
     self.criterion = nn.CrossEntropyLoss()
    def supervised_training(self, use_pretrain=True, lr=1e-3, data_aug=False):
        # load the best checkpoint
        self.load_checkpoint(
            torch.load(str(Path(self.checkpoint) / self.checkpoint_identifier),
                       map_location=torch.device("cpu")))
        self.model.to(self.device)

        from torchvision import transforms
        transform_train = transforms.Compose([
            pil_augment.CenterCrop(size=(20, 20)),
            pil_augment.Resize(size=(32, 32), interpolation=PIL.Image.NEAREST),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            pil_augment.Img2Tensor()
        ])
        transform_val = transforms.Compose([
            pil_augment.CenterCrop(size=(20, 20)),
            pil_augment.Resize(size=(32, 32), interpolation=PIL.Image.NEAREST),
            pil_augment.Img2Tensor()
        ])

        self.kl = KL_div(reduce=True)

        def _sup_train_loop(train_loader, epoch):
            self.model.train()
            train_loader_ = tqdm_(train_loader)
            for batch_num, (image_gt) in enumerate(train_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                loss = self.kl(pred, class2one_hot(gt, 10).float())
                self.model.zero_grad()
                loss.backward()
                self.model.step()
                linear_meters["train_loss"].add(loss.item())
                linear_meters["train_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "tra_acc": linear_meters["train_acc"].summary()["acc"],
                    "loss": linear_meters["train_loss"].summary()["mean"],
                }
                train_loader_.set_postfix(report_dict)

            print(f"  Training epoch {epoch}: {nice_dict(report_dict)} ")

        def _sup_eval_loop(val_loader, epoch) -> Tensor:
            self.model.eval()
            val_loader_ = tqdm_(val_loader)
            for batch_num, (image_gt) in enumerate(val_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                linear_meters["val_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "val_acc": linear_meters["val_acc"].summary()["acc"]
                }
                val_loader_.set_postfix(report_dict)
            print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ")
            return linear_meters["val_acc"].summary()["acc"]

            # building training and validation set based on extracted features

        train_loader = dcp(self.val_loader)
        train_loader.dataset.datasets = (
            train_loader.dataset.datasets[0].datasets[0], )
        val_loader = dcp(self.val_loader)
        val_loader.dataset.datasets = (
            val_loader.dataset.datasets[0].datasets[1], )

        if data_aug:
            train_loader.dataset.datasets[0].transform = transform_train
            val_loader.dataset.datasets[0].transform = transform_val

        # network and optimization
        if not use_pretrain:
            self.model.torchnet.apply(weights_init)
        else:
            self.model.torchnet.head_B.apply(weights_init)
            # wipe out the initialization
        self.model.optimizer = torch.optim.Adam(
            self.model.torchnet.parameters(), lr=lr)
        self.model.scheduler = torch.optim.lr_scheduler.StepLR(
            self.model.optimizer, step_size=50, gamma=0.2)

        # meters
        meter_config = {
            "train_loss": AverageValueMeter(),
            "train_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"]),
            "val_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"])
        }
        linear_meters = MeterInterface(meter_config)
        drawer = DrawCSV2(
            save_dir=self.save_dir,
            save_name=
            f"supervised_from_checkpoint_{use_pretrain}_data_aug_{data_aug}.png",
            columns_to_draw=[
                "train_loss_mean", "train_acc_acc", "val_acc_acc"
            ])
        for epoch in range(self.max_epoch):
            _sup_train_loop(train_loader, epoch)
            with torch.no_grad():
                _ = _sup_eval_loop(val_loader, epoch)
            self.model.step()
            linear_meters.step()
            linear_meters.summary().to_csv(
                self.save_dir /
                f"supervised_from_checkpoint_{use_pretrain}_data_aug_{data_aug}.csv"
            )
            drawer.draw(linear_meters.summary())
    def linear_retraining(self, conv_name: str, lr=1e-3):
        """
        Calling point to execute retraining
        :param conv_name:
        :return:
        """
        print(f"conv_name: {conv_name}, feature extracting..")

        def _linear_train_loop(train_loader, epoch):
            train_loader_ = tqdm_(train_loader)
            for batch_num, (feature, gt) in enumerate(train_loader_):
                feature, gt = feature.to(self.device), gt.to(self.device)
                pred = linearnet(feature)
                loss = self.criterion(pred, gt)
                linearOptim.zero_grad()
                loss.backward()
                linearOptim.step()
                linear_meters["train_loss"].add(loss.item())
                linear_meters["train_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "tra_acc": linear_meters["train_acc"].summary()["acc"],
                    "loss": linear_meters["train_loss"].summary()["mean"],
                }
                train_loader_.set_postfix(report_dict)

            print(f"  Training epoch {epoch}: {nice_dict(report_dict)} ")

        def _linear_eval_loop(val_loader, epoch) -> Tensor:
            val_loader_ = tqdm_(val_loader)
            for batch_num, (feature, gt) in enumerate(val_loader_):
                feature, gt = feature.to(self.device), gt.to(self.device)
                pred = linearnet(feature)
                linear_meters["val_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "val_acc": linear_meters["val_acc"].summary()["acc"]
                }
                val_loader_.set_postfix(report_dict)
            print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ")
            return linear_meters["val_acc"].summary()["acc"]

        # building training and validation set based on extracted features
        train_loader = dcp(self.val_loader)
        train_loader.dataset.datasets = (
            train_loader.dataset.datasets[0].datasets[0], )
        val_loader = dcp(self.val_loader)
        val_loader.dataset.datasets = (
            val_loader.dataset.datasets[0].datasets[1], )
        _, train_features, train_targets = self.feature_exactor(
            conv_name, train_loader)
        print(f"training_feature_shape: {train_features.shape}")
        train_features = train_features.view(train_features.size(0), -1)
        _, val_features, val_targets = self.feature_exactor(
            conv_name, val_loader)
        val_features = val_features.view(val_features.size(0), -1)
        print(f"val_feature_shape: {val_features.shape}")

        train_dataset = TensorDataset(train_features, train_targets)
        val_dataset = TensorDataset(val_features, val_targets)
        Train_DataLoader = DataLoader(train_dataset,
                                      batch_size=100,
                                      shuffle=True)
        Val_DataLoader = DataLoader(val_dataset, batch_size=100, shuffle=False)

        # network and optimization
        linearnet = LinearNet(num_features=train_features.size(1),
                              num_classes=self.model.arch_dict["output_k_B"])
        linearOptim = torch.optim.Adam(linearnet.parameters(), lr=lr)
        linearnet.to(self.device)

        # meters
        meter_config = {
            "train_loss": AverageValueMeter(),
            "train_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"]),
            "val_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"])
        }
        linear_meters = MeterInterface(meter_config)
        drawer = DrawCSV2(save_dir=self.save_dir,
                          save_name=f"retraining_from_{conv_name}.png",
                          columns_to_draw=[
                              "train_loss_mean", "train_acc_acc", "val_acc_acc"
                          ])
        for epoch in range(self.max_epoch):
            _linear_train_loop(Train_DataLoader, epoch)
            _ = _linear_eval_loop(Val_DataLoader, epoch)
            linear_meters.step()
            linear_meters.summary().to_csv(self.save_dir /
                                           f"retraining_from_{conv_name}.csv")
            drawer.draw(linear_meters.summary())
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     cloumns = super().__init_meters__()
     self.METERINTERFACE.register_new_meter("train_geo", AverageValueMeter())
     cloumns.insert(2, "train_geo_mean")
     return cloumns
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     columns = super().__init_meters__()
     self.METERINTERFACE.register_new_meter("train_cutout", AverageValueMeter())
     columns.insert(-1, "train_cutout_mean")
     return columns
示例#25
0
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {"rec_loss": AverageValueMeter()}
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return ["rec_loss_mean"]
示例#26
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     columns = super().__init_meters__()
     self.METERINTERFACE.register_new_meter("train_gaussian",
                                            AverageValueMeter())
     columns = ["train_gaussian_mean"] + columns
     return columns