Beispiel #1
0
    def __init__(
        self,
        count: int = 1000,
        max_object_per_img: int = 4,
        max_object_scale: float = 0.25,
        height: int = 256,
        width: int = 256,
        transform: Callable = None,
        target_transform: Callable = None,
        seed: int = 0,
    ) -> None:
        """
        Interface for ShapesDataset
        :param count: how many samples to generate
        :param max_object_per_img: how many objects to show in an image
        :param height: image height
        :param width: image height
        """
        super().__init__()
        fix_all_seed(seed)
        assert (
            max_object_per_img >= 1
        ), f"max_object_per_img should be larger than 1, given {max_object_per_img}."
        self.max_object_per_img = max_object_per_img
        assert 0 < max_object_scale <= 0.75
        self.max_object_scale = max_object_scale
        self.height = height
        self.width = width
        self._image_ids = []
        self.image_info = []
        # Background is always the first class
        self.class_info = [{"source": "", "id": 0, "name": "BG"}]
        self.source_class_ids = {}
        # self.transform: Callable = transform
        # self.target_transform: Callable = target_transform
        self.squential_transform = SequentialWrapper(
            img_transform=transform,
            target_transform=target_transform,
            if_is_target=[False, True, True],
        )
        for _type in self.types:
            self.add_class("shapes", self.map_id[_type], _type)

        # Add images
        # Generate random specifications of images (i.e. color and
        # list of shapes sizes and locations). This is more compact than
        # actual images. Images are generated on the fly in load_image().
        for i in range(count):
            bg_color, shapes = self.random_image(height, width)
            self.add_image(
                "shapes",
                image_id=i,
                path=None,
                width=width,
                height=height,
                bg_color=bg_color,
                shapes=shapes,
            )

        self.prepare()
        "imsatvat": trainer.IMSATVATTrainer,  # imsat with vat
    }
    Trainer = trainer_mapping.get(config.get("Trainer").get("name").lower())
    assert Trainer, config.get("Trainer").get("name")
    return Trainer


if __name__ == '__main__':
    DEFAULT_CONFIG = "config_MNIST.yaml"
    merged_config = ConfigManger(DEFAULT_CONFIG_PATH=DEFAULT_CONFIG,
                                 verbose=False,
                                 integrality_check=True).config
    pprint(merged_config)
    # for reproducibility
    if merged_config.get("Seed"):
        fix_all_seed(merged_config.get("Seed"))

    # get train loaders and validation loader
    train_loader_A, train_loader_B, val_loader = get_dataloader(merged_config)

    # create model:
    model = Model(
        arch_dict=merged_config["Arch"],
        optim_dict=merged_config["Optim"],
        scheduler_dict=merged_config["Scheduler"],
    )
    # if use automatic precision mixture training
    model = to_Apex(model, opt_level=None, verbosity=0)

    # get specific trainer
    Trainer = get_trainer(merged_config)
from torchvision.transforms import Compose, RandomCrop, ToTensor

from deepclustering.manager import ConfigManger
from deepclustering.model import Model
from deepclustering.utils import fix_all_seed
from playground.IMSAT.IMSATTrainer import IMSATTrainer
from playground.IMSAT.mnist_helper import MNISTClusteringDatasetInterface

fix_all_seed(3)

DEFAULT_CONFIG = "./IMSAT.yaml"

merged_config = ConfigManger(
    DEFAULT_CONFIG_PATH=DEFAULT_CONFIG, verbose=True, integrality_check=True
).config

tf1 = Compose([ToTensor()])
tf2 = Compose([RandomCrop((28, 28), padding=2), ToTensor()])

# create model:
model = Model(
    arch_dict=merged_config["Arch"],
    optim_dict=merged_config["Optim"],
    scheduler_dict=merged_config["Scheduler"],
)

train_loader_A = MNISTClusteringDatasetInterface(
    **merged_config["DataLoader"]
).ParallelDataLoader(tf1, tf2, tf2, tf2, tf2)

val_loader = MNISTClusteringDatasetInterface(
from deepclustering.augment import SequentialWrapper, pil_augment
from deepclustering.dataset.segmentation.acdc_dataset import ACDCSemiInterface
from deepclustering.manager import ConfigManger
from deepclustering.model import Model
from deepclustering.utils import fix_all_seed
from playground.PaNN.trainer import SemiSegTrainer

config = ConfigManger(DEFAULT_CONFIG_PATH="config.yaml",
                      integrality_check=True,
                      verbose=True).merged_config

fix_all_seed(config.get("Seed", 0))

data_handler = ACDCSemiInterface(labeled_data_ratio=0.99,
                                 unlabeled_data_ratio=0.01)
data_handler.compile_dataloader_params(
    labeled_batch_size=4,
    unlabeled_batch_size=8,
    val_batch_size=1,
    shuffle=True,
    num_workers=2,
)
# transformations
train_transforms = SequentialWrapper(
    img_transform=pil_augment.Compose([
        pil_augment.Resize((256, 256)),
        pil_augment.RandomCrop((224, 224)),
        pil_augment.RandomHorizontalFlip(),
        pil_augment.RandomRotation(degrees=10),
        pil_augment.ToTensor(),
    ]),
Beispiel #5
0
        SemiPrimalDualTrainer,
        SemiWeightedIICTrainer,
        SemiUDATrainer,
    )
try:
    from .dataset import get_mnist_dataloaders
except ImportError:
    from toy_example.dataset import get_mnist_dataloaders
from deepclustering.optim import RAdam
from deepclustering.utils import fix_all_seed
from deepclustering.manager import ConfigManger
from torch.optim.lr_scheduler import MultiStepLR

config = ConfigManger("./config.yaml", integrality_check=False).config

fix_all_seed(0)
## dataloader part
unlabeled_class_sample_nums = {0: 10000, 1: 1000, 2: 2000, 3: 3000, 4: 4000}
dataloader_params = {
    "batch_size": 64,
    "num_workers": 1,
    "drop_last": True,
    "pin_memory": True,
}
train_transform = transforms.Compose([transforms.ToTensor()])
val_transform = transforms.Compose([transforms.ToTensor()])
labeled_loader, unlabeled_loader, val_loader = get_mnist_dataloaders(
    labeled_sample_num=10,
    unlabeled_class_sample_nums=unlabeled_class_sample_nums,
    train_transform=train_transform,
    val_transform=val_transform,
Beispiel #6
0
    assert simplex(x_out), f"x_out not normalized."
    assert simplex(x_tf_out), f"x_tf_out not normalized."

    bn, k = x_out.shape
    assert x_tf_out.size(0) == bn and x_tf_out.size(1) == k

    p_i_j = x_out.unsqueeze(2) * x_tf_out.unsqueeze(1)  # bn, k, k
    p_i_j = p_i_j.sum(dim=0)  # k, k aggregated over one batch
    p_i_j = (p_i_j + p_i_j.t()) / 2.0  # symmetric
    p_i_j /= p_i_j.sum()  # normalise

    return p_i_j


if __name__ == '__main__':
    fix_all_seed(0)

    logit1 = torch.randn(1000, 3, requires_grad=True)
    logit2 = torch.randn(1000, 3, requires_grad=True)

    optim = torch.optim.Adam((logit1, logit2))
    # criterion = CustomizedIICLossDual()
    # criterion = CustomizedIICLoss()
    criterion = IIDLoss()

    for i in range(1000000000):
        optim.zero_grad()
        p1 = torch.softmax(logit1, 1)
        p2 = torch.softmax(logit2, 1)
        loss2, _ = criterion(p1, p2)
        loss2.backward()