def main():
    # Create arg parser
    parser = argparse.ArgumentParser()
    parser = TopologyVAE.add_model_specific_args(parser)
    parser = WeightedNumpyDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    utils.add_default_trainer_args(parser, default_root="")

    parser.add_argument(
        "--augment_dataset",
        action='store_true',
        help="Use data augmentation or not",
    )
    parser.add_argument(
        "--use_binary_data",
        action='store_true',
        help="Binarize images in the dataset",
    )

    # Parse arguments
    hparams = parser.parse_args()
    hparams.root_dir = topology_get_path(
        k=hparams.rank_weight_k,
        n_max_epochs=hparams.max_epochs,
        predict_target=hparams.predict_target,
        hdims=hparams.target_predictor_hdims,
        metric_loss=hparams.metric_loss,
        metric_loss_kw=hparams.metric_loss_kw,
        beta_target_pred_loss=hparams.beta_target_pred_loss,
        beta_metric_loss=hparams.beta_metric_loss,
        latent_dim=hparams.latent_dim,
        beta_final=hparams.beta_final,
        use_binary_data=hparams.use_binary_data)
    print_flush(' '.join(sys.argv[1:]))
    print_flush(hparams.root_dir)
    pl.seed_everything(hparams.seed)

    # Create data
    if hparams.use_binary_data:
        if not os.path.exists(
                os.path.join(get_data_root(), 'topology_data/target_bin.npy')):
            gen_binary_dataset_from_all_files(get_data_root())
        hparams.dataset_path = os.path.join(ROOT_PROJECT,
                                            get_topology_binary_dataset_path())
    else:
        if not os.path.exists(
                os.path.join(get_data_root(), 'topology_data/target.npy')):
            gen_dataset_from_all_files(get_data_root())
        hparams.dataset_path = os.path.join(ROOT_PROJECT,
                                            get_topology_dataset_path())
    if hparams.augment_dataset:
        aug = transforms.Compose([
            # transforms.Normalize(mean=, std=),
            # transforms.RandomCrop(30, padding=10),
            transforms.RandomRotation(45),
            transforms.RandomRotation(90),
            transforms.RandomRotation(180),
            transforms.RandomVerticalFlip(0.5)
        ])
    else:
        aug = None
    datamodule = WeightedNumpyDataset(hparams,
                                      utils.DataWeighter(hparams),
                                      transform=aug)

    # Load model
    model = TopologyVAE(hparams)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(period=max(
        1, hparams.max_epochs // 10),
                                                       monitor="loss/val",
                                                       save_top_k=-1,
                                                       save_last=True,
                                                       mode='min')

    if hparams.load_from_checkpoint is not None:
        model = TopologyVAE.load_from_checkpoint(hparams.load_from_checkpoint)
        utils.update_hparams(hparams, model)
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            resume_from_checkpoint=hparams.load_from_checkpoint)

        print(f'Load from checkpoint')
    else:
        # Main trainer
        trainer = pl.Trainer(
            gpus=[hparams.cuda] if hparams.cuda is not None else 0,
            default_root_dir=hparams.root_dir,
            max_epochs=hparams.max_epochs,
            checkpoint_callback=True,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='step')
            ],
            terminate_on_nan=True,
            progress_bar_refresh_rate=5,
            # gradient_clip_val=20.0,
        )

    # Fit
    trainer.fit(model, datamodule=datamodule)

    print(
        f"Training finished; end of script: rename {checkpoint_callback.best_model_path}"
    )

    shutil.copyfile(
        checkpoint_callback.best_model_path,
        os.path.join(os.path.dirname(checkpoint_callback.best_model_path),
                     'best.ckpt'))
Пример #2
0
def get_topology_binary_target_path():
    return os.path.join(get_data_root(), f'topology_data/target_bin.npy')
Пример #3
0
def get_topology_binary_start_score_path():
    return os.path.join(get_data_root(), f'topology_data/start_score_bin.npy')
Пример #4
0
def get_topology_binary_dataset_path():
    return os.path.join(get_data_root(), f'topology_data/topology_data_bin.npz')