def __init_meters__(self) -> List[Union[str, List[str]]]:
     meter_config = {
         "lr":
         AverageValueMeter(),
         "trloss":
         AverageValueMeter(),
         "trdice":
         SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
         "valloss":
         AverageValueMeter(),
         "valdice":
         SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
         "valbdice":
         BatchDiceMeter(C=self.model.arch_dict["num_classes"],
                        report_axises=self.axis),
     }
     self.METERINTERFACE = MeterInterface(meter_config)
     return [
         "trloss_mean",
         ["trdice_DSC1", "trdice_DSC2", "trdice_DSC3"],
         "valloss_mean",
         ["valdice_DSC1", "valdice_DSC2", "valdice_DSC3"],
         ["valbdice_DSC1", "valbdice_DSC2", "valbdice_DSC3"],
         "lr_mean",
     ]
Exemple #2
0
class TestHaussdorffDistance(TestCase):
    def setUp(self) -> None:
        super().setUp()
        C = 3
        meter_config = {
            "hd_meter": HaussdorffDistance(C=C),
            "s_dice": SliceDiceMeter(C=C, report_axises=[1, 2]),
            "b_dice": BatchDiceMeter(C=C, report_axises=[1, 2]),
        }
        self.meter = MeterInterface(meter_config)

    def test_batch_case(self):
        print(self.meter.summary())

        for _ in range(5):
            for _ in range(10):
                pred = torch.randn(4, 3, 256, 256)
                label = torch.randint(0, 3, (4, 256, 256))

                pred_onehot = logit2one_hot(pred)
                label_onehot = class2one_hot(label, C=3)
                self.meter.hd_meter.add(pred_onehot, label_onehot)
                self.meter.s_dice.add(pred, label)
                self.meter.b_dice.add(pred, label)
            self.meter.step()
            print(self.meter.summary())
Exemple #3
0
 def setUp(self) -> None:
     super().setUp()
     self._meter_config = {
         "avg1": AverageValueMeter(),
         "dice1": SliceDiceMeter(C=2),
         "dice2": SliceDiceMeter(C=2),
     }
     self.meters = MeterInterface(self._meter_config)
Exemple #4
0
    def setUp(self) -> None:
        config = {"avg": AveragewithStd()}
        self.METER = MeterInterface(config)
        columns_to_draw = [["avg_mean", "avg_lstd", "avg_hstd"]]
        from pathlib import Path

        self.drawer = DrawCSV2(columns_to_draw=columns_to_draw,
                               save_dir=Path(__file__).parent)
Exemple #5
0
 def setUp(self) -> None:
     super().setUp()
     C = 3
     meter_config = {
         "hd_meter": HaussdorffDistance(C=C),
         "s_dice": SliceDiceMeter(C=C, report_axises=[1, 2]),
         "b_dice": BatchDiceMeter(C=C, report_axises=[1, 2]),
     }
     self.meter = MeterInterface(meter_config)
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     meter_config = {
         "lr": AverageValueMeter(),
         "traloss": AverageValueMeter(),
         "traconf": ConfusionMatrix(self.model.torchnet.num_classes),
         "valloss": AverageValueMeter(),
         "valconf": ConfusionMatrix(self.model.torchnet.num_classes),
     }
     self.METERINTERFACE = MeterInterface(meter_config)
     return ["traloss_mean", "traconf_acc", "valloss_mean", "valconf_acc", "lr_mean"]
Exemple #7
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     """
     basic meters to record clustering results, specifically for multi-subheads.
     :return:
     """
     METER_CONFIG = {
         "val_average_acc": AverageValueMeter(),
         "val_best_acc": AverageValueMeter(),
         "val_worst_acc": AverageValueMeter(),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [[
         "val_average_acc_mean", "val_best_acc_mean", "val_worst_acc_mean"
     ]]
Exemple #8
0
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {
         'tra_reg_total': AverageValueMeter(),
         'tra_sup_label': AverageValueMeter(),
         'tra_sup_mixup': AverageValueMeter(),
         'grl': AverageValueMeter(),
         'tra_cls': AverageValueMeter(),
         'tra_conf': ConfusionMatrix(num_classes=10),
         'val_conf': ConfusionMatrix(num_classes=10)
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [
         'tra_reg_total_mean', 'tra_sup_label_mean', 'tra_sup_mixup_mean',
         'tra_cls_mean', 'tra_conf_acc', 'val_conf_acc', 'grl_mean'
     ]
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {
         "train_mi": AverageValueMeter(),
         "train_entropy": AverageValueMeter(),
         "train_centropy": AverageValueMeter(),
         "train_sat": AverageValueMeter(),
         "val_best_acc": AverageValueMeter(),
         "val_average_acc": AverageValueMeter(),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [
         "train_mi_mean",
         "train_entropy_mean",
         "train_centropy_mean",
         "train_sat_mean",
         "val_average_acc_mean",
         "val_best_acc_mean",
     ]
Exemple #10
0
 def __init_meters__(self) -> List[Union[str, List[str]]]:
     METER_CONFIG = {
         "lr": InstanceValue(),
         "train_loss": AverageValueMeter(),
         "val_loss": AverageValueMeter(),
         "train_acc": ConfusionMatrix(10),
         "val_acc": ConfusionMatrix(10),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)  # type:ignore
     return [
         ["train_loss_mean", "val_loss_mean"],
         ["train_acc_acc", "val_acc_acc"],
         "lr_value",
     ]
Exemple #11
0
class TestDrawAverageWithSTD(TestCase):
    """
    This is to test the plotting of mean and std of a list of varying scalars
    """
    def setUp(self) -> None:
        config = {"avg": AveragewithStd()}
        self.METER = MeterInterface(config)
        columns_to_draw = [["avg_mean", "avg_lstd", "avg_hstd"]]
        from pathlib import Path

        self.drawer = DrawCSV2(columns_to_draw=columns_to_draw,
                               save_dir=Path(__file__).parent)

    def _train_loop(self, data, epoch):
        for i in data:
            self.METER["avg"].add(i)

        time.sleep(0.1)

    def test_torch(self):
        for i in range(100):
            data = torch.randn(10, 1) / (i + 1)
            self._train_loop(data, i)
            self.METER.step()
            summary = self.METER.summary()
            self.drawer.draw(summary)

    def test_numpy(self):
        for i in range(100):
            data = np.random.randn(10, 1) / (i + 1)
            self._train_loop(data, i)
            self.METER.step()
            summary = self.METER.summary()
            self.drawer.draw(summary)

    def test_list(self):
        for i in range(100):
            data = (np.random.randn(10, 1) / (i + 1)).squeeze().tolist()
            self._train_loop(data, i)
            self.METER.step()
            summary = self.METER.summary()
            self.drawer.draw(summary)
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {
         "train_adv_loss": AverageValueMeter(),
         "train_sat_loss": AverageValueMeter(),
         "train_mi_loss": AverageValueMeter(),
         "val_avg_acc": AverageValueMeter(),
         "val_best_acc": AverageValueMeter(),
         "val_worst_acc": AverageValueMeter(),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [
         "train_mi_loss_mean",
         "train_sat_loss_mean",
         "train_adv_loss_mean",
         ["val_avg_acc_mean", "val_best_acc_mean", "val_worst_acc_mean"],
     ]
 def test_meter_interface(self):
     meterinterface = MeterInterface(
         {"avg1": AverageValueMeter(), "dice1": SliceDiceMeter()}
     )
     print(meterinterface.summary())
     for epoch in range(10):
         if epoch == 2:
             meterinterface.register_meter("avg2", AverageValueMeter())
         for i in range(10):
             meterinterface["avg1"].add(1)
             meterinterface["dice1"].add(
                 torch.randn(1, 4, 224, 224), torch.randint(0, 4, size=(1, 224, 224))
             )
             try:
                 meterinterface["avg2"].add(2)
             except:
                 pass
         meterinterface.step()
     print(meterinterface.summary())
Exemple #14
0
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {
         "train_head_A": AverageValueMeter(),
         "train_head_B": AverageValueMeter(),
         "train_adv_A": AverageValueMeter(),
         "train_adv_B": AverageValueMeter(),
         "val_average_acc": AverageValueMeter(),
         "val_best_acc": AverageValueMeter(),
     }
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return [
         "train_head_A_mean",
         "train_head_B_mean",
         "train_adv_A_mean",
         "train_adv_B_mean",
         "val_average_acc_mean",
         "val_best_acc_mean",
     ]
Exemple #15
0
class TestDataFrameDrawer(TestCase):
    def setUp(self) -> None:
        super().setUp()
        self._meter_config = {
            "avg1": AverageValueMeter(),
            "dice1": SliceDiceMeter(C=2),
            "dice2": SliceDiceMeter(C=2),
        }
        self.meters = MeterInterface(self._meter_config)

    def _train_loop(self):
        for i in range(2):
            scalar1 = np.random.rand()
            self.meters["avg1"].add(scalar1)
            img_pred = torch.randn(1, 2, 100, 100)
            img_gt = torch.randint(0, 2, (1, 100, 100))
            self.meters["dice1"].add(img_pred, img_gt)
            self.meters["dice2"].add(img_pred, img_gt)

    def test_plot(self):
        self.drawer = DataFrameDrawer(self.meters, save_dir="./", save_name=save_name1)
        self.meters.reset()
        for epoch in range(5):
            self._train_loop()
            self.meters.step()
            with TimeBlock() as timer:
                self.drawer()
            print(timer.cost)

    def test_single_line_plot(self):
        self.drawer = DataFrameDrawer(self.meters, save_dir="./", save_name=save_name2)
        self.meters.reset()
        self.drawer.set_callback("dice2", singleline_plot())
        for epoch in range(5):
            self._train_loop()
            self.meters.step()
            with TimeBlock() as timer:
                self.drawer()
            print(timer.cost)
class IMSAT_Trainer(_Trainer):
    """
    MI(x,p) + CE(p,adv(p)) or MI(x,p) + CE(p,geom(p))
    """
    def __init__(
        self,
        model: Model,
        train_loader: DataLoader,
        val_loader: DataLoader,
        max_epoch: int = 100,
        save_dir: str = "IMSAT",
        use_vat: bool = False,
        sat_weight: float = 0.1,
        checkpoint_path: str = None,
        device="cpu",
        config: dict = None,
        **kwargs,
    ) -> None:
        super().__init__(
            model,
            train_loader,
            val_loader,
            max_epoch,
            save_dir,
            checkpoint_path,
            device,
            config,
            **kwargs,
        )
        self.use_vat = use_vat
        self.sat_weight = float(sat_weight)
        self.criterion = MultualInformaton_IMSAT(mu=4, separate_return=True)
        self.jsd = JSD_div()
        self.kl = KL_div()
        plt.ion()

    def __init_meters__(self) -> List[str]:
        METER_CONFIG = {
            "train_mi": AverageValueMeter(),
            "train_entropy": AverageValueMeter(),
            "train_centropy": AverageValueMeter(),
            "train_sat": AverageValueMeter(),
            "val_best_acc": AverageValueMeter(),
            "val_average_acc": AverageValueMeter(),
        }
        self.METERINTERFACE = MeterInterface(METER_CONFIG)
        return [
            "train_mi_mean",
            "train_entropy_mean",
            "train_centropy_mean",
            "train_sat_mean",
            "val_average_acc_mean",
            "val_best_acc_mean",
        ]

    @property
    def _training_report_dict(self):
        report_dict = dict_filter(
            flatten_dict({
                "train_MI":
                self.METERINTERFACE.train_mi.summary()["mean"],
                "train_entropy":
                self.METERINTERFACE.train_entropy.summary()["mean"],
                "train_centropy":
                self.METERINTERFACE.train_centropy.summary()["mean"],
                "train_sat":
                self.METERINTERFACE.train_sat.summary()["mean"],
            }),
            lambda k, v: v != 0.0,
        )
        return report_dict

    @property
    def _eval_report_dict(self):
        report_dict = flatten_dict({
            "average_acc":
            self.METERINTERFACE.val_average_acc.summary()["mean"],
            "best_acc":
            self.METERINTERFACE.val_best_acc.summary()["mean"],
        })
        return report_dict

    def start_training(self):
        for epoch in range(self._start_epoch, self.max_epoch):
            self._train_loop(train_loader=self.train_loader, epoch=epoch)
            with torch.no_grad():
                current_score = self._eval_loop(self.val_loader, epoch)
            self.METERINTERFACE.step()
            self.model.schedulerStep()
            # save meters and checkpoints
            for k, v in self.METERINTERFACE.aggregated_meter_dict.items():
                v.summary().to_csv(self.save_dir / f"meters/{k}.csv")
            self.METERINTERFACE.summary().to_csv(self.save_dir /
                                                 self.wholemeter_filename)
            self.writer.add_scalars(
                "Scalars",
                self.METERINTERFACE.summary().iloc[-1].to_dict(),
                global_step=epoch,
            )
            self.drawer.call_draw()
            self.save_checkpoint(self.state_dict, epoch, current_score)

    def _train_loop(
        self,
        train_loader: DataLoader,
        epoch: int,
        mode=ModelMode.TRAIN,
        *args,
        **kwargs,
    ):
        super()._train_loop(*args, **kwargs)
        self.model.set_mode(mode)
        assert self.model.training
        _train_loader: tqdm = tqdm_(train_loader)
        for _batch_num, images_labels_indices in enumerate(_train_loader):
            images, labels, *_ = zip(*images_labels_indices)
            tf1_images = torch.cat(tuple(
                [images[0] for _ in range(images.__len__() - 1)]),
                                   dim=0).to(self.device)
            tf2_images = torch.cat(tuple(images[1:]), dim=0).to(self.device)
            pred_tf1_simplex = self.model(tf1_images)
            pred_tf2_simplex = self.model(tf2_images)
            assert simplex(pred_tf1_simplex[0]), pred_tf1_simplex
            assert simplex(pred_tf2_simplex[0]), pred_tf2_simplex
            total_loss = self._trainer_specific_loss(tf1_images, tf2_images,
                                                     pred_tf1_simplex,
                                                     pred_tf2_simplex)
            self.model.zero_grad()
            total_loss.backward()
            self.model.step()
            report_dict = self._training_report_dict
            _train_loader.set_postfix(report_dict)

        report_dict_str = ", ".join(
            [f"{k}:{v:.3f}" for k, v in report_dict.items()])
        print(f"  Training epoch: {epoch} : {report_dict_str}")

    def _eval_loop(self,
                   val_loader: DataLoader,
                   epoch: int,
                   mode=ModelMode.EVAL,
                   *args,
                   **kwargs) -> float:
        super(IMSAT_Trainer, self)._eval_loop(*args, **kwargs)
        self.model.set_mode(mode)
        assert not self.model.training
        _val_loader = tqdm_(val_loader)
        preds = torch.zeros(
            self.model.arch_dict["num_sub_heads"],
            val_loader.dataset.__len__(),
            dtype=torch.long,
            device=self.device,
        )
        probas = torch.zeros(
            self.model.arch_dict["num_sub_heads"],
            val_loader.dataset.__len__(),
            self.model.arch_dict["output_k"],
            dtype=torch.float,
            device=self.device,
        )
        gts = torch.zeros(val_loader.dataset.__len__(),
                          dtype=torch.long,
                          device=self.device)
        _batch_done = 0
        for _batch_num, images_labels_indices in enumerate(_val_loader):
            images, labels, *_ = zip(*images_labels_indices)
            images, labels = images[0].to(self.device), labels[0].to(
                self.device)
            pred = self.model(images)
            _bSlice = slice(_batch_done, _batch_done + images.shape[0])
            gts[_bSlice] = labels
            for subhead in range(pred.__len__()):
                preds[subhead][_bSlice] = pred[subhead].max(1)[1]
                probas[subhead][_bSlice] = pred[subhead]
            _batch_done += images.shape[0]
        assert _batch_done == val_loader.dataset.__len__(), _batch_done

        # record
        subhead_accs = []
        for subhead in range(self.model.arch_dict["num_sub_heads"]):
            reorder_pred, remap = hungarian_match(
                flat_preds=preds[subhead],
                flat_targets=gts,
                preds_k=self.model.arch_dict["output_k"],
                targets_k=self.model.arch_dict["output_k"],
            )
            _acc = flat_acc(reorder_pred, gts)
            subhead_accs.append(_acc)
            # record average acc
            self.METERINTERFACE.val_average_acc.add(_acc)
        self.METERINTERFACE.val_best_acc.add(max(subhead_accs))
        report_dict = self._eval_report_dict

        report_dict_str = ", ".join(
            [f"{k}:{v:.3f}" for k, v in report_dict.items()])
        print(f"Validating epoch: {epoch} : {report_dict_str}")
        return self.METERINTERFACE.val_best_acc.summary()["mean"]

    def _trainer_specific_loss(
        self,
        images: torch.Tensor,
        images_tf: torch.Tensor,
        pred: List[torch.Tensor],
        pred_tf: List[torch.Tensor],
    ) -> torch.Tensor:
        """
        to override
        :param pred:
        :param pred_tf:
        :return:
        """
        assert simplex(pred[0]), pred
        mi_losses, entropy_losses, centropy_losses = [], [], []
        for subhead_num in range(self.model.arch_dict["num_sub_heads"]):
            _mi_loss, (_entropy_loss,
                       _centropy_loss) = self.criterion(pred[subhead_num])
            mi_losses.append(-_mi_loss)
            entropy_losses.append(_entropy_loss)
            centropy_losses.append(_centropy_loss)
        mi_loss = sum(mi_losses) / len(mi_losses)
        entrop_loss = sum(entropy_losses) / len(entropy_losses)
        centropy_loss = sum(centropy_losses) / len(centropy_losses)

        self.METERINTERFACE["train_mi"].add(-mi_loss.item())
        self.METERINTERFACE["train_entropy"].add(entrop_loss.item())
        self.METERINTERFACE["train_centropy"].add(centropy_loss.item())

        sat_loss = torch.Tensor([0]).to(self.device)
        if self.sat_weight > 0:
            if not self.use_vat:
                # use transformation
                _sat_loss = list(
                    map(lambda p1, p2: self.kl(p2, p1.detach()), pred,
                        pred_tf))
                sat_loss = sum(_sat_loss) / len(_sat_loss)
            else:
                sat_loss, *_ = VATLoss_Multihead(xi=1, eps=10, prop_eps=0.1)(
                    self.model.torchnet, images)

        self.METERINTERFACE["train_sat"].add(sat_loss.item())

        total_loss = mi_loss + self.sat_weight * sat_loss
        return total_loss
Exemple #17
0
class AdaNetTrainer(_Trainer):
    RUN_PATH = str(Path(PROJECT_PATH) / 'runs')
    ARCHIVE_PATH = str(Path(PROJECT_PATH) / 'archives')

    def __init__(self,
                 model: Model,
                 labeled_loader: DataLoader,
                 unlabeled_loader: DataLoader,
                 val_loader: DataLoader,
                 max_epoch: int = 100,
                 grl_scheduler: CustomScheduler = None,
                 epoch_decay_start: int = None,
                 save_dir: str = 'adanet',
                 checkpoint_path: str = None,
                 device='cpu',
                 config: dict = None,
                 **kwargs) -> None:
        super().__init__(model, None, val_loader, max_epoch, save_dir,
                         checkpoint_path, device, config, **kwargs)
        self.labeled_loader = labeled_loader
        self.unlabeled_loader = unlabeled_loader
        self.kl_criterion = KL_div()
        self.ce_loss = nn.CrossEntropyLoss()
        self.beta_distr: Beta = Beta(torch.tensor([1.0]), torch.tensor([1.0]))
        self.grl_scheduler = grl_scheduler
        self.grl_scheduler.epoch = self._start_epoch
        self.epoch_decay_start = int(
            epoch_decay_start) if epoch_decay_start else None

    def __init_meters__(self) -> List[str]:
        METER_CONFIG = {
            'tra_reg_total': AverageValueMeter(),
            'tra_sup_label': AverageValueMeter(),
            'tra_sup_mixup': AverageValueMeter(),
            'grl': AverageValueMeter(),
            'tra_cls': AverageValueMeter(),
            'tra_conf': ConfusionMatrix(num_classes=10),
            'val_conf': ConfusionMatrix(num_classes=10)
        }
        self.METERINTERFACE = MeterInterface(METER_CONFIG)
        return [
            'tra_reg_total_mean', 'tra_sup_label_mean', 'tra_sup_mixup_mean',
            'tra_cls_mean', 'tra_conf_acc', 'val_conf_acc', 'grl_mean'
        ]

    @property
    def _training_report_dict(self):
        return {
            'tra_sup_l': self.METERINTERFACE.tra_sup_label.summary()['mean'],
            'tra_sup_m': self.METERINTERFACE.tra_sup_mixup.summary()['mean'],
            'tra_cls': self.METERINTERFACE.tra_cls.summary()['mean'],
            'tra_acc': self.METERINTERFACE.tra_conf.summary()['acc']
        }

    @property
    def _eval_report_dict(self):
        return flatten_dict({'val': self.METERINTERFACE.val_conf.summary()},
                            sep='_')

    def start_training(self):
        for epoch in range(self._start_epoch, self.max_epoch):

            # copy as the original work
            if self.epoch_decay_start:
                if epoch > self.epoch_decay_start:
                    decayed_lr = (self.max_epoch -
                                  epoch) * self.model.optim_dict['lr'] / (
                                      self.max_epoch - self.epoch_decay_start)
                    self.model.optimizer.lr = decayed_lr
                    self.model.optimizer.betas = (0.5, 0.999)

            self._train_loop(
                labeled_loader=self.labeled_loader,
                unlabeled_loader=self.unlabeled_loader,
                epoch=epoch,
            )
            with torch.no_grad():
                current_score = self._eval_loop(self.val_loader, epoch)
            self.METERINTERFACE.step()
            self.model.schedulerStep()
            self.grl_scheduler.step()
            # save meters and checkpoints
            Summary = self.METERINTERFACE.summary()
            Summary.to_csv(self.save_dir / self.wholemeter_filename)
            self.drawer.draw(Summary)
            self.model.torchnet.lambd = self.grl_scheduler.value
            self.save_checkpoint(self.state_dict, epoch, current_score)

    def _train_loop(self,
                    labeled_loader: DataLoader = None,
                    unlabeled_loader: DataLoader = None,
                    epoch: int = 0,
                    mode=ModelMode.TRAIN,
                    *args,
                    **kwargs):
        super(AdaNetTrainer, self)._train_loop(*args, **kwargs)  # warnings
        self.model.set_mode(mode)
        assert self.model.training

        labeled_loader_ = DataIter(labeled_loader)
        unlabeled_loader_ = DataIter(unlabeled_loader)
        batch_num: tqdm = tqdm_(range(unlabeled_loader.__len__()))

        for _batch_num, ((label_img, label_gt), (unlabel_img, _),
                         _) in enumerate(
                             zip(labeled_loader_, unlabeled_loader_,
                                 batch_num)):
            label_img, label_gt, unlabel_img = label_img.to(self.device), \
                                               label_gt.to(self.device), unlabel_img.to(self.device)

            label_pred, _ = self.model(label_img)
            self.METERINTERFACE.tra_conf.add(label_pred.max(1)[1], label_gt)
            sup_loss = self.ce_loss(label_pred, label_gt.squeeze())
            self.METERINTERFACE.tra_sup_label.add(sup_loss.item())

            reg_loss = self._trainer_specific_loss(label_img, label_gt,
                                                   unlabel_img)
            self.METERINTERFACE.tra_reg_total.add(reg_loss.item())
            with ZeroGradientBackwardStep(sup_loss + reg_loss,
                                          self.model) as loss:
                loss.backward()
            report_dict = self._training_report_dict
            batch_num.set_postfix(report_dict)
        print(f'  Training epoch {epoch}: {nice_dict(report_dict)}')

    def _eval_loop(self,
                   val_loader: DataLoader = None,
                   epoch: int = 0,
                   mode=ModelMode.EVAL,
                   *args,
                   **kwargs) -> float:
        self.model.set_mode(mode)
        assert not self.model.training
        val_loader_: tqdm = tqdm_(val_loader)

        for _batch_num, (img, label) in enumerate(val_loader_):
            img, label = img.to(self.device), label.to(self.device)
            pred, _ = self.model(img)
            self.METERINTERFACE.val_conf.add(pred.max(1)[1], label)
            report_dict = self._eval_report_dict
            val_loader_.set_postfix(report_dict)
        print(f'Validating epoch {epoch}: {nice_dict(report_dict)}')

        return self.METERINTERFACE.val_conf.summary()['acc']

    def _trainer_specific_loss(self, label_img, label_gt, unlab_img, *args,
                               **kwargs):
        super(AdaNetTrainer, self)._trainer_specific_loss(*args,
                                                          **kwargs)  # warning
        assert label_img.shape == unlab_img.shape, f"Shapes of labeled and unlabeled images should be the same," \
            f"given {label_img.shape} and {unlab_img.shape}."
        self.model.eval()
        with torch.no_grad():
            pseudo_label = self.model.torchnet(unlab_img)[0]
        self.model.train()
        mixup_img, mixup_label, mix_indice = self._mixup(
            label_img,
            class2one_hot(label_gt.unsqueeze(dim=1).unsqueeze(dim=2),
                          10).squeeze().float(), unlab_img, pseudo_label)

        pred, cls = self.model(mixup_img)
        assert simplex(pred) and simplex(cls)
        reg_loss1 = self.kl_criterion(pred, mixup_label)
        adv_loss = self.kl_criterion(cls, mix_indice)
        self.METERINTERFACE.tra_sup_mixup.add(reg_loss1.item())
        self.METERINTERFACE.tra_cls.add(adv_loss.item())
        self.METERINTERFACE.grl.add(self.grl_scheduler.value)

        # Discriminator
        return (reg_loss1 + adv_loss) * 0.1

    def _mixup(self, label_img: torch.Tensor, label_onehot: torch.Tensor,
               unlab_img: torch.Tensor, unlabeled_pred: torch.Tensor):
        assert label_img.shape == unlab_img.shape
        assert label_img.shape.__len__() == 4
        assert one_hot(label_onehot) and simplex(unlabeled_pred)
        assert label_onehot.shape == unlabeled_pred.shape
        bn, *shape = label_img.shape
        alpha = self.beta_distr.sample((bn, )).squeeze(1).to(self.device)
        _alpha = alpha.view(bn, 1, 1, 1).repeat(1, *shape)
        assert _alpha.shape == label_img.shape
        mixup_img = label_img * _alpha + unlab_img * (1 - _alpha)
        mixup_label = label_onehot * alpha.view(bn, 1) \
                      + unlabeled_pred * (1 - alpha).view(bn, 1)
        mixup_index = torch.stack([alpha, 1 - alpha], dim=1).to(self.device)

        assert mixup_img.shape == label_img.shape
        assert mixup_label.shape == label_onehot.shape
        assert mixup_index.shape[0] == bn
        assert simplex(mixup_index)

        return mixup_img, mixup_label, mixup_index
Exemple #18
0
class ClusteringGeneralTrainer(_Trainer):
    # project save dirs for training statistics
    RUN_PATH = str(Path(__file__).parent.parent / "runs")
    ARCHIVE_PATH = str(Path(__file__).parent.parent / "archives")

    def __init__(
        self,
        model: Model,
        train_loader_A: DataLoader,
        train_loader_B: DataLoader,
        val_loader: DataLoader,
        criterion: nn.Module = None,
        max_epoch: int = 100,
        save_dir: str = "ClusteringGeneralTrainer",
        checkpoint_path: str = None,
        device="cpu",
        head_control_params: Dict[str, int] = {"B": 1},
        use_sobel: bool = False,  # both IIC and IMSAT may need this sobel filter
        config: dict = None,
        **kwargs,
    ) -> None:
        super().__init__(
            model,
            None,
            val_loader,
            max_epoch,
            save_dir,
            checkpoint_path,
            device,
            config,
            **kwargs,
        )  # type: ignore
        assert (self.train_loader is None
                ), self.train_loader  # discard the original self.train_loader
        self.train_loader_A = train_loader_A  # trainer for head_A
        self.train_loader_B = train_loader_B  # trainer for head B
        self.head_control_params: OrderedDict = OrderedDict(
            head_control_params)
        assert criterion, criterion
        self.criterion = criterion
        self.criterion.to(self.device)
        self.use_sobel = use_sobel
        if self.use_sobel:
            self.sobel = SobelProcess(include_origin=False)
            self.sobel.to(
                self.device)  # sobel filter return a tensor (bn, 1, w, h)

    def __init_meters__(self) -> List[Union[str, List[str]]]:
        """
        basic meters to record clustering results, specifically for multi-subheads.
        :return:
        """
        METER_CONFIG = {
            "val_average_acc": AverageValueMeter(),
            "val_best_acc": AverageValueMeter(),
            "val_worst_acc": AverageValueMeter(),
        }
        self.METERINTERFACE = MeterInterface(METER_CONFIG)
        return [[
            "val_average_acc_mean", "val_best_acc_mean", "val_worst_acc_mean"
        ]]

    @property
    def _training_report_dict(self) -> Dict[str, float]:
        return {}  # to override

    @property
    def _eval_report_dict(self) -> Dict[str, float]:
        """
        return validation report dict
        :return:
        """
        report_dict = {
            "average_acc":
            self.METERINTERFACE.val_average_acc.summary()["mean"],
            "best_acc": self.METERINTERFACE.val_best_acc.summary()["mean"],
            "worst_acc": self.METERINTERFACE.val_worst_acc.summary()["mean"],
        }
        report_dict = dict_filter(report_dict)
        return report_dict

    def start_training(self):
        """
        main function to call for training
        :return:
        """
        for epoch in range(self._start_epoch, self.max_epoch):
            self._train_loop(
                train_loader_A=self.train_loader_A,
                train_loader_B=self.train_loader_B,
                epoch=epoch,
                head_control_param=self.head_control_params,
            )
            with torch.no_grad():
                current_score = self._eval_loop(self.val_loader, epoch)

            # update meters
            self.METERINTERFACE.step()
            # update model scheduler
            self.model.schedulerStep()
            # save meters and checkpoints
            SUMMARY = self.METERINTERFACE.summary()
            SUMMARY.to_csv(self.save_dir / f"wholeMeter.csv")
            # draw traing curves
            self.drawer.draw(SUMMARY)
            # save last.pth and/or best.pth based on current_score
            self.save_checkpoint(self.state_dict(), epoch, current_score)
        # close tf.summary_writer
        time.sleep(3)
        self.writer.close()

    def _train_loop(
        self,
        train_loader_A: DataLoader = None,
        train_loader_B: DataLoader = None,
        epoch: int = None,
        mode: ModelMode = ModelMode.TRAIN,
        head_control_param: OrderedDict = None,
        *args,
        **kwargs,
    ) -> None:
        """
        :param train_loader_A:
        :param train_loader_B:
        :param epoch:
        :param mode:
        :param head_control_param:
        :param args:
        :param kwargs:
        :return: None
        """
        # robustness asserts
        assert isinstance(train_loader_B, DataLoader) and isinstance(
            train_loader_A, DataLoader)
        assert (head_control_param and head_control_param.__len__() > 0), \
            f"`head_control_param` must be provided, given {head_control_param}."
        assert set(head_control_param.keys()) <= {"A", "B", }, \
            f"`head_control_param` key must be in `A` or `B`, given {set(head_control_param.keys())}"
        for k, v in head_control_param.items():
            assert k in ("A", "B"), (
                f"`head_control_param` key must be in `A` or `B`,"
                f" given{set(head_control_param.keys())}")
            assert isinstance(
                v, int) and v >= 0, f"Iteration for {k} must be >= 0."
        # set training mode
        self.model.set_mode(mode)
        assert (
            self.model.training
        ), f"Model should be in train() model, given {self.model.training}."
        assert len(train_loader_B) == len(train_loader_A), (
            f'The length of the train_loaders should be the same,"'
            f"given `len(train_loader_A)`:{len(train_loader_A)} and `len(train_loader_B)`:{len(train_loader_B)}."
        )

        for head_name, head_iterations in head_control_param.items():
            assert head_name in ("A", "B"), head_name
            train_loader = eval(f"train_loader_{head_name}"
                                )  # change the dataset for different head
            for head_epoch in range(head_iterations):
                # given one head, one iteration in this head, and one train_loader.
                train_loader_: tqdm = tqdm_(
                    train_loader)  # reinitialize the train_loader
                train_loader_.set_description(
                    f"Training epoch: {epoch} head:{head_name}, head_epoch:{head_epoch + 1}/{head_iterations}"
                )
                for batch, image_labels in enumerate(train_loader_):
                    images, *_ = list(zip(*image_labels))
                    # extract tf1_images, tf2_images and put then to self.device
                    tf1_images = torch.cat(tuple(
                        [images[0] for _ in range(len(images) - 1)]),
                                           dim=0).to(self.device)
                    tf2_images = torch.cat(tuple(images[1:]),
                                           dim=0).to(self.device)
                    assert tf1_images.shape == tf2_images.shape, f"`tf1_images` should have the same size as `tf2_images`," \
                        f"given {tf1_images.shape} and {tf2_images.shape}."
                    # if images are processed with sobel filters
                    if self.use_sobel:
                        tf1_images = self.sobel(tf1_images)
                        tf2_images = self.sobel(tf2_images)
                        assert tf1_images.shape == tf2_images.shape
                    # Here you have two kinds of geometric transformations
                    # todo: functions to be overwritten
                    batch_loss = self._trainer_specific_loss(
                        tf1_images, tf2_images, head_name)
                    # update model with self-defined context manager support Apex module
                    with ZeroGradientBackwardStep(batch_loss,
                                                  self.model) as loss:
                        loss.backward()
                    # write value to tqdm module for system monitoring
                    report_dict = self._training_report_dict
                    train_loader_.set_postfix(report_dict)
        # for tensorboard recording
        self.writer.add_scalar_with_tag("train", report_dict, epoch)
        # for std recording
        print(f"Training epoch: {epoch} : {nice_dict(report_dict)}")

    def _eval_loop(
        self,
        val_loader: DataLoader = None,
        epoch: int = 0,
        mode: ModelMode = ModelMode.EVAL,
        return_soft_predict=False,
        *args,
        **kwargs,
    ) -> float:
        assert isinstance(
            val_loader, DataLoader)  # make sure a validation loader is passed.
        self.model.set_mode(mode)  # set model to be eval mode, by default.
        # make sure the model is in eval mode.
        assert (
            not self.model.training
        ), f"Model should be in eval model in _eval_loop, given {self.model.training}."
        val_loader_: tqdm = tqdm_(val_loader)
        # prediction initialization with shape: (num_sub_heads, num_samples)
        preds = torch.zeros(self.model.arch_dict["num_sub_heads"],
                            val_loader.dataset.__len__(),
                            dtype=torch.long,
                            device=self.device)
        # soft_prediction initialization with shape (num_sub_heads, num_sample, num_classes)
        if return_soft_predict:
            soft_preds = torch.zeros(
                self.model.arch_dict["num_sub_heads"],
                val_loader.dataset.__len__(),
                self.model.arch_dict["output_k_B"],
                dtype=torch.float,
                device=torch.device("cpu"))  # I put it into cpu
        # target initialization with shape: (num_samples)
        target = torch.zeros(val_loader.dataset.__len__(),
                             dtype=torch.long,
                             device=self.device)
        # begin index
        slice_done = 0
        subhead_accs = []
        val_loader_.set_description(f"Validating epoch: {epoch}")
        for batch, image_labels in enumerate(val_loader_):
            images, gt, *_ = list(zip(*image_labels))
            # only take the tf3 image and gts, put them to self.device
            images, gt = images[0].to(self.device), gt[0].to(self.device)
            # if use sobel filter
            if self.use_sobel:
                images = self.sobel(images)
            # using default head_B for inference, _pred should be a list of simplex by default.
            _pred = self.model.torchnet(images, head="B")
            assert assert_list(simplex,
                               _pred), "pred should be a list of simplexes."
            assert _pred.__len__() == self.model.arch_dict["num_sub_heads"]
            # slice window definition
            bSlicer = slice(slice_done, slice_done + images.shape[0])
            for subhead in range(self.model.arch_dict["num_sub_heads"]):
                # save predictions for each subhead for each batch
                preds[subhead][bSlicer] = _pred[subhead].max(1)[1]
                if return_soft_predict:
                    soft_preds[subhead][bSlicer] = _pred[subhead]
            # save target for each batch
            target[bSlicer] = gt
            # update slice index
            slice_done += gt.shape[0]
        # make sure that all the dataset has been done. Errors will raise if dataloader.drop_last=True
        assert slice_done == val_loader.dataset.__len__(
        ), "Slice not completed."
        for subhead in range(self.model.arch_dict["num_sub_heads"]):
            # remap pred for each head and compare with target to get subhead_acc
            reorder_pred, remap = hungarian_match(
                flat_preds=preds[subhead],
                flat_targets=target,
                preds_k=self.model.arch_dict["output_k_B"],
                targets_k=self.model.arch_dict["output_k_B"],
            )
            _acc = flat_acc(reorder_pred, target)
            subhead_accs.append(_acc)
            # record average acc
            self.METERINTERFACE.val_average_acc.add(_acc)

            if return_soft_predict:
                soft_preds[subhead][:, list(remap.values(
                ))] = soft_preds[subhead][:, list(remap.keys())]
                assert torch.allclose(soft_preds[subhead].max(1)[1],
                                      reorder_pred.cpu())

        # record best acc
        self.METERINTERFACE.val_best_acc.add(max(subhead_accs))
        # record worst acc
        self.METERINTERFACE.val_worst_acc.add(min(subhead_accs))
        report_dict = self._eval_report_dict
        # record results for std
        print(f"Validating epoch: {epoch} : {nice_dict(report_dict)}")
        # record results for tensorboard
        self.writer.add_scalar_with_tag("val", report_dict, epoch)
        # using multithreads to call histogram interface of tensorboard.
        pred_histgram(self.writer, preds, epoch=epoch)
        # return the current score to save the best checkpoint.
        if return_soft_predict:
            return self.METERINTERFACE.val_best_acc.summary()["mean"], (
                target.cpu(), soft_preds[np.argmax(subhead_accs)]
            )  # type ignore

        return self.METERINTERFACE.val_best_acc.summary()["mean"]

    def _trainer_specific_loss(self, tf1_images: Tensor, tf2_images: Tensor,
                               head_name: str):
        """
        functions to be overrided
        :param tf1_images: basic transformed images with device = self.device
        :param tf2_images: advanced transformed image with device = self.device
        :param head_name: head name for model inference
        :return: loss tensor to call .backward()
        """

        raise NotImplementedError
 def test_Initialize_MeterInterface(self):
     Meter = MeterInterface(meter_config=self.meter_config)
     for epoch in range(3):
         self._training_loop(Meter)
         Meter.step()
         print(Meter.summary())
class SemiSegTrainer(_Trainer):
    @lazy_load_checkpoint
    def __init__(
            self,
            model: Model,
            labeled_loader: DataLoader,
            unlabeled_loader: DataLoader,
            val_loader: DataLoader,
            max_epoch: int = 100,
            save_dir: str = "base",
            checkpoint_path: str = None,
            device="cpu",
            config: dict = None,
            max_iter: int = 100,
            axis=(1, 2, 3),
            **kwargs,
    ) -> None:
        self.axis = axis
        super().__init__(
            model,
            None,
            val_loader,
            max_epoch,
            save_dir,
            checkpoint_path,
            device,
            config,
            **kwargs,
        )
        assert self.train_loader is None
        self.labeled_loader = labeled_loader
        self.unlabeled_loader = unlabeled_loader
        self.kl_criterion = KL_div()
        self.max_iter = max_iter

    def __init_meters__(self) -> List[Union[str, List[str]]]:
        meter_config = {
            "lr":
            AverageValueMeter(),
            "trloss":
            AverageValueMeter(),
            "trdice":
            SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                           report_axises=self.axis),
            "valloss":
            AverageValueMeter(),
            "valdice":
            SliceDiceMeter(C=self.model.arch_dict["num_classes"],
                           report_axises=self.axis),
            "valbdice":
            BatchDiceMeter(C=self.model.arch_dict["num_classes"],
                           report_axises=self.axis),
        }
        self.METERINTERFACE = MeterInterface(meter_config)
        return [
            "trloss_mean",
            ["trdice_DSC1", "trdice_DSC2", "trdice_DSC3"],
            "valloss_mean",
            ["valdice_DSC1", "valdice_DSC2", "valdice_DSC3"],
            ["valbdice_DSC1", "valbdice_DSC2", "valbdice_DSC3"],
            "lr_mean",
        ]

    def start_training(self):
        for epoch in range(self._start_epoch, self.max_epoch):
            self._train_loop(
                labeled_loader=self.labeled_loader,
                unlabeled_loader=self.unlabeled_loader,
                epoch=epoch,
            )
            with torch.no_grad():
                current_score = self._eval_loop(self.val_loader, epoch)
            self.METERINTERFACE.step()
            self.model.schedulerStep()
            # save meters and checkpoints
            SUMMARY = self.METERINTERFACE.summary()
            SUMMARY.to_csv(self.save_dir / self.wholemeter_filename)
            self.drawer.draw(SUMMARY)
            self.save_checkpoint(self.state_dict(), epoch, current_score)
        self.writer.close()

    def _train_loop(
        self,
        labeled_loader: DataLoader = None,
        unlabeled_loader: DataLoader = None,
        epoch: int = 0,
        mode=ModelMode.TRAIN,
        *args,
        **kwargs,
    ):
        self.model.set_mode(mode)
        labeled_loader = DataIter(labeled_loader)
        unlabeled_loader = DataIter(unlabeled_loader)
        _max_iter = tqdm_(range(self.max_iter))
        _max_iter.set_description(f"Training Epoch {epoch}")
        self.METERINTERFACE["lr"].add(self.model.get_lr()[0])
        for (
                batch_num,
            ((lab_img, lab_gt), lab_path),
            ((unlab_img, _), unlab_path),
        ) in zip(_max_iter, labeled_loader, unlabeled_loader):
            lab_img, lab_gt = lab_img.to(self.device), lab_gt.to(self.device)
            lab_preds = self.model(lab_img, force_simplex=True)
            sup_loss = self.kl_criterion(
                lab_preds,
                class2one_hot(lab_gt.squeeze(1),
                              C=self.model.arch_dict["num_classes"]).float(),
            )
            reg_loss = self._trainer_specific_loss(unlab_img)
            self.METERINTERFACE["trloss"].add(sup_loss.item())
            self.METERINTERFACE["trdice"].add(lab_preds, lab_gt)

            with ZeroGradientBackwardStep(sup_loss + reg_loss,
                                          self.model) as total_loss:
                total_loss.backward()
            report_dict = self._training_report_dict
            _max_iter.set_postfix(report_dict)
        print(f"Training Epoch {epoch}: {nice_dict(report_dict)}")
        self.writer.add_scalar_with_tag("train",
                                        report_dict,
                                        global_step=epoch)

    def _trainer_specific_loss(self, unlab_img: Tensor, **kwargs) -> Tensor:
        return torch.tensor(0, dtype=torch.float32, device=self.device)

    def _eval_loop(
        self,
        val_loader: DataLoader = None,
        epoch: int = 0,
        mode=ModelMode.EVAL,
        *args,
        **kwargs,
    ) -> float:
        self.model.set_mode(mode)
        _val_loader = tqdm_(val_loader)
        _val_loader.set_description(f"Validating Epoch {epoch}")
        for batch_num, ((val_img, val_gt), val_path) in enumerate(_val_loader):
            val_img, val_gt = val_img.to(self.device), val_gt.to(self.device)
            val_preds = self.model(val_img, force_simplex=True)
            val_loss = self.kl_criterion(
                val_preds,
                class2one_hot(val_gt.squeeze(1),
                              C=self.model.arch_dict["num_classes"]).float(),
                disable_assert=True,
            )
            self.METERINTERFACE["valloss"].add(val_loss.item())
            self.METERINTERFACE["valdice"].add(val_preds, val_gt)
            self.METERINTERFACE["valbdice"].add(val_preds, val_gt)
            report_dict = self._eval_report_dict
            _val_loader.set_postfix(report_dict)
        print(f"Validating Epoch {epoch}: {nice_dict(report_dict)}")
        self.writer.add_scalar_with_tag(tag="eval",
                                        tag_scalar_dict=report_dict,
                                        global_step=epoch)
        return self.METERINTERFACE["valbdice"].value()[0][0].item()

    @property
    def _training_report_dict(self):
        return flatten_dict(
            {
                "tra_loss": self.METERINTERFACE["trloss"].summary()["mean"],
                "": self.METERINTERFACE["trdice"].summary(),
                "lr": self.METERINTERFACE["lr"].summary()["mean"],
            },
            sep="_",
        )

    @property
    def _eval_report_dict(self):
        return flatten_dict(
            {
                "val_loss": self.METERINTERFACE["valloss"].summary()["mean"],
                "": self.METERINTERFACE["valdice"].summary(),
                "b": self.METERINTERFACE["valbdice"].summary(),
            },
            sep="",
        )
 def test_save_checkpoint_and_load(self):
     Meter1 = MeterInterface(meter_config=self.meter_config)
     for epoch in range(3):
         self._training_loop(Meter1)
         Meter1.step()
         print(Meter1.summary())
     meter1_dict = Meter1.state_dict()
     # print("Meter1 saved.")
     Meter2 = MeterInterface(meter_config=self.meter_config)
     Meter2.load_state_dict(meter1_dict)
     # print("Meter2 loaded")
     print(Meter2.summary())
     for epoch in range(5):
         self._training_loop(Meter2)
         Meter2.step()
         print(Meter2.summary())
Exemple #22
0
    def test_resume(self):

        meterinterface = MeterInterface({
            "avg1": AverageValueMeter(),
            "dice1": SliceDiceMeter()
        })
        meterinterface.step()
        meterinterface.step()
        meterinterface.step()

        for epoch in range(10):
            if epoch == 2:
                meterinterface.register_new_meter("avg2", AverageValueMeter())
            for i in range(10):
                meterinterface["avg1"].add(1)
                meterinterface["dice1"].add(
                    torch.randn(1, 4, 224, 224),
                    torch.randint(0, 4, size=(1, 224, 224)))
                try:
                    meterinterface["avg2"].add(2)
                except:
                    pass
            meterinterface.step()
        print(meterinterface.summary())
        state_dict = meterinterface.state_dict()

        meterinterface2 = MeterInterface({
            "avg1": AverageValueMeter(),
            "avg2": AverageValueMeter(),
            "dice1": SliceDiceMeter(),
            "avg3": AverageValueMeter(),
        })
        meterinterface2.load_state_dict(state_dict)

        for epoch in range(10):
            for i in range(10):
                meterinterface2["avg3"].add(1)
                meterinterface2["dice1"].add(
                    torch.randn(1, 4, 224, 224),
                    torch.randint(0, 4, size=(1, 224, 224)))
            meterinterface2.step()
        print(meterinterface2.summary())
    def linear_retraining(self, conv_name: str, lr=1e-3):
        """
        Calling point to execute retraining
        :param conv_name:
        :return:
        """
        print(f"conv_name: {conv_name}, feature extracting..")

        def _linear_train_loop(train_loader, epoch):
            train_loader_ = tqdm_(train_loader)
            for batch_num, (feature, gt) in enumerate(train_loader_):
                feature, gt = feature.to(self.device), gt.to(self.device)
                pred = linearnet(feature)
                loss = self.criterion(pred, gt)
                linearOptim.zero_grad()
                loss.backward()
                linearOptim.step()
                linear_meters["train_loss"].add(loss.item())
                linear_meters["train_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "tra_acc": linear_meters["train_acc"].summary()["acc"],
                    "loss": linear_meters["train_loss"].summary()["mean"],
                }
                train_loader_.set_postfix(report_dict)

            print(f"  Training epoch {epoch}: {nice_dict(report_dict)} ")

        def _linear_eval_loop(val_loader, epoch) -> Tensor:
            val_loader_ = tqdm_(val_loader)
            for batch_num, (feature, gt) in enumerate(val_loader_):
                feature, gt = feature.to(self.device), gt.to(self.device)
                pred = linearnet(feature)
                linear_meters["val_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "val_acc": linear_meters["val_acc"].summary()["acc"]
                }
                val_loader_.set_postfix(report_dict)
            print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ")
            return linear_meters["val_acc"].summary()["acc"]

        # building training and validation set based on extracted features
        train_loader = dcp(self.val_loader)
        train_loader.dataset.datasets = (
            train_loader.dataset.datasets[0].datasets[0], )
        val_loader = dcp(self.val_loader)
        val_loader.dataset.datasets = (
            val_loader.dataset.datasets[0].datasets[1], )
        _, train_features, train_targets = self.feature_exactor(
            conv_name, train_loader)
        print(f"training_feature_shape: {train_features.shape}")
        train_features = train_features.view(train_features.size(0), -1)
        _, val_features, val_targets = self.feature_exactor(
            conv_name, val_loader)
        val_features = val_features.view(val_features.size(0), -1)
        print(f"val_feature_shape: {val_features.shape}")

        train_dataset = TensorDataset(train_features, train_targets)
        val_dataset = TensorDataset(val_features, val_targets)
        Train_DataLoader = DataLoader(train_dataset,
                                      batch_size=100,
                                      shuffle=True)
        Val_DataLoader = DataLoader(val_dataset, batch_size=100, shuffle=False)

        # network and optimization
        linearnet = LinearNet(num_features=train_features.size(1),
                              num_classes=self.model.arch_dict["output_k_B"])
        linearOptim = torch.optim.Adam(linearnet.parameters(), lr=lr)
        linearnet.to(self.device)

        # meters
        meter_config = {
            "train_loss": AverageValueMeter(),
            "train_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"]),
            "val_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"])
        }
        linear_meters = MeterInterface(meter_config)
        drawer = DrawCSV2(save_dir=self.save_dir,
                          save_name=f"retraining_from_{conv_name}.png",
                          columns_to_draw=[
                              "train_loss_mean", "train_acc_acc", "val_acc_acc"
                          ])
        for epoch in range(self.max_epoch):
            _linear_train_loop(Train_DataLoader, epoch)
            _ = _linear_eval_loop(Val_DataLoader, epoch)
            linear_meters.step()
            linear_meters.summary().to_csv(self.save_dir /
                                           f"retraining_from_{conv_name}.csv")
            drawer.draw(linear_meters.summary())
Exemple #24
0
 def __init_meters__(self) -> List[str]:
     METER_CONFIG = {"rec_loss": AverageValueMeter()}
     self.METERINTERFACE = MeterInterface(METER_CONFIG)
     return ["rec_loss_mean"]
    def supervised_training(self, use_pretrain=True, lr=1e-3, data_aug=False):
        # load the best checkpoint
        self.load_checkpoint(
            torch.load(str(Path(self.checkpoint) / self.checkpoint_identifier),
                       map_location=torch.device("cpu")))
        self.model.to(self.device)

        from torchvision import transforms
        transform_train = transforms.Compose([
            pil_augment.CenterCrop(size=(20, 20)),
            pil_augment.Resize(size=(32, 32), interpolation=PIL.Image.NEAREST),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            pil_augment.Img2Tensor()
        ])
        transform_val = transforms.Compose([
            pil_augment.CenterCrop(size=(20, 20)),
            pil_augment.Resize(size=(32, 32), interpolation=PIL.Image.NEAREST),
            pil_augment.Img2Tensor()
        ])

        self.kl = KL_div(reduce=True)

        def _sup_train_loop(train_loader, epoch):
            self.model.train()
            train_loader_ = tqdm_(train_loader)
            for batch_num, (image_gt) in enumerate(train_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                loss = self.kl(pred, class2one_hot(gt, 10).float())
                self.model.zero_grad()
                loss.backward()
                self.model.step()
                linear_meters["train_loss"].add(loss.item())
                linear_meters["train_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "tra_acc": linear_meters["train_acc"].summary()["acc"],
                    "loss": linear_meters["train_loss"].summary()["mean"],
                }
                train_loader_.set_postfix(report_dict)

            print(f"  Training epoch {epoch}: {nice_dict(report_dict)} ")

        def _sup_eval_loop(val_loader, epoch) -> Tensor:
            self.model.eval()
            val_loader_ = tqdm_(val_loader)
            for batch_num, (image_gt) in enumerate(val_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                linear_meters["val_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "val_acc": linear_meters["val_acc"].summary()["acc"]
                }
                val_loader_.set_postfix(report_dict)
            print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ")
            return linear_meters["val_acc"].summary()["acc"]

            # building training and validation set based on extracted features

        train_loader = dcp(self.val_loader)
        train_loader.dataset.datasets = (
            train_loader.dataset.datasets[0].datasets[0], )
        val_loader = dcp(self.val_loader)
        val_loader.dataset.datasets = (
            val_loader.dataset.datasets[0].datasets[1], )

        if data_aug:
            train_loader.dataset.datasets[0].transform = transform_train
            val_loader.dataset.datasets[0].transform = transform_val

        # network and optimization
        if not use_pretrain:
            self.model.torchnet.apply(weights_init)
        else:
            self.model.torchnet.head_B.apply(weights_init)
            # wipe out the initialization
        self.model.optimizer = torch.optim.Adam(
            self.model.torchnet.parameters(), lr=lr)
        self.model.scheduler = torch.optim.lr_scheduler.StepLR(
            self.model.optimizer, step_size=50, gamma=0.2)

        # meters
        meter_config = {
            "train_loss": AverageValueMeter(),
            "train_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"]),
            "val_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"])
        }
        linear_meters = MeterInterface(meter_config)
        drawer = DrawCSV2(
            save_dir=self.save_dir,
            save_name=
            f"supervised_from_checkpoint_{use_pretrain}_data_aug_{data_aug}.png",
            columns_to_draw=[
                "train_loss_mean", "train_acc_acc", "val_acc_acc"
            ])
        for epoch in range(self.max_epoch):
            _sup_train_loop(train_loader, epoch)
            with torch.no_grad():
                _ = _sup_eval_loop(val_loader, epoch)
            self.model.step()
            linear_meters.step()
            linear_meters.summary().to_csv(
                self.save_dir /
                f"supervised_from_checkpoint_{use_pretrain}_data_aug_{data_aug}.csv"
            )
            drawer.draw(linear_meters.summary())