示例#1
0
def _main() -> None:
    """実行確認用スクリプト."""
    log_utils.init_root_logger()

    # show directory path
    logger.info(f"data directory: {get_data()}")
    logger.info(f"raw directory: {get_raw()}")
示例#2
0
def main() -> None:
    """CelebA データセットをダウンロードし、学習及びテスト用のファイルリストを生成する."""
    log_utils.init_root_logger()

    logger.info("=== download and extract files.")
    filepath = directories.get_raw().joinpath("img_align_celeba.zip")
    if filepath.exists() is False:
        download(filepath)

    logger.info("=== unzip.")
    extractpath = directories.get_raw()
    with zipfile.ZipFile(str(filepath)) as z:
        z.extractall(str(extractpath))

    logger.info("=== create train and valid file list.")
    filelist = sorted(
        [p.relative_to(extractpath) for p in extractpath.glob("**/*.jpg")])
    train_num = int(len(filelist) * 0.8)
    train_list = filelist[:train_num]
    valid_list = filelist[train_num:]

    train_path = directories.get_interim().joinpath("celeba_train.csv")
    with open(str(train_path), "w") as ft:
        writer = csv.writer(ft)
        writer.writerows([[p] for p in train_list])

    valid_path = directories.get_interim().joinpath("celeba_valid.csv")
    with open(str(valid_path), "w") as fv:
        writer = csv.writer(fv)
        writer.writerows([[p] for p in valid_list])
示例#3
0
def _main() -> None:
    """実行用スクリプト."""
    log_utils.init_root_logger()

    config = utils.load_config_from_input_args(lambda x: Config(**x))
    if config is None:
        logger.error("config error.")
        return

    filepath = download(config)
    logger.info(f"download path: {filepath}")
示例#4
0
    Returns:
        torchvision.transforms.Compose: データ変換コンポーネント
    """
    transforms = tv_transforms.Compose([
        tv_transforms.RandomHorizontalFlip(),
        tv_transforms.CenterCrop(148),
        tv_transforms.Resize(image_size),
        tv_transforms.ToTensor(),
    ])

    return transforms


def train(config: Config):
    """学習処理の実行スクリプト."""
    transforms = get_transforms(config.image_size)
    dataset_train = dataset.ImageDataset(config.train_list_path, transforms,
                                         dataset.Mode.TRAIN)
    dataset_valid = dataset.ImageDataset(config.valid_list_path, transforms,
                                         dataset.Mode.VALID)


if __name__ == "__main__":
    try:
        lu.init_root_logger()
        train()
    except Exception as e:
        logger.error(e)
        logger.error(traceback.format_exc())
示例#5
0
def _main() -> None:
    log_utils.init_root_logger()

    network = TransferVAE(3, 3, (64, 64))
    network = network.to("cuda")
    ts.summary(network, input_size=(3, 64, 64))
示例#6
0
文件: train.py 项目: samsgood0310/til
def main() -> None:
    """学習処理の実行スクリプト."""
    log_utils.init_root_logger()

    # train_dir = directories.get_raw().joinpath("hazelnut/train")
    # filelist = sorted(list(train_dir.glob("**/*.png")))
    train_dir = directories.get_raw().joinpath("img_align_celeba")
    filelist = sorted(list(train_dir.glob("**/*.jpg")))
    num_train = int(len(filelist) * 0.8)

    image_size = (64, 64)
    transforms = tv_transforms.Compose([
        # tv_transforms.Grayscale(num_output_channels=1),
        tv_transforms.RandomHorizontalFlip(),
        tv_transforms.CenterCrop(148),
        tv_transforms.Resize(image_size),
        tv_transforms.ToTensor(),
        # tv_transforms.Lambda(lambda x: 2.0 * x - 1.0),
    ])
    dataset_train = mvtec_ad.Dataset(filelist[:num_train], transforms,
                                     mvtec_ad.Mode.TRAIN)
    dataset_valid = mvtec_ad.Dataset(filelist[num_train:], transforms,
                                     mvtec_ad.Mode.VALID)

    batch_size = 144
    num_workers = 4
    dataloader_train = torch_data.DataLoader(
        dataset_train,
        batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        worker_init_fn=worker_init_random,
    )
    dataloader_valid = torch_data.DataLoader(
        dataset_valid,
        batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        worker_init_fn=worker_init_random,
    )

    random_seed = 42
    pl.seed_everything(random_seed)
    in_channels = 3
    out_channels = 3
    hparams = {
        "batch_size": batch_size,
        "num_workers": num_workers,
        "in_channels": in_channels,
        "out_channels": out_channels,
        "random_seed": random_seed,
    }
    # network = cnn_ae.SimpleCBR(in_channels, out_channels)
    # network = vanila_vae.VAE(in_channels, out_channels, image_size)
    network = transfer_vae.TransferVAE(in_channels, out_channels, image_size)
    # generator = vanila_gan.Generator(62)
    # discriminator = vanila_gan.Discriminator(in_channels, 1)
    # model = trainer.AETrainer(network, hparams)
    model = trainer.VAETrainer(network, hparams)
    # model = trainer_gan.GANTrainer(generator, discriminator, hparams)
    model.set_dataloader(dataloader_train, dataloader_valid)

    log_dir = "vanila_gan"
    save_top_k = 5
    early_stop = True
    min_epochs = 30
    max_epochs = 10000
    progress_bar_refresh_rate = 1
    cache_dir = directories.get_interim().joinpath(log_dir)
    profiler = True  # if use detail profiler, pl_profiler.AdvancedProfiler()
    model_checkpoint = pl_callbacks.ModelCheckpoint(
        filepath=str(cache_dir),
        monitor="val_loss",
        save_top_k=save_top_k,
        save_weights_only=False,
        mode="min",
        period=1,
    )
    pl_trainer = pl.Trainer(
        early_stop_callback=early_stop,
        default_root_dir=str(cache_dir),
        fast_dev_run=False,
        min_epochs=min_epochs,
        max_epochs=max_epochs,
        gpus=[0] if torch_cuda.is_available() else None,
        progress_bar_refresh_rate=progress_bar_refresh_rate,
        profiler=profiler,
        checkpoint_callback=model_checkpoint,
    )
    pl_trainer.fit(model)