Пример #1
0
    "tf1": transforms.Compose([
        pil_augment.Img2Tensor(),
    ]),
    "tf2": transforms.Compose([
        pil_augment.RandomCrop(size=32, padding=2, ),
        pil_augment.Img2Tensor(),
    ]
    ),
    "tf3": transforms.Compose([
        pil_augment.Img2Tensor(),
    ]),
}
svhn_strong_transform = {
    # output size 32*32
    "tf1": transforms.Compose([
        pil_augment.CenterCrop(size=(28, 28)),
        pil_augment.Resize(size=32, interpolation=PIL.Image.BILINEAR),
        pil_augment.Img2Tensor()]),
    "tf2": transforms.Compose([pil_augment.RandomApply(
        transforms=[transforms.RandomRotation(degrees=(-25.0, 25.0), resample=False, expand=False)],
        p=0.5),
        pil_augment.RandomChoice(transforms=[
            pil_augment.RandomCrop(size=(20, 20), padding=None),
            pil_augment.RandomCrop(size=(24, 24), padding=None),
            pil_augment.RandomCrop(size=(28, 28), padding=None)]),
        pil_augment.Resize(size=32, interpolation=PIL.Image.BILINEAR),
        transforms.ColorJitter(
            brightness=[0.6, 1.4],
            contrast=[0.6, 1.4],
            saturation=[0.6, 1.4],
            hue=[-0.125, 0.125]),
Пример #2
0
        pil_augment.RandomHorizontalFlip(),
        pil_augment.RandomRotation(degrees=10),
        pil_augment.ToTensor(),
    ]),
    target_transform=pil_augment.Compose([
        pil_augment.Resize((256, 256)),
        pil_augment.RandomCrop((224, 224)),
        pil_augment.RandomHorizontalFlip(),
        pil_augment.RandomRotation(degrees=10),
        pil_augment.ToLabel(),
    ]),
    if_is_target=(False, True),
)
val_transform = SequentialWrapper(
    img_transform=pil_augment.Compose(
        [pil_augment.CenterCrop((224, 224)),
         pil_augment.ToTensor()]),
    target_transform=pil_augment.Compose(
        [pil_augment.CenterCrop((224, 224)),
         pil_augment.ToLabel()]),
    if_is_target=(False, True),
)
labeled_loader, unlabeled_loader, val_loader = data_handler.SemiSupervisedDataLoaders(
    labeled_transform=train_transforms,
    unlabeled_transform=train_transforms,
    val_transform=val_transform,
    group_labeled=True,
    group_unlabeled=False,
    group_val=True,
)
model = Model(
Пример #3
0
# ============================== public transform interface ================================
stl10_strong_transform = {
    "tf1":
    transforms.Compose([
        pil_augment.RandomCrop(size=(64, 64), padding=None),
        pil_augment.Resize(size=(64, 64), interpolation=0),
        pil_augment.Img2Tensor(include_grey=True, include_rgb=False),
    ]),
    "tf2":
    transforms.Compose([
        pil_augment.RandomCrop(size=(64, 64), padding=None),
        pil_augment.Resize(size=(64, 64), interpolation=0),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(
            brightness=[0.6, 1.4],
            contrast=[0.6, 1.4],
            saturation=[0.6, 1.4],
            hue=[-0.125, 0.125],
        ),
        pil_augment.Img2Tensor(include_grey=True, include_rgb=False),
    ]),
    "tf3":
    transforms.Compose([
        pil_augment.CenterCrop(size=(64, 64)),
        pil_augment.Resize(size=(64, 64), interpolation=0),
        pil_augment.Img2Tensor(include_grey=True, include_rgb=False),
    ]),
}
# ==========================================================================================
    def supervised_training(self, use_pretrain=True, lr=1e-3, data_aug=False):
        # load the best checkpoint
        self.load_checkpoint(
            torch.load(str(Path(self.checkpoint) / self.checkpoint_identifier),
                       map_location=torch.device("cpu")))
        self.model.to(self.device)

        from torchvision import transforms
        transform_train = transforms.Compose([
            pil_augment.CenterCrop(size=(20, 20)),
            pil_augment.Resize(size=(32, 32), interpolation=PIL.Image.NEAREST),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            pil_augment.Img2Tensor()
        ])
        transform_val = transforms.Compose([
            pil_augment.CenterCrop(size=(20, 20)),
            pil_augment.Resize(size=(32, 32), interpolation=PIL.Image.NEAREST),
            pil_augment.Img2Tensor()
        ])

        self.kl = KL_div(reduce=True)

        def _sup_train_loop(train_loader, epoch):
            self.model.train()
            train_loader_ = tqdm_(train_loader)
            for batch_num, (image_gt) in enumerate(train_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                loss = self.kl(pred, class2one_hot(gt, 10).float())
                self.model.zero_grad()
                loss.backward()
                self.model.step()
                linear_meters["train_loss"].add(loss.item())
                linear_meters["train_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "tra_acc": linear_meters["train_acc"].summary()["acc"],
                    "loss": linear_meters["train_loss"].summary()["mean"],
                }
                train_loader_.set_postfix(report_dict)

            print(f"  Training epoch {epoch}: {nice_dict(report_dict)} ")

        def _sup_eval_loop(val_loader, epoch) -> Tensor:
            self.model.eval()
            val_loader_ = tqdm_(val_loader)
            for batch_num, (image_gt) in enumerate(val_loader_):
                image, gt = zip(*image_gt)
                image = image[0].to(self.device)
                gt = gt[0].to(self.device)

                if self.use_sobel:
                    image = self.sobel(image)

                pred = self.model.torchnet(image)[0]
                linear_meters["val_acc"].add(pred.max(1)[1], gt)
                report_dict = {
                    "val_acc": linear_meters["val_acc"].summary()["acc"]
                }
                val_loader_.set_postfix(report_dict)
            print(f"Validating epoch {epoch}: {nice_dict(report_dict)} ")
            return linear_meters["val_acc"].summary()["acc"]

            # building training and validation set based on extracted features

        train_loader = dcp(self.val_loader)
        train_loader.dataset.datasets = (
            train_loader.dataset.datasets[0].datasets[0], )
        val_loader = dcp(self.val_loader)
        val_loader.dataset.datasets = (
            val_loader.dataset.datasets[0].datasets[1], )

        if data_aug:
            train_loader.dataset.datasets[0].transform = transform_train
            val_loader.dataset.datasets[0].transform = transform_val

        # network and optimization
        if not use_pretrain:
            self.model.torchnet.apply(weights_init)
        else:
            self.model.torchnet.head_B.apply(weights_init)
            # wipe out the initialization
        self.model.optimizer = torch.optim.Adam(
            self.model.torchnet.parameters(), lr=lr)
        self.model.scheduler = torch.optim.lr_scheduler.StepLR(
            self.model.optimizer, step_size=50, gamma=0.2)

        # meters
        meter_config = {
            "train_loss": AverageValueMeter(),
            "train_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"]),
            "val_acc": ConfusionMatrix(self.model.arch_dict["output_k_B"])
        }
        linear_meters = MeterInterface(meter_config)
        drawer = DrawCSV2(
            save_dir=self.save_dir,
            save_name=
            f"supervised_from_checkpoint_{use_pretrain}_data_aug_{data_aug}.png",
            columns_to_draw=[
                "train_loss_mean", "train_acc_acc", "val_acc_acc"
            ])
        for epoch in range(self.max_epoch):
            _sup_train_loop(train_loader, epoch)
            with torch.no_grad():
                _ = _sup_eval_loop(val_loader, epoch)
            self.model.step()
            linear_meters.step()
            linear_meters.summary().to_csv(
                self.save_dir /
                f"supervised_from_checkpoint_{use_pretrain}_data_aug_{data_aug}.csv"
            )
            drawer.draw(linear_meters.summary())
Пример #5
0
    transforms.Compose([
        pil_augment.RandomCrop((24, 24), padding=0),
        transforms.ToTensor(),
    ]),
    "tf3":
    transforms.Compose(
        [transforms.CenterCrop((24, 24)),
         transforms.ToTensor()]),
}
mnist_strong_transform = {
    # output shape would be 24*24
    "tf1":
    transforms.Compose([
        pil_augment.RandomChoice(transforms=[
            pil_augment.RandomCrop(size=(20, 20), padding=None),
            pil_augment.CenterCrop(size=(20, 20))
        ]),
        pil_augment.Resize(size=24, interpolation=PIL.Image.BILINEAR),
        transforms.ToTensor()
    ]),
    "tf2":
    transforms.Compose([
        pil_augment.RandomApply(transforms=[
            transforms.RandomRotation(degrees=(-25.0, 25.0),
                                      resample=False,
                                      expand=False)
        ],
                                p=0.5),
        pil_augment.RandomChoice(transforms=[
            pil_augment.RandomCrop(size=(16, 16), padding=None),
            pil_augment.RandomCrop(size=(20, 20), padding=None),