예제 #1
0
img_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # transforms.Normalize((0.5), (0.5))
    ]
)
dataset = MNIST(DATA_PATH, transform=img_transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

model = Model()
model.torchnet = autoencoder()
model.optimizer = torch.optim.Adam(
    model.torchnet.parameters(), lr=1e-3, weight_decay=1e-5
)

config = ConfigManger().parsed_args
if config["loss"] == "mse":
    criterion = nn.MSELoss()
elif config["loss"] == "gdl":
    criterion = gradient_difference_loss(config["weight"])

trainer = MNISTTrainer(
    model=model,
    train_loader=dataloader,
    val_loader=dataloader,
    criterion=nn.MSELoss(),
    device="cuda",
    **config["Trainer"]
)
trainer.start_training()
예제 #2
0
) -> Type[trainer.ClusteringGeneralTrainer]:
    assert config.get("Trainer").get("name"), config.get("Trainer").get("name")
    trainer_mapping: Dict[str, Type[trainer.ClusteringGeneralTrainer]] = {
        "imsat":
        trainer.IMSATAbstractTrainer,  # imsat without any regularization
        "imsatvat": trainer.IMSATVATTrainer,
    }
    Trainer = trainer_mapping.get(config.get("Trainer").get("name").lower())
    assert Trainer, config.get("Trainer").get("name")
    return Trainer


DEFAULT_CONFIG = "config_CIFAR.yaml"

merged_config = ConfigManger(DEFAULT_CONFIG_PATH=DEFAULT_CONFIG,
                             verbose=True,
                             integrality_check=True).config
train_loader_A, train_loader_B, val_loader = get_dataloader(
    **merged_config["DataLoader"])

# create model:
model = Model(
    arch_dict=merged_config["Arch"],
    optim_dict=merged_config["Optim"],
    scheduler_dict=merged_config["Scheduler"],
)
model = to_Apex(model, opt_level=None, verbosity=0)

Trainer = get_trainer(merged_config)

clusteringTrainer = Trainer(model=model,
예제 #3
0
                                                           images)
        adv_pred = self.model(adv_images)
        assert simplex(adv_pred[0])

        batch_loss: List[torch.Tensor] = []  # type: ignore
        for subhead in range(pred.__len__()):
            _loss, _loss_no_lambda = self.criterion(pred[subhead],
                                                    adv_pred[subhead])
            batch_loss.append(_loss)
        batch_loss: torch.Tensor = sum(batch_loss) / len(
            batch_loss)  # type: ignore
        self.METERINTERFACE[f"train_mi"].add(-batch_loss.item())
        return batch_loss


config = ConfigManger(DEFAULT_CONFIG_PATH="./config.yml", verbose=False).config

datainterface = MNISTClusteringDatasetInterface(split_partitions=["train"],
                                                **config["DataLoader"])
datainterface.drop_last = True
train_loader = datainterface.ParallelDataLoader(
    default_mnist_img_transform["tf1"],
    default_mnist_img_transform["tf2"],
    default_mnist_img_transform["tf2"],
    default_mnist_img_transform["tf2"],
    default_mnist_img_transform["tf2"],
    default_mnist_img_transform["tf2"],
)
datainterface.split_partitions = ["val"]
datainterface.drop_last = False
val_loader = datainterface.ParallelDataLoader(
    config: Dict[str, Union[float, int, dict]]
) -> Type[trainer.ClusteringGeneralTrainer]:
    assert config.get("Trainer").get("name"), config.get("Trainer").get("name")
    trainer_mapping: Dict[str, Type[trainer.ClusteringGeneralTrainer]] = {
        "iicgeo": trainer.IICGeoTrainer,  # the basic iic
        "imsatvat": trainer.IMSATVATTrainer,  # imsat with vat
    }
    Trainer = trainer_mapping.get(config.get("Trainer").get("name").lower())
    assert Trainer, config.get("Trainer").get("name")
    return Trainer


if __name__ == '__main__':
    DEFAULT_CONFIG = "config_MNIST.yaml"
    merged_config = ConfigManger(DEFAULT_CONFIG_PATH=DEFAULT_CONFIG,
                                 verbose=False,
                                 integrality_check=True).config
    pprint(merged_config)
    # for reproducibility
    if merged_config.get("Seed"):
        fix_all_seed(merged_config.get("Seed"))

    # get train loaders and validation loader
    train_loader_A, train_loader_B, val_loader = get_dataloader(merged_config)

    # create model:
    model = Model(
        arch_dict=merged_config["Arch"],
        optim_dict=merged_config["Optim"],
        scheduler_dict=merged_config["Scheduler"],
    )
예제 #5
0
#   to verify whether the VAT can help better with IIC.
#   This experiment can be long and a pretrained checkpoint can be used to reduce the time.
##############################
from pathlib import Path

from deepclustering.dataset.classification import (
    Cifar10ClusteringDatasetInterface,
    default_cifar10_img_transform,
)
from deepclustering.manager import ConfigManger
from deepclustering.model import Model
from playground.IIC_VAT.VATIICTrainer import IMSATIICTrainer

DEFAULT_CONFIG = str(Path(__file__).parent / "config.yaml")

config = ConfigManger(DEFAULT_CONFIG_PATH=DEFAULT_CONFIG, verbose=True).config
# create model:
model = Model(
    arch_dict=config["Arch"],
    optim_dict=config["Optim"],
    scheduler_dict=config["Scheduler"],
)

train_loader_A = Cifar10ClusteringDatasetInterface(
    **config["DataLoader"]).ParallelDataLoader(
        default_cifar10_img_transform["tf1"],
        default_cifar10_img_transform["tf2"],
        default_cifar10_img_transform["tf2"],
        default_cifar10_img_transform["tf2"],
        default_cifar10_img_transform["tf2"],
    )
예제 #6
0
from deepclustering.augment import SequentialWrapper, pil_augment
from deepclustering.dataset.segmentation.acdc_dataset import ACDCSemiInterface
from deepclustering.manager import ConfigManger
from deepclustering.model import Model
from deepclustering.utils import fix_all_seed
from playground.PaNN.trainer import SemiSegTrainer

config = ConfigManger(DEFAULT_CONFIG_PATH="config.yaml",
                      integrality_check=True,
                      verbose=True).merged_config

fix_all_seed(config.get("Seed", 0))

data_handler = ACDCSemiInterface(labeled_data_ratio=0.99,
                                 unlabeled_data_ratio=0.01)
data_handler.compile_dataloader_params(
    labeled_batch_size=4,
    unlabeled_batch_size=8,
    val_batch_size=1,
    shuffle=True,
    num_workers=2,
)
# transformations
train_transforms = SequentialWrapper(
    img_transform=pil_augment.Compose([
        pil_augment.Resize((256, 256)),
        pil_augment.RandomCrop((224, 224)),
        pil_augment.RandomHorizontalFlip(),
        pil_augment.RandomRotation(degrees=10),
        pil_augment.ToTensor(),
    ]),
예제 #7
0
        SemiTrainer,
        SemiEntropyTrainer,
        SemiPrimalDualTrainer,
        SemiWeightedIICTrainer,
        SemiUDATrainer,
    )
try:
    from .dataset import get_mnist_dataloaders
except ImportError:
    from toy_example.dataset import get_mnist_dataloaders
from deepclustering.optim import RAdam
from deepclustering.utils import fix_all_seed
from deepclustering.manager import ConfigManger
from torch.optim.lr_scheduler import MultiStepLR

config = ConfigManger("./config.yaml", integrality_check=False).config

fix_all_seed(0)
## dataloader part
unlabeled_class_sample_nums = {0: 10000, 1: 1000, 2: 2000, 3: 3000, 4: 4000}
dataloader_params = {
    "batch_size": 64,
    "num_workers": 1,
    "drop_last": True,
    "pin_memory": True,
}
train_transform = transforms.Compose([transforms.ToTensor()])
val_transform = transforms.Compose([transforms.ToTensor()])
labeled_loader, unlabeled_loader, val_loader = get_mnist_dataloaders(
    labeled_sample_num=10,
    unlabeled_class_sample_nums=unlabeled_class_sample_nums,
예제 #8
0
from deepclustering.manager import ConfigManger
from deepclustering.model import Model, to_Apex

from arch import _register_arch
from data import get_dataloader
from scheduler import CustomScheduler
from trainer import AdaNetTrainer, VAT_Trainer

_ = _register_arch  # to enable the network registration

DEFAULT_CONFIG_PATH = "config.yaml"
config = ConfigManger(DEFAULT_CONFIG_PATH,
                      verbose=True,
                      integrality_check=False).config
model = Model(config.get("Arch"), config.get("Optim"), config.get("Scheduler"))
model = to_Apex(model, opt_level=None)

label_loader, unlabel_loader, val_loader = get_dataloader(
    config["DataLoader"].get("name"),
    config["DataLoader"].get("aug", False),
    config.get("DataLoader"),
)
scheduler = CustomScheduler(max_epoch=config["Trainer"]["max_epoch"])
assert config["Trainer"].get("name") in ("vat", "ada")

Trainer = VAT_Trainer if config["Trainer"]["name"].lower(
) == "vat" else AdaNetTrainer
trainer = Trainer(model=model,
                  labeled_loader=label_loader,
                  unlabeled_loader=unlabel_loader,
                  val_loader=val_loader,
        DEFAULT_CONFIG = os.path.join(checkpoint_path, "config_CIFAR.yaml")
    if "svhn" in checkpoint_path.lower():
        DEFAULT_CONFIG = os.path.join(checkpoint_path, "config_SVHN.yaml")
    assert DEFAULT_CONFIG
    assert Path(checkpoint_path, "config.yaml").exists()
    copyfile(os.path.join(checkpoint_path, "config.yaml"), DEFAULT_CONFIG)
    assert Path(DEFAULT_CONFIG).exists() and Path(DEFAULT_CONFIG).is_file()
    config = yaml_load(DEFAULT_CONFIG)
    return config, DEFAULT_CONFIG


if __name__ == '__main__':
    # interface: python analyze_main checkpoint_path=runs/07_15_benchmark/strong/cifar_2/iicgeo
    DEFAULT_CONFIG = None
    checkpoint_path = ConfigManger(
        DEFAULT_CONFIG_PATH=DEFAULT_CONFIG,
        verbose=True,
        integrality_check=True).config["checkpoint_path"]
    merged_config, DEFAULT_CONFIG = get_config(checkpoint_path)
    # correct save_dir and checkpoint_dir
    merged_config["Trainer"]["save_dir"] = checkpoint_path
    merged_config["Trainer"]["checkpoint_path"] = checkpoint_path
    merged_config["Trainer"]["max_epoch"] = 300

    # for reproducibility
    fix_all_seed(merged_config.get("Seed", 0))

    # get train loaders and validation loader
    train_loader_A, train_loader_B, val_loader = get_dataloader(
        merged_config, DEFAULT_CONFIG)

    # create model: