Exemplo n.º 1
0
def main(cfg: OmegaConf):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)
    stream_handler.terminator = ""
    logger.addHandler(stream_handler)

    check_hydra_conf(cfg)
    init_ddp(cfg)

    # fix seed
    seed = cfg["parameter"]["seed"]
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    rank = cfg["distributed"]["local_rank"]
    logger.info("Using {}".format(rank))

    root = "~/pytorch_datasets"
    if cfg["experiment"]["name"].lower() == "cifar10":
        transform = create_simclr_data_augmentation(
            cfg["experiment"]["strength"], size=32)
        training_dataset = torchvision.datasets.CIFAR10(root=root,
                                                        train=True,
                                                        download=True,
                                                        transform=transform)
        validation_dataset = torchvision.datasets.CIFAR10(
            root=root,
            train=False,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ]))
        num_classes = 10
    elif cfg["experiment"]["name"].lower() == "cifar100":
        transform = create_simclr_data_augmentation(
            cfg["experiment"]["strength"], size=32)
        training_dataset = torchvision.datasets.CIFAR100(root=root,
                                                         train=True,
                                                         download=True,
                                                         transform=transform)
        validation_dataset = torchvision.datasets.CIFAR100(
            root=root,
            train=False,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ]))
        num_classes = 100
    else:
        assert cfg["experiment"]["name"].lower() in {"cifar10", "cifar100"}

    sampler = torch.utils.data.distributed.DistributedSampler(training_dataset,
                                                              shuffle=True)
    training_data_loader = DataLoader(
        dataset=training_dataset,
        sampler=sampler,
        num_workers=cfg["parameter"]["num_workers"],
        batch_size=cfg["experiment"]["batches"],
        pin_memory=True,
        drop_last=True,
    )

    validation_sampler = torch.utils.data.distributed.DistributedSampler(
        validation_dataset, shuffle=False)
    validation_data_loader = DataLoader(
        dataset=validation_dataset,
        sampler=validation_sampler,
        num_workers=cfg["parameter"]["num_workers"],
        batch_size=cfg["experiment"]["batches"],
        pin_memory=True,
        drop_last=False,
    )

    model = SupervisedModel(base_cnn=cfg["experiment"]["base_cnn"],
                            num_classes=num_classes)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    learning(cfg, training_data_loader, validation_data_loader, model)