def test_ddp_sgd_comm_hook(tmpdir):
    """Test for DDP FP16 compress hook."""
    model = BoringModel()
    training_type_plugin = DDPPlugin(
        ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
        ddp_comm_hook=powerSGD.powerSGD_hook,
        sync_batchnorm=True,
    )
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        plugins=[training_type_plugin],
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = (trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook)
    expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
def test_ddp_fp16_compress_comm_hook(tmpdir):
    """Test for DDP FP16 compress hook."""
    model = BoringModel()
    training_type_plugin = DDPPlugin(
        ddp_comm_hook=default.fp16_compress_hook,
        sync_batchnorm=True,
    )
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        plugins=[training_type_plugin],
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = (trainer.accelerator.training_type_plugin._model.
                         get_ddp_logging_data().comm_hook)
    expected_comm_hook = default.fp16_compress_hook.__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
    """Test for DDP FP16 compress wrapper for SGD hook."""
    model = BoringModel()
    training_type_plugin = DDPPlugin(
        ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
        ddp_comm_hook=powerSGD.powerSGD_hook,
        ddp_comm_wrapper=default.fp16_compress_wrapper,
        sync_batchnorm=True,
    )
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
    expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
Example #4
0
def get_trainer(wandb_logger, callbacks, config):

    gpus = []
    if config.gpu0:
        gpus.append(0)
    if config.gpu1:
        gpus.append(1)
    logging.info("gpus active", gpus)
    if len(gpus) >= 2:
        distributed_backend = "ddp"
        accelerator = "dpp"
        plugins = DDPPlugin(find_unused_parameters=False)
    else:
        distributed_backend = None
        accelerator = None
        plugins = None

    trainer = pl.Trainer(
        logger=wandb_logger,
        gpus=gpus,
        max_epochs=config.NUM_EPOCHS,
        precision=config.precision_compute,
        #    limit_train_batches=0.1, #only to debug
        #    limit_val_batches=0.1, #only to debug
        #    limit_test_batches=0.1,
        #    val_check_interval=1,
        auto_lr_find=config.AUTO_LR,
        log_gpu_memory=True,
        #    distributed_backend=distributed_backend,
        #    accelerator=accelerator,
        #    plugins=plugins,
        callbacks=callbacks,
        progress_bar_refresh_rate=5,
    )

    return trainer
Example #5
0
def main():
    system = configure_system(
        hyperparameter_defaults["system"])(hyperparameter_defaults)
    logger = TensorBoardLogger(
        'experiments_logs',
        name=str(hyperparameter_defaults['system']) + "_" +
        str(system.model.__class__.__name__) + "_" +
        str(hyperparameter_defaults['criterion']) + "_" +
        str(hyperparameter_defaults['scheduler']))

    early_stop = EarlyStopping(monitor="valid_iou",
                               mode="max",
                               verbose=True,
                               patience=hyperparameter_defaults["patience"])
    model_checkpoint = ModelCheckpoint(
        monitor="valid_iou",
        mode="max",
        verbose=True,
        filename='Model-{epoch:02d}-{valid_iou:.5f}',
        save_top_k=3,
        save_last=True)
    trainer = pl.Trainer(
        gpus=[0, 1],
        plugins=DDPPlugin(find_unused_parameters=True),
        max_epochs=hyperparameter_defaults['epochs'],
        logger=logger,
        check_val_every_n_epoch=1,
        accelerator='ddp',
        callbacks=[early_stop, model_checkpoint],
        num_sanity_val_steps=0,
        limit_train_batches=1.0,
        deterministic=True,
    )

    trainer.fit(system)
    trainer.test(system)
Example #6
0
def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--batch-size", default=6, type=int)
    parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"])
    parser.add_argument(
        "--root-dir",
        type=Path,
        help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
    )
    parser.add_argument(
        "--librimix-tr-split",
        default="train-360",
        choices=["train-360", "train-100"],
        help="The training partition of librimix dataset. (default: ``train-360``)",
    )
    parser.add_argument(
        "--librimix-task",
        default="sep_clean",
        type=str,
        choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"],
        help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)",
    )
    parser.add_argument(
        "--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)"
    )
    parser.add_argument(
        "--sample-rate",
        default=8000,
        type=int,
        help="Sample rate of audio files in the given dataset. (default: 8000)",
    )
    parser.add_argument(
        "--exp-dir",
        default=Path("./exp"),
        type=Path,
        help="The directory to save checkpoints and logs."
    )
    parser.add_argument(
        "--epochs",
        metavar="NUM_EPOCHS",
        default=200,
        type=int,
        help="The number of epochs to train. (default: 200)",
    )
    parser.add_argument(
        "--learning-rate",
        default=1e-3,
        type=float,
        help="Initial learning rate. (default: 1e-3)",
    )
    parser.add_argument(
        "--num-gpu",
        default=1,
        type=int,
        help="The number of GPUs for training. (default: 1)",
    )
    parser.add_argument(
        "--num-node",
        default=1,
        type=int,
        help="The number of nodes in the cluster for training. (default: 1)",
    )
    parser.add_argument(
        "--num-workers",
        default=4,
        type=int,
        help="The number of workers for dataloader. (default: 4)",
    )

    args = parser.parse_args()

    model = _get_model(num_sources=args.num_speakers)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=5
    )
    train_loader, valid_loader, eval_loader = _get_dataloader(
        args.dataset,
        args.root_dir,
        args.num_speakers,
        args.sample_rate,
        args.batch_size,
        args.num_workers,
        args.librimix_task,
        args.librimix_tr_split,
    )
    loss = si_sdr_loss
    metric_dict = {
        "sdri": sdri_metric,
        "sisdri": sisdri_metric,
    }
    model = ConvTasNetModule(
        model=model,
        train_loader=train_loader,
        val_loader=valid_loader,
        loss=loss,
        optim=optimizer,
        metrics=metric_dict,
        lr_scheduler=lr_scheduler,
    )
    checkpoint_dir = args.exp_dir / "checkpoints"
    checkpoint = ModelCheckpoint(
        checkpoint_dir,
        monitor="Losses/val_loss",
        mode="min",
        save_top_k=5,
        save_weights_only=True,
        verbose=True
    )
    callbacks = [
        checkpoint,
        EarlyStopping(monitor="Losses/val_loss", mode="min", patience=30, verbose=True),
    ]
    trainer = Trainer(
        default_root_dir=args.exp_dir,
        max_epochs=args.epochs,
        gpus=args.num_gpu,
        num_nodes=args.num_node,
        accelerator="ddp",
        plugins=DDPPlugin(find_unused_parameters=False),  # make sure there is no unused params
        limit_train_batches=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
        callbacks=callbacks,
    )
    trainer.fit(model)
    model.load_from_checkpoint(checkpoint.best_model_path)
    state_dict = torch.load(checkpoint.best_model_path, map_location="cpu")
    state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()}
    torch.save(state_dict, args.exp_dir / "best_model.pth")
    trainer.test(model, eval_loader)
Example #7
0
def main():
    # parse arguments
    args = parse_args()
    rank_zero_only(pprint.pprint)(vars(args))

    # init default-cfg and merge it with the main- and data-cfg
    config = get_cfg_defaults()
    config.merge_from_file(args.main_cfg_path)
    config.merge_from_file(args.data_cfg_path)
    pl.seed_everything(config.TRAINER.SEED)  # reproducibility
    # TODO: Use different seeds for each dataloader workers
    # This is needed for data augmentation

    # scale lr and warmup-step automatically
    args.gpus = _n_gpus = setup_gpus(args.gpus)
    config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
    config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
    _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
    config.TRAINER.SCALING = _scaling
    config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
    config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP /
                                            _scaling)

    # lightning module
    profiler = build_profiler(args.profiler_name)
    model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler)
    loguru_logger.info(f"LoFTR LightningModule initialized!")

    # lightning data
    data_module = MultiSceneDataModule(args, config)
    loguru_logger.info(f"LoFTR DataModule initialized!")

    # TensorBoard Logger
    logger = TensorBoardLogger(save_dir='logs/tb_logs',
                               name=args.exp_name,
                               default_hp_metric=False)
    ckpt_dir = Path(logger.log_dir) / 'checkpoints'

    # Callbacks
    # TODO: update ModelCheckpoint to monitor multiple metrics
    ckpt_callback = ModelCheckpoint(
        monitor='auc@10',
        verbose=True,
        save_top_k=5,
        mode='max',
        save_last=True,
        dirpath=str(ckpt_dir),
        filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}')
    lr_monitor = LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor]
    if not args.disable_ckpt:
        callbacks.append(ckpt_callback)

    # Lightning Trainer
    trainer = pl.Trainer.from_argparse_args(
        args,
        plugins=DDPPlugin(find_unused_parameters=False,
                          num_nodes=args.num_nodes,
                          sync_batchnorm=config.TRAINER.WORLD_SIZE > 0),
        gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
        callbacks=callbacks,
        logger=logger,
        sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
        replace_sampler_ddp=False,  # use custom sampler
        reload_dataloaders_every_epoch=False,  # avoid repeated samples!
        weights_summary='full',
        profiler=profiler)
    loguru_logger.info(f"Trainer initialized!")
    loguru_logger.info(f"Start training!")
    trainer.fit(model, datamodule=data_module)
Example #8
0
def cli_main():
    # ------------
    # args
    # ------------
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        action='store',
                        dest='config',
                        help='config.yaml',
                        required=True)
    parser.add_argument('--ckpt',
                        action='store',
                        dest='ckpt',
                        help='checkpoint to load',
                        required=True)
    args = parser.parse_args()

    with open(args.config, 'r') as ymlfile:
        config = yaml.load(ymlfile, Loader=yaml.FullLoader)
        config = DotMap(config)

    assert (config.name in [
        "lstur", "nrms", "naml", "naml_simple", "sentirec", "robust_sentirec"
    ])

    pl.seed_everything(1234)

    # ------------
    # logging
    # ------------
    logger = TensorBoardLogger(**config.logger)

    # ------------
    # data
    # ------------

    test_dataset = BaseDataset(path.join(config.test_behavior),
                               path.join(config.test_news), config)
    test_loader = DataLoader(test_dataset, **config.test_dataloader)

    #print(len(dataset), len(train_dataset), len(val_dataset))
    # ------------
    # init model
    # ------------
    # ------------
    # init model
    # ------------
    # load embedding pre-trained embedding weights
    embedding_weights = []
    with open(config.embedding_weights, 'r') as file:
        lines = file.readlines()
        for line in tqdm(lines):
            weights = [float(w) for w in line.split(" ")]
            embedding_weights.append(weights)
    pretrained_word_embedding = torch.from_numpy(
        np.array(embedding_weights, dtype=np.float32))

    if config.name == "lstur":
        model = LSTUR.load_from_checkpoint(
            args.ckpt,
            config=config,
            pretrained_word_embedding=pretrained_word_embedding)
    elif config.name == "nrms":
        model = NRMS.load_from_checkpoint(
            args.ckpt,
            config=config,
            pretrained_word_embedding=pretrained_word_embedding)
    elif config.name == "naml":
        model = NAML.load_from_checkpoint(
            args.ckpt,
            config=config,
            pretrained_word_embedding=pretrained_word_embedding)
    elif config.name == "naml_simple":
        model = NAML_Simple.load_from_checkpoint(
            args.ckpt,
            config=config,
            pretrained_word_embedding=pretrained_word_embedding)
    elif config.name == "sentirec":
        model = SENTIREC.load_from_checkpoint(
            args.ckpt,
            config=config,
            pretrained_word_embedding=pretrained_word_embedding)
    elif config.name == "robust_sentirec":
        model = ROBUST_SENTIREC.load_from_checkpoint(
            args.ckpt,
            config=config,
            pretrained_word_embedding=pretrained_word_embedding)
    # elif:
    # UPCOMING MODELS

    # ------------
    # Test
    # ------------
    trainer = Trainer(**config.trainer,
                      logger=logger,
                      plugins=DDPPlugin(find_unused_parameters=False))

    trainer.test(model=model, test_dataloaders=test_loader)
Example #9
0
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.save_dir,
        filename='{epoch}-{val_loss:.3f}-{train_loss:.3f}',
        save_top_k=-1)
    logger = CometLogger(
        api_key="YOUR-API-KEY",
        project_name=proj_name,
    )

    model = lit_gazetrack_model(args.dataset_dir, args.save_dir,
                                args.batch_size, logger)
    if (args.checkpoint):
        if (args.gpus == 0):
            w = torch.load(args.checkpoint,
                           map_location=torch.device('cpu'))['state_dict']
        else:
            w = torch.load(args.checkpoint)['state_dict']
        model.load_state_dict(w)
        print("Loaded checkpoint")

    trainer = pl.Trainer(gpus=args.gpus,
                         logger=logger,
                         accelerator="ddp",
                         max_epochs=args.epochs,
                         default_root_dir=args.save_dir,
                         progress_bar_refresh_rate=1,
                         callbacks=[checkpoint_callback],
                         plugins=DDPPlugin(find_unused_parameters=False))
    trainer.fit(model)
    print("DONE")
Example #10
0
def process(args):
    torch.multiprocessing.set_sharing_strategy('file_system')
    # Pretraining data
    if args.dataset == "ZINC5k":
        dataset = ZINC5K("../data/torchdrug/molecule-datasets/",
                         node_feature="pretrain",
                         edge_feature="pretrain",
                         lazy=True)
    elif args.dataset == "ZINC250k":
        dataset = datasets.ZINC250k("../data/torchdrug/molecule-datasets/",
                                    node_feature="pretrain",
                                    edge_feature="pretrain",
                                    lazy=True)
    elif args.dataset == "ZINC2m":
        # defaults to lazy load
        dataset = datasets.ZINC2m("../data/torchdrug/molecule-datasets/",
                                  node_feature="pretrain",
                                  edge_feature="pretrain")

    # CTRP smiles to embed
    ctrp = pd.read_csv("../data/drug_screens/CTRP/v20.meta.per_compound.txt",
                       sep="\t")
    ctrp_ds = MoleculeDataset()
    ctrp_ds.load_smiles(smiles_list=ctrp['cpd_smiles'],
                        targets=dict(),
                        node_feature='pretrain',
                        edge_feature='pretrain')

    # Self-supervised pretraining
    dm = ChemGraphDataModule.from_argparse_args(args,
                                                train=dataset,
                                                predict=ctrp_ds)
    model = ChemGraphEmbeddingNetwork(task=args.task,
                                      input_dim=dataset.node_feature_dim,
                                      hidden_dims=[512] * 5,
                                      edge_input_dim=dataset.edge_feature_dim,
                                      batch_norm=True,
                                      readout="mean",
                                      mask_rate=0.15)
    # Callbacks
    fname = f"{args.name}_{args.task}_{args.dataset}"
    logger = TensorBoardLogger(save_dir=args.default_root_dir,
                               version=fname,
                               name='lightning_logs')
    early_stop = EarlyStopping(monitor='accuracy',
                               min_delta=0.001,
                               patience=5,
                               verbose=False,
                               mode='max')
    checkpoint_callback = ModelCheckpoint(monitor='accuracy', mode='max')
    trainer = Trainer.from_argparse_args(
        args,
        default_root_dir=logger.log_dir,
        logger=logger,
        callbacks=[early_stop, checkpoint_callback],
        strategy=DDPPlugin(find_unused_parameters=False),
        profiler='simple')
    trainer.fit(model, dm)

    # Generate CTRP embeddings
    model.to('cpu')
    model.eval()
    dl = DataLoader(ctrp_ds, batch_size=len(ctrp_ds))
    graph_embeds = []
    node_embeds = []
    for batch in dl:
        graph_feature, node_feature = model(batch)
        graph_embeds.append(graph_feature.detach())
        node_embeds.append(node_feature.detach())
    graph_embeds = torch.cat(graph_embeds).numpy()
    node_embeds = torch.cat(node_embeds).numpy()

    # Write out
    node_cpd_ids = [
        np.repeat(cpd_id, n['graph'].num_node)
        for n, cpd_id in zip(ctrp_ds, ctrp['broad_cpd_id'])
    ]
    node_cpd_ids = np.concatenate(node_cpd_ids)
    node_embeds = pd.DataFrame(node_embeds, index=node_cpd_ids)
    node_embeds['atom_type'] = np.concatenate(
        [[ATOM_SYMBOL[a] for a in n['graph'].atom_type] for n in ctrp_ds])
    graph_embeds = pd.DataFrame(graph_embeds, index=ctrp['broad_cpd_id'])
    node_embeds.to_csv(
        f"../data/torchdrug/molecule-datasets/{fname}_ctrp_node_embeds.csv",
        sep=",")
    graph_embeds.to_csv(
        f"../data/torchdrug/molecule-datasets/{fname}_ctrp_graph_embeds.csv",
        sep=",")
Example #11
0
def train_model(model, model_dir):
    # Setup trainer
    tb_logger = pl_loggers.TensorBoardLogger('{}/logs/'.format(model_dir))

    chkpt1 = ModelCheckpoint(save_last=True) 
    chkpt2 = ModelCheckpoint(every_n_train_steps=10000) # save every 10000 steps

    if Constants.n_gpus != 0:
        trainer = Trainer(gpus=Constants.n_gpus, callbacks=[chkpt1, chkpt2], accelerator='ddp_spawn', plugins=DDPPlugin(find_unused_parameters=False), precision=16, logger=tb_logger, default_root_dir=model_dir, max_epochs=n_epochs)
    else:
        trainer = Trainer(gpus=0, default_root_dir=model_dir, logger=tb_logger, callbacks=[chkpt1, chkpt2], max_epochs=n_epochs)

    trainer.fit(model)
Example #12
0
def train_model(
    train_config: TrainConfig,
    video_loader_config: Optional[VideoLoaderConfig] = None,
):
    """Trains a model.

    Args:
        train_config (TrainConfig): Pydantic config for training.
        video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos.
            If None, will use default for model specified in TrainConfig.
    """
    # get default VLC for model if not specified
    if video_loader_config is None:
        video_loader_config = ModelConfig(
            train_config=train_config, video_loader_config=video_loader_config
        ).video_loader_config

    # set up model
    model = instantiate_model(
        checkpoint=train_config.checkpoint,
        scheduler_config=train_config.scheduler_config,
        weight_download_region=train_config.weight_download_region,
        model_cache_dir=train_config.model_cache_dir,
        labels=train_config.labels,
        from_scratch=train_config.from_scratch,
        model_name=train_config.model_name,
        predict_all_zamba_species=train_config.predict_all_zamba_species,
    )

    data_module = ZambaDataModule(
        video_loader_config=video_loader_config,
        transform=MODEL_MAPPING[model.__class__.__name__]["transform"],
        train_metadata=train_config.labels,
        batch_size=train_config.batch_size,
        num_workers=train_config.num_workers,
    )

    validate_species(model, data_module)

    train_config.save_dir.mkdir(parents=True, exist_ok=True)

    # add folder version_n that auto increments if we are not overwriting
    tensorboard_version = train_config.save_dir.name if train_config.overwrite else None
    tensorboard_save_dir = (
        train_config.save_dir.parent if train_config.overwrite else train_config.save_dir
    )

    tensorboard_logger = TensorBoardLogger(
        save_dir=tensorboard_save_dir,
        name=None,
        version=tensorboard_version,
        default_hp_metric=False,
    )

    logging_and_save_dir = (
        tensorboard_logger.log_dir if not train_config.overwrite else train_config.save_dir
    )

    model_checkpoint = ModelCheckpoint(
        dirpath=logging_and_save_dir,
        filename=train_config.model_name,
        monitor=train_config.early_stopping_config.monitor
        if train_config.early_stopping_config is not None
        else None,
        mode=train_config.early_stopping_config.mode
        if train_config.early_stopping_config is not None
        else "min",
    )

    callbacks = [model_checkpoint]

    if train_config.early_stopping_config is not None:
        callbacks.append(EarlyStopping(**train_config.early_stopping_config.dict()))

    if train_config.backbone_finetune_config is not None:
        callbacks.append(BackboneFinetuning(**train_config.backbone_finetune_config.dict()))

    trainer = pl.Trainer(
        gpus=train_config.gpus,
        max_epochs=train_config.max_epochs,
        auto_lr_find=train_config.auto_lr_find,
        logger=tensorboard_logger,
        callbacks=callbacks,
        fast_dev_run=train_config.dry_run,
        accelerator="ddp" if data_module.multiprocessing_context is not None else None,
        plugins=DDPPlugin(find_unused_parameters=False)
        if data_module.multiprocessing_context is not None
        else None,
    )

    if video_loader_config.cache_dir is None:
        logger.info("No cache dir is specified. Videos will not be cached.")
    else:
        logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.")

    if train_config.auto_lr_find:
        logger.info("Finding best learning rate.")
        trainer.tune(model, data_module)

    try:
        git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
    except git.exc.InvalidGitRepositoryError:
        git_hash = None

    configuration = {
        "git_hash": git_hash,
        "model_class": model.model_class,
        "species": model.species,
        "starting_learning_rate": model.lr,
        "train_config": json.loads(train_config.json(exclude={"labels"})),
        "training_start_time": datetime.utcnow().isoformat(),
        "video_loader_config": json.loads(video_loader_config.json()),
    }

    if not train_config.dry_run:
        config_path = Path(logging_and_save_dir) / "train_configuration.yaml"
        config_path.parent.mkdir(exist_ok=True, parents=True)
        logger.info(f"Writing out full configuration to {config_path}.")
        with config_path.open("w") as fp:
            yaml.dump(configuration, fp)

    logger.info("Starting training...")
    trainer.fit(model, data_module)

    if not train_config.dry_run:
        if trainer.datamodule.test_dataloader() is not None:
            logger.info("Calculating metrics on holdout set.")
            test_metrics = trainer.test(dataloaders=trainer.datamodule.test_dataloader())[0]
            with (Path(logging_and_save_dir) / "test_metrics.json").open("w") as fp:
                json.dump(test_metrics, fp, indent=2)

        if trainer.datamodule.val_dataloader() is not None:
            logger.info("Calculating metrics on validation set.")
            val_metrics = trainer.validate(dataloaders=trainer.datamodule.val_dataloader())[0]
            with (Path(logging_and_save_dir) / "val_metrics.json").open("w") as fp:
                json.dump(val_metrics, fp, indent=2)

    return trainer
Example #13
0
def cli_main():
    # ------------
    # args
    # ------------
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        action='store',
                        dest='config',
                        help='config.yaml',
                        required=True)
    parser.add_argument('--resume',
                        action='store',
                        dest='resume',
                        help='resume training form ckpt',
                        required=False)
    args = parser.parse_args()

    with open(args.config, 'r') as ymlfile:
        config = yaml.load(ymlfile, Loader=yaml.FullLoader)
        config = DotMap(config)

    assert (config.name in [
        "lstur", "nrms", "naml", "naml_simple", "sentirec", "robust_sentirec"
    ])

    pl.seed_everything(1234)

    # ------------
    # init callbacks & logging
    # ------------
    checkpoint_callback = ModelCheckpoint(**config.checkpoint)
    logger = TensorBoardLogger(**config.logger)

    # ------------
    # data
    # ------------
    train_dataset = BaseDataset(path.join(config.train_behavior),
                                path.join(config.train_news), config)
    val_dataset = BaseDataset(path.join(config.val_behavior),
                              path.join(config.train_news), config)
    train_loader = DataLoader(train_dataset, **config.train_dataloader)
    val_loader = DataLoader(val_dataset, **config.val_dataloader)

    # ------------
    # init model
    # ------------
    # load embedding pre-trained embedding weights
    embedding_weights = []
    with open(config.embedding_weights, 'r') as file:
        lines = file.readlines()
        for line in tqdm(lines):
            weights = [float(w) for w in line.split(" ")]
            embedding_weights.append(weights)
    pretrained_word_embedding = torch.from_numpy(
        np.array(embedding_weights, dtype=np.float32))

    if config.name == "lstur":
        model = LSTUR(config, pretrained_word_embedding)
    elif config.name == "nrms":
        model = NRMS(config, pretrained_word_embedding)
    elif config.name == "naml":
        model = NAML(config, pretrained_word_embedding)
    elif config.name == "naml_simple":
        model = NAML_Simple(config, pretrained_word_embedding)
    elif config.name == "sentirec":
        model = SENTIREC(config, pretrained_word_embedding)
    elif config.name == "robust_sentirec":
        model = ROBUST_SENTIREC(config, pretrained_word_embedding)
    # elif:
    # UPCOMING MODELS

    # ------------
    # training
    # ------------
    early_stop_callback = EarlyStopping(**config.early_stop)
    if args.resume is not None:
        model = model.load_from_checkpoint(
            args.resume,
            config=config,
            pretrained_word_embedding=pretrained_word_embedding)
        trainer = Trainer(**config.trainer,
                          callbacks=[early_stop_callback, checkpoint_callback],
                          logger=logger,
                          plugins=DDPPlugin(find_unused_parameters=config.
                                            find_unused_parameters),
                          resume_from_checkpoint=args.resume)
    else:
        trainer = Trainer(**config.trainer,
                          callbacks=[early_stop_callback, checkpoint_callback],
                          logger=logger,
                          plugins=DDPPlugin(find_unused_parameters=config.
                                            find_unused_parameters))
    trainer.fit(model=model,
                train_dataloader=train_loader,
                val_dataloaders=val_loader)
Example #14
0
def run(config):
    # build hooks
    loss_fn = build_loss(config)
    metric_fn = build_metrics(config)
    hooks = build_hooks(config)
    hooks.update({"loss_fn": loss_fn, "metric_fn": metric_fn})

    # build model
    model = build_model(config)

    # build callbacks
    callbacks = build_callbacks(config)

    # build logger
    logger = build_logger(config)

    # debug
    if config.debug:
        logger = None
        OmegaConf.set_struct(config, True)
        with open_dict(config):
            config.trainer.trainer.max_epochs = None
            config.trainer.trainer.max_steps = 10

    # logging for wandb or mlflow
    if hasattr(logger, "log_hyperparams"):
        for k, v in config.trainer.items():
            if not k in ("metrics", "inference"):
                logger.log_hyperparams(params=v)
        logger.log_hyperparams(params=config.dataset)
        logger.log_hyperparams(params=config.augmentation)

    # last linear training
    if (hasattr(config.trainer.model, "last_linear")
            and (config.trainer.model.last_linear.training)
            and (config.trainer.model.params.pretrained)):
        model = train_last_linear(config, model, hooks, logger)

    # initialize model
    model, params = kvt.utils.initialize_model(config, model)

    # build optimizer
    optimizer = build_optimizer(config, model=model, params=params)

    # build scheduler
    scheduler = build_scheduler(config, optimizer=optimizer)

    # build dataloaders
    dataloaders = build_dataloaders(config)

    # build strong transform
    strong_transform, storong_transform_p = build_strong_transform(config)

    # build lightning module
    lightning_module = build_lightning_module(
        config,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        hooks=hooks,
        dataloaders=dataloaders,
        strong_transform=strong_transform,
        storong_transform_p=storong_transform_p,
    )

    # build plugins
    # fix this issue
    # https://github.com/PyTorchLightning/pytorch-lightning/discussions/6219
    plugins = []
    if hasattr(config.trainer.trainer,
               "accelerator") and (config.trainer.trainer.accelerator
                                   in ("ddp", "ddp2")):
        if hasattr(config.trainer, "find_unused_parameters"):
            plugins.append(
                DDPPlugin(find_unused_parameters=config.trainer.
                          find_unused_parameters), )
        else:
            plugins.append(DDPPlugin(find_unused_parameters=False), )

    # best model path
    dir_path = config.trainer.callbacks.ModelCheckpoint.dirpath
    if isinstance(OmegaConf.to_container(config.dataset.dataset), list):
        idx_fold = config.dataset.dataset[0].params.idx_fold
    else:
        idx_fold = config.dataset.dataset.params.idx_fold
    filename = f"fold_{idx_fold}_best.ckpt"
    best_model_path = os.path.join(dir_path, filename)

    # train loop
    trainer = pl.Trainer(
        logger=logger,
        callbacks=callbacks,
        plugins=plugins,
        **config.trainer.trainer,
    )
    if not config.trainer.skip_training:
        trainer.fit(lightning_module)
        path = trainer.checkpoint_callback.best_model_path
        if path:
            print(f"Best model: {path}")
            print("Renaming...")
            # copy best model
            subprocess.run(f"mv {path} {best_model_path}",
                           shell=True,
                           stdout=PIPE,
                           stderr=PIPE)
        # if there is no best_model_path
        # e.g. no valid dataloader
        else:
            print("Saving current trainer...")
            trainer.save_checkpoint(best_model_path)

    # log best model
    if hasattr(logger, "log_hyperparams"):
        logger.log_hyperparams(params={"best_model_path": best_model_path})

    # load best checkpoint
    if os.path.exists(best_model_path):
        print(f"Loading best model: {best_model_path}")
        state_dict = torch.load(best_model_path)["state_dict"]

        # if using dp, it is necessary to fix state dict keys
        if (hasattr(config.trainer.trainer, "sync_batchnorm")
                and config.trainer.trainer.sync_batchnorm):
            state_dict = kvt.utils.fix_dp_model_state_dict(state_dict)

        lightning_module.model.load_state_dict(state_dict)
    else:
        print(f"Best model {best_model_path} does not exist.")

    # evaluate
    metric_dict = evaluate(lightning_module,
                           hooks,
                           config,
                           mode=["validation"])
    print("Result:")
    print(metric_dict)

    if hasattr(logger, "log_metrics"):
        logger.log_metrics(metric_dict)
Example #15
0
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


dataset = MNIST(os.getcwd(), download=False, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

# init model
autoencoder = LitAutoEncoder()

# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)

parallel_devices = [torch.device(i) for i in range(torch.cuda.device_count())]
acc = GPUAccelerator(precision_plugin=NativeMixedPrecisionPlugin(),
                     training_type_plugin=DDPPlugin(
                         parallel_devices=parallel_devices,
                         cluster_environment=LSFEnvironment()))

targs = {
    'max_epochs': 1,
    'num_nodes': 2,
    'accumulate_grad_batches': 1,
    'gpus': 6,
    'accelerator': acc,
    'limit_train_batches': 10,
    'limit_val_batches': 5,
    'log_every_n_steps': 1
}

# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer(**targs)

@RunIf(min_gpus=2)
@mock.patch.dict(
    os.environ,
    {
        "CUDA_VISIBLE_DEVICES": "0,1",
        "SLURM_NTASKS": "2",
        "SLURM_JOB_NAME": "SOME_NAME",
        "SLURM_NODEID": "0",
        "SLURM_PROCID": "1",
        "SLURM_LOCALID": "1",
    },
)
@mock.patch("pytorch_lightning.plugins.DDPPlugin.setup_distributed", autospec=True)
@pytest.mark.parametrize("strategy", ["ddp", DDPPlugin()])
def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy):
    class CB(Callback):
        def on_fit_start(self, trainer, pl_module):
            assert trainer._accelerator_connector._is_slurm_managing_tasks()
            assert isinstance(trainer.accelerator, GPUAccelerator)
            assert isinstance(trainer.training_type_plugin, DDPPlugin)
            assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment)
            assert trainer.training_type_plugin.cluster_environment.local_rank() == 1
            assert trainer.training_type_plugin.local_rank == 1
            raise SystemExit()

    model = BoringModel()
    trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2, callbacks=[CB()])

    with pytest.raises(SystemExit):
def create_lightning_trainer(container: LightningContainer,
                             resume_from_checkpoint: Optional[Path] = None,
                             num_nodes: int = 1,
                             multiple_trainloader_mode: str = "max_size_cycle") -> \
        Tuple[Trainer, StoringLogger]:
    """
    Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
    and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
    return value.
    :param container: The container with model and data.
    :param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
    :param num_nodes: The number of nodes to use in distributed training.
    :return: A tuple [Trainer object, diagnostic logger]
    """
    logging.debug(f"resume_from_checkpoint: {resume_from_checkpoint}")
    num_gpus = container.num_gpus_per_node()
    effective_num_gpus = num_gpus * num_nodes
    strategy = None
    if effective_num_gpus == 0:
        accelerator = "cpu"
        devices = 1
        message = "CPU"
    else:
        accelerator = "gpu"
        devices = num_gpus
        message = f"{devices} GPU"
        if effective_num_gpus > 1:
            # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of
            # GPU memory).
            # Initialize the DDP plugin. The default for pl_find_unused_parameters is False. If True, the plugin
            # prints out lengthy warnings about the performance impact of find_unused_parameters.
            strategy = DDPPlugin(find_unused_parameters=container.pl_find_unused_parameters)
            message += "s per node with DDP"
    logging.info(f"Using {message}")
    tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
    loggers = [tensorboard_logger, AzureMLLogger(False)]
    storing_logger = StoringLogger()
    loggers.append(storing_logger)
    # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
    precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
    # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
    # https://pytorch.org/docs/stable/notes/randomness.html
    # Note that switching to deterministic models can have large performance downside.
    if container.pl_deterministic:
        deterministic = True
        benchmark = False
    else:
        deterministic = False
        benchmark = True

    # The last checkpoint is considered the "best" checkpoint. For large segmentation
    # models, this still appears to be the best way of choosing them because validation loss on the relatively small
    # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
    # not for the HeadAndNeck model.
    # Note that "last" is somehow a misnomer, it should rather be "latest". There is a "last" checkpoint written in
    # every epoch. We could use that for recovery too, but it could happen that the job gets preempted right during
    # writing that file, and we would end up with an invalid file.
    last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
                                               save_last=True,
                                               save_top_k=0)
    recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
                                                   filename=AUTOSAVE_CHECKPOINT_FILE_NAME,
                                                   every_n_val_epochs=container.autosave_every_n_val_epochs,
                                                   save_last=False)
    callbacks: List[Callback] = [
        last_checkpoint_callback,
        recovery_checkpoint_callback,
    ]
    if container.monitor_loading:
        # TODO antonsc: Remove after fixing the callback.
        raise NotImplementedError("Monitoring batch loading times has been temporarily disabled.")
        # callbacks.append(BatchTimeCallback())
    if num_gpus > 0 and container.monitor_gpu:
        logging.info("Adding monitoring for GPU utilization")
        callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True))
    # Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
    additional_args = container.get_trainer_arguments()
    # Callbacks can be specified via the "callbacks" argument (the legacy behaviour) or the new get_callbacks method
    if "callbacks" in additional_args:
        more_callbacks = additional_args.pop("callbacks")
        if isinstance(more_callbacks, list):
            callbacks.extend(more_callbacks)  # type: ignore
        else:
            callbacks.append(more_callbacks)  # type: ignore
    callbacks.extend(container.get_callbacks())
    is_azureml_run = not is_offline_run_context(RUN_CONTEXT)
    progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
    if progress_bar_refresh_rate is None:
        progress_bar_refresh_rate = 50
        logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
                     f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
    if is_azureml_run:
        callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate,
                                            write_to_logging_info=True,
                                            print_timestamp=False))
    else:
        callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate))
    # Read out additional model-specific args here.
    # We probably want to keep essential ones like numgpu and logging.
    trainer = Trainer(default_root_dir=str(container.outputs_folder),
                      deterministic=deterministic,
                      benchmark=benchmark,
                      accelerator=accelerator,
                      strategy=strategy,
                      max_epochs=container.num_epochs,
                      # Both these arguments can be integers or floats. If integers, it is the number of batches.
                      # If float, it's the fraction of batches. We default to 1.0 (processing all batches).
                      limit_train_batches=container.pl_limit_train_batches or 1.0,
                      limit_val_batches=container.pl_limit_val_batches or 1.0,
                      num_sanity_val_steps=container.pl_num_sanity_val_steps,
                      check_val_every_n_epoch=container.pl_check_val_every_n_epoch,
                      callbacks=callbacks,
                      logger=loggers,
                      num_nodes=num_nodes,
                      devices=devices,
                      precision=precision,
                      sync_batchnorm=True,
                      detect_anomaly=container.detect_anomaly,
                      profiler=container.pl_profiler,
                      resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
                      multiple_trainloader_mode=multiple_trainloader_mode,
                      **additional_args)
    return trainer, storing_logger
Example #18
0
def main(args):
    backbone = "bert-base-uncased-itokens"
    tokenizer = BertTokenizerFast.from_pretrained(backbone)

    # encoder_decoder_config = EncoderDecoderConfig.from_pretrained("bert-base-uncased-itokens")
    # model = EncoderDecoderModel.from_pretrained(
    #     "bert-base-uncased-itokens", config=encoder_decoder_config
    # )

    # model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    #     "bert-base-uncased-itokens", "bert-base-uncased-itokens", tie_encoder_decoder=True
    # )

    # generator = Generator(model)

    # discriminator = Discriminator(
    #     AutoModel.from_pretrained("bert-base-uncased-itokens")
    # )

    if args.test:
        model = GAN.load_from_checkpoint(args.load_checkpoint,
                                         args=args,
                                         tokenizer=tokenizer,
                                         backbone=backbone)
        model.cuda()
        model.eval()

        model.inference(args.scene_graphs_json)

        return

    # train
    if args.gpus > 1:
        dm = VGDataModule(args, tokenizer, 2)
    else:
        dm = VGDataModule(args, tokenizer)

    if args.load_checkpoint != "":
        model = GAN.load_from_checkpoint(args.load_checkpoint,
                                         args=args,
                                         tokenizer=tokenizer,
                                         backbone=backbone)
    else:
        model = GAN(args, tokenizer, backbone)

    training_args = {
        "gpus": args.gpus,
        "fast_dev_run": False,
        "max_steps": args.num_iterations,
        "precision": 32,
        "gradient_clip_val": 1,
    }

    if args.gpus > 1:
        additional_args = {
            "accelerator": "ddp",
            "plugins": [DDPPlugin(find_unused_parameters=True)]
            # "plugins": [my_ddp]
        }

        training_args.update(additional_args)

    trainer = pl.Trainer(**training_args)
    trainer.fit(model, dm)
def main():
    """
    Main training loop.
    """
    parser = ArgumentParser()

    parser = UNet.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)

    args = parser.parse_args()

    prod = bool(os.getenv("PROD"))
    logging.getLogger(__name__).setLevel(logging.INFO)

    if prod:
        logging.info(
            "Training i production mode, disabling all debugging APIs")
        torch.autograd.set_detect_anomaly(False)
        torch.autograd.profiler.profile(enabled=False)
        torch.autograd.profiler.emit_nvtx(enabled=False)
    else:
        logging.info("Training i development mode, debugging APIs active.")
        torch.autograd.set_detect_anomaly(True)
        torch.autograd.profiler.profile(enabled=True,
                                        use_cuda=True,
                                        record_shapes=True,
                                        profile_memory=True)
        torch.autograd.profiler.emit_nvtx(enabled=True, record_shapes=True)

    model = UNet(**vars(args))

    logging.info(
        f"Network:\n"
        f"\t{model.hparams.n_channels} input channels\n"
        f"\t{model.hparams.n_classes} output channels (classes)\n"
        f'\t{"Bilinear" if model.hparams.bilinear else "Transposed conv"} upscaling'
    )

    cudnn.benchmark = True  # cudnn Autotuner
    cudnn.enabled = True  # look for optimal algorithms

    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=0.00,
        mode="min",
        patience=10 if not os.getenv("EARLY_STOP") else int(
            os.getenv("EARLY_STOP")),
        verbose=True,
    )

    lr_monitor = LearningRateMonitor()

    run_name = "{}_LR{}_BS{}_IS{}".format(
        datetime.now().strftime("%d-%m-%Y-%H-%M-%S"),
        args.lr,
        args.batch_size,
        args.image_size,
    ).replace(".", "_")

    log_folder = ("./logs" if not os.getenv("DIR_ROOT_DIR") else
                  os.getenv("DIR_ROOT_DIR"))
    if not os.path.isdir(log_folder):
        os.mkdir(log_folder)
    logger = TensorBoardLogger(log_folder, name=run_name)

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='./checkpoints',
        filename='unet-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min',
    )

    try:
        trainer = Trainer.from_argparse_args(
            args,
            gpus=-1,
            accelerator="ddp",
            plugins=DDPPlugin(find_unused_parameters=False),
            precision=16,
            auto_lr_find="learning_rate"
            if float(os.getenv("LRN_RATE")) == 0.0 else False,
            logger=logger,
            callbacks=[early_stop_callback, lr_monitor, checkpoint_callback],
            accumulate_grad_batches=1.0 if not os.getenv("ACC_GRAD") else int(
                os.getenv("ACC_GRAD")),
            gradient_clip_val=0.0 if not os.getenv("GRAD_CLIP") else float(
                os.getenv("GRAD_CLIP")),
            max_epochs=100 if not os.getenv("EPOCHS") else int(
                os.getenv("EPOCHS")),
            val_check_interval=0.1 if not os.getenv("VAL_INT_PER") else float(
                os.getenv("VAL_INT_PER")),
            default_root_dir=os.getcwd()
            if not os.getenv("DIR_ROOT_DIR") else os.getenv("DIR_ROOT_DIR"),
            fast_dev_run=True
            if os.getenv("FAST_DEV_RUN") == "True" else False,
        )
        if float(os.getenv("LRN_RATE")) == 0.0:
            trainer.tune(model)
        trainer.fit(model)
        trainer.test(model)
    except KeyboardInterrupt:
        torch.save(model.state_dict(), "INTERRUPTED.pth")
        logging.info("Saved interrupt")
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Example #20
0
def run(cfg: DictConfig):
    local_rank = int(os.environ.get('LOCAL_RANK', 0))

    # The logs & checkpoints are dumped in: ${cfg.output_dir}/${cfg.experiment_name}/vN, where vN
    # is v0, v1, .... The version number increases automatically.
    script_dir = Path.cwd()
    experiment_dir = script_dir / cfg.output_dir / cfg.experiment_name
    experiment_dir.mkdir(parents=True, exist_ok=True)
    existing_ver = list()
    for d in experiment_dir.iterdir():
        if d.name.startswith('v') and d.name[1:].isdecimal() and d.is_dir():
            existing_ver.append(int(d.name[1:]))
    if local_rank == 0:
        current_ver = max(existing_ver) + 1 if existing_ver else 0
        output_dir = experiment_dir / f'v{current_ver}'
        output_dir.mkdir()
    else:
        # Use the same directory for output with the main process.
        current_ver = max(existing_ver)
        output_dir = experiment_dir / f'v{current_ver}'

    pl_logger = logging.getLogger('lightning')
    logging.config.fileConfig(
        script_dir / 'logging.conf',
        disable_existing_loggers=False,
        defaults={'log_filename': output_dir / f'run_rank{local_rank}.log'})
    # Only the process with LOCAL_RANK = 0 will print logs on the console.
    # And all the processes will print logs in their own log files.
    if local_rank != 0:
        root_logger = logging.getLogger()
        root_logger.removeHandler(root_logger.handlers[0])

    pl_logger.info(f'Output logs & checkpoints in: {output_dir}')
    # Dump experiment configurations for reproducibility
    if local_rank == 0:
        with open(output_dir / 'cfg.yaml', 'w') as yaml_file:
            yaml_file.write(OmegaConf.to_yaml(cfg))
    pl_logger.info('The final experiment setup is dumped as: ./cfg.yaml')

    pl.seed_everything(cfg.seed, workers=True)

    # Create model
    net = load_obj(cfg.model.class_name,
                   'torchvision.models')(**cfg.model.params)
    pl_logger.info(
        f'Create model "{type(net)}". You can view its graph using TensorBoard.'
    )

    # Inject quantizers into the model
    net = nz.quantizer_inject(net, cfg.quan)
    quan_cnt, quan_dict = nz.quantizer_stat(net)
    msg = f'Inject {quan_cnt} quantizers into the model:'
    for k, v in quan_dict.items():
        msg += f'\n                {k} = {len(v)}'
    yaml.safe_dump(quan_dict, open(output_dir / 'quan_stat.yaml', 'w'))
    pl_logger.info(msg)
    pl_logger.info(
        'A complete list of injected quantizers is dumped as: ./quan_stat.yaml'
    )

    # Prepare the dataset
    dm = apputil.get_datamodule(cfg)
    pl_logger.info(
        f'Prepare the "{cfg.dataset.name}" dataset from: {cfg.dataset.data_dir}'
    )
    msg = f'The dataset samples are split into three sets:' \
          f'\n         Train = {len(dm.train_dataloader())} batches (batch size = {dm.train_dataloader().batch_size})' \
          f'\n           Val = {len(dm.val_dataloader())} batches (batch size = {dm.val_dataloader().batch_size})' \
          f'\n          Test = {len(dm.test_dataloader())} batches (batch size = {dm.test_dataloader().batch_size})'
    pl_logger.info(msg)

    progressbar_cb = apputil.ProgressBar(pl_logger)
    # gpu_stats_cb = pl.callbacks.GPUStatsMonitor()

    if cfg.checkpoint.path:
        assert Path(cfg.checkpoint.path).is_file(
        ), f'Checkpoint path is not a file: {cfg.checkpoint.path}'
        pl_logger.info(
            f'Resume training checkpoint from: {cfg.checkpoint.path}')

    if cfg.eval:
        pl_logger.info('Training process skipped. Evaluate the resumed model.')
        assert cfg.checkpoint.path is not None, 'Try to evaluate the model resumed from the checkpoint, but got None'

        # Initialize the Trainer
        trainer = pl.Trainer(callbacks=[progressbar_cb], **cfg.trainer)
        pl_logger.info(
            f'The model is distributed to {trainer.num_gpus} GPUs with {cfg.trainer.accelerator} backend.'
        )

        pretrained_lit = LitModuleWrapper.load_from_checkpoint(
            checkpoint_path=cfg.checkpoint.path, model=net, cfg=cfg)
        trainer.test(pretrained_lit, datamodule=dm, verbose=False)
    else:  # train + eval
        tb_logger = TensorBoardLogger(output_dir / 'tb_runs',
                                      name=cfg.experiment_name,
                                      log_graph=True)
        pl_logger.info('Tensorboard logger initialized in: ./tb_runs')

        lr_monitor_cb = pl.callbacks.LearningRateMonitor()
        checkpoint_cb = pl.callbacks.ModelCheckpoint(
            dirpath=output_dir / 'checkpoints',
            filename='{epoch}-{val_loss_epoch:.4f}-{val_acc_epoch:.4f}',
            monitor='val_loss_epoch',
            mode='min',
            save_top_k=3,
            save_last=True)
        pl_logger.info(
            'Checkpoints of the best 3 models as well as the last one will be saved to: ./checkpoints'
        )

        # Wrap model with LightningModule
        lit = LitModuleWrapper(net, cfg)
        # A fake input array for TensorBoard to generate graph
        lit.example_input_array = t.rand(dm.size()).unsqueeze(dim=0)

        # Initialize the Trainer
        trainer = pl.Trainer(
            logger=[tb_logger],
            callbacks=[checkpoint_cb, lr_monitor_cb, progressbar_cb],
            resume_from_checkpoint=cfg.checkpoint.path,
            plugins=DDPPlugin(find_unused_parameters=False),
            **cfg.trainer)
        pl_logger.info(
            f'The model is distributed to {trainer.num_gpus} GPUs with {cfg.trainer.accelerator} backend.'
        )

        pl_logger.info('Training process begins.')
        trainer.fit(model=lit, datamodule=dm)

        pl_logger.info('Evaluate the best trained model.')
        trainer.test(datamodule=dm, ckpt_path='best', verbose=False)

    pl_logger.info('Program completed successfully. Exiting...')
    pl_logger.info(
        'If you have any questions or suggestions, please visit: github.com/zhutmost/neuralzip'
    )
Example #21
0
def main(conf):
    train_set = PodcastMixDataloader(
        csv_dir=conf["data"]["train_dir"],
        sample_rate=conf["data"]["sample_rate"],
        original_sample_rate=conf["data"]["original_sample_rate"],
        segment=conf["data"]["segment"],
        shuffle_tracks=True,
        multi_speakers=conf["training"]["multi_speakers"])
    val_set = PodcastMixDataloader(
        csv_dir=conf["data"]["valid_dir"],
        sample_rate=conf["data"]["sample_rate"],
        original_sample_rate=conf["data"]["original_sample_rate"],
        segment=conf["data"]["segment"],
        shuffle_tracks=True,
        multi_speakers=conf["training"]["multi_speakers"])
    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf["training"]["batch_size"],
                              num_workers=conf["training"]["num_workers"],
                              drop_last=True,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf["training"]["batch_size"],
                            num_workers=conf["training"]["num_workers"],
                            drop_last=True,
                            pin_memory=True)

    if (conf["model"]["name"] == "ConvTasNet"):
        sys.path.append('ConvTasNet_model')
        from conv_tasnet_norm import ConvTasNetNorm
        conf["masknet"].update({"n_src": conf["data"]["n_src"]})
        model = ConvTasNetNorm(**conf["filterbank"],
                               **conf["masknet"],
                               sample_rate=conf["data"]["sample_rate"])
        loss_func = LogL2Time()
        plugins = None
    elif (conf["model"]["name"] == "UNet"):
        # UNet with logl2 time loss and normalization inside model
        sys.path.append('UNet_model')
        from unet_model import UNet
        model = UNet(conf["data"]["sample_rate"], conf["data"]["fft_size"],
                     conf["data"]["hop_size"], conf["data"]["window_size"],
                     conf["convolution"]["kernel_size"],
                     conf["convolution"]["stride"])
        loss_func = LogL2Time()
        plugins = DDPPlugin(find_unused_parameters=False)
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    if conf["training"]["half_lr"]:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["model"]["name"] + "_model/" + conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    callbacks.append(checkpoint)
    if conf["training"]["early_stop"]:
        callbacks.append(
            EarlyStopping(monitor="val_loss",
                          mode="min",
                          patience=100,
                          verbose=True))

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        gradient_clip_val=5.0,
        resume_from_checkpoint=conf["main_args"]["resume_from"],
        precision=32,
        plugins=plugins)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        print(best_k, f)
        json.dump(best_k, f, indent=0)
    print(checkpoint.best_model_path)
    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))