Example #1
0
def get_trainer(args: argparse.Namespace, config: DictConfig) -> Trainer:

    # amp
    precision = 16 if args.amp is not None else 32

    # logger
    if not args.debug:
        os.makedirs(config.OUTPUT_DIR, exist_ok=True)
        w_logger = WandbLogger(project=WANDB_PJ_NAME,
                               save_dir=config.OUTPUT_DIR,
                               name=config.OUTPUT_DIR)
        w_logger.log_hyperparams(OmegaConf.to_container(config))
    else:
        w_logger = False

    # checkpoint
    ckpt_callback = ModelCheckpoint(
        filename='{epoch:03d}-{rmse:.3f}-{delta1:.3f}',
        save_top_k=1,
        monitor='delta1',
        mode='max')

    return Trainer(
        max_epochs=config.SOLVER.EPOCH,
        callbacks=ckpt_callback,
        resume_from_checkpoint=args.resume,
        default_root_dir=config.OUTPUT_DIR,
        gpus=get_gpus(args),
        precision=precision,
        amp_level=args.amp,
        profiler=args.profiler,
        logger=w_logger,
        fast_dev_run=args.debug,
    )
Example #2
0
def main():
    import os
    import sys
    import wandb

    from pathlib import Path
    from config import setSeed, getConfig
    from main.curl import CURL
    import pytorch_lightning as pl
    from pytorch_lightning.loggers import WandbLogger

    from IPython import embed

    run = wandb.init()
    conf = getConfig(sys.argv[1])

    conf = update_custom(conf, run.config)

    wandb_logger = WandbLogger(project='mineRL',
                               name=run.name,
                               tags=[alg, 'sweep'])

    wandb_logger.log_hyperparams(conf)

    curl = CURL(conf)

    trainer = pl.Trainer(gpus=1,
                         max_epochs=conf['epochs'],
                         progress_bar_refresh_rate=20,
                         weights_summary='full',
                         logger=wandb_logger,
                         default_root_dir=f"./results/{conf['experiment']}")

    trainer.fit(curl)
Example #3
0
def main(args):

    wand_logger = WandbLogger(offline=False,
                              project='Transformer',
                              save_dir='./lightning_logs/')
    wand_logger.log_hyperparams(params=args)

    checkpoint = ModelCheckpoint(
        filepath='./lightning_logs/checkpoints/checkpoints',
        monitor='val_loss',
        verbose=0,
        save_top_k=2)

    model = TransformerModel(**vars(args))
    trainer = Trainer(
        logger=wand_logger,
        early_stop_callback=False,
        checkpoint_callback=checkpoint,
        # fast_dev_run=True,
        # overfit_pct=0.03,
        # profiler=True,
        auto_lr_find=False,
        # val_check_interval=1.0,
        # log_save_interval=50000,
        # row_log_interval=50000,
        max_epochs=args.epochs,
        min_epochs=1,
    )
    # lr_finder = trainer.lr_find(model)
    # print(lr_finder.results)
    trainer.fit(model)
Example #4
0
def cli_main():
    parser = argparse.ArgumentParser()

    # trainer args
    parser = pl.Trainer.add_argparse_args(parser)

    # model args
    parser = SLExperiment.add_model_specific_args(parser)
    args = parser.parse_args()

    model = SLExperiment(**args.__dict__)

    if args.resume:
        model.resume(args.resume)

    logger = None
    callbacks = []
    if not args.fast_dev_run:
        logger = WandbLogger(project="argumentation",
                             save_dir=str(config.root_dir),
                             tags=[args.tag])
        logger.log_hyperparams(args)

        # save checkpoints based on avg_reward
        checkpoint_callback = ModelCheckpoint(
            dirpath=logger.experiment.dir,
            save_top_k=1,
            monitor="validation/loss",
            mode="min",
            save_weights_only=True,
            verbose=True,
        )
        callbacks.append(checkpoint_callback)

        if args.tag:
            tag_checkpoint_callback = copy.deepcopy(checkpoint_callback)
            tag_checkpoint_callback.dirpath = model.model_dir
            tag_checkpoint_callback.filename = model.model_name
            callbacks.append(tag_checkpoint_callback)

    # early stopping
    if args.patience:
        early_stop_callback = EarlyStopping(monitor="validation/loss",
                                            patience=args.patience,
                                            mode="min",
                                            verbose=True)
        callbacks.append(early_stop_callback)

    pl.seed_everything(123)
    trainer = pl.Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=callbacks,
        track_grad_norm=2,
    )

    if args.train_ds and args.val_ds:
        trainer.fit(model)
    if args.test_ds:
        trainer.test(model)
Example #5
0
File: train.py Project: oleges1/kws
def train(config):
    fix_seeds(seed=config.train.seed)

    crnn = CRNNEncoder(
        in_channels=config.model.get('in_channels', 42),
        hidden_size=config.model.get('hidden_size', 16),
        dropout=config.model.get('dropout', 0.1),
        cnn_layers=config.model.get('cnn_layers', 2),
        rnn_layers=config.model.get('rnn_layers', 2),
        kernel_size=config.model.get('kernel_size', 9)
    )
    model = AttentionNet(
        crnn,
        hidden_size=config.model.get('hidden_size', 16),
        num_classes=config.model.get('num_classes', 3)
    )
    pl_model = KWSModel(
        model, lr=config.train.get('lr', 4e-5),
         in_channels=config.model.get('in_channels', 42),
         batch_size=config.train.get('batch_size', 32)
    )
    wandb_logger = WandbLogger(name=config.train.get('experiment_name', 'final_run'), project='kws-attention', log_model=True)
    wandb_logger.log_hyperparams(config)
    wandb_logger.watch(model, log='all', log_freq=100)
    trainer = pl.Trainer(max_epochs=config.train.get('max_epochs', 15), logger=wandb_logger, gpus=config.train.get('gpus', 1))
    trainer.fit(pl_model)
Example #6
0
def test_wandb_logger(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    logger = WandbLogger(anonymous=True, offline=True)

    logger.log_metrics({'acc': 1.0})
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

    wandb.init().log.reset_mock()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)

    # continue training on same W&B run
    wandb.init().step = 3
    logger.finalize('success')
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_with({'acc': 1.0}, step=6)

    logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {
            'test': 'None',
            'nested/a': 1,
            'b': [2, 3, 4]
        },
        allow_val_change=True,
    )

    logger.watch('model', 'log', 10)
    wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #7
0
def train_classifier(logging=False, train=True):

    hparams = {
        'gpus': [1],
        'max_epochs': 25,
        'num_classes': 700,
        'feature_dimension': 512,
        'model_dimension': 1024,
        'pretrained_text': False,
        'num_modalities': 1,
        'batch_size': 32,
        'learning_rate': 1e-3,
        'model_path':
        "/home/sgurram/Projects/aai/aai/experimental/sgurram/lava/src/wandb/run-20210626_215155-yqwe58z7/files/lava/yqwe58z7/checkpoints/epoch=6-step=12529.ckpt",
        'model_descriptor': 'lava timesformer 1/3 kinetics data, unshuffled',
        'accumulate_grad_batches': 2,
        'overfit_batches': 0,
        'type_modalities': 'av',
        'modality_fusion': 'concat',
        'loss_funtions': ['cross_entropy'],
        'metrics': None,
        'optimizer': 'adam',
        'scheduler': 'n/a',
        'profiler': 'simple',
        'default_root_dir': '/home/sgurram/Desktop/video_lava_classifer',
    }

    model = EvalLightning(
        num_classes=hparams['num_classes'],
        feature_dimension=hparams['feature_dimension'],
        model_dimension=hparams['model_dimension'],
        num_modalities=hparams['num_modalities'],
        batch_size=hparams['batch_size'],
        learning_rate=hparams['learning_rate'],
        model_path=hparams['model_path'],
        model=LAVALightning,
        pretrained_text=hparams['pretrained_text'],
    )

    if logging:
        wandb_logger = WandbLogger(name='run', project='lava')
        wandb_logger.log_hyperparams(hparams)
        wandb_logger.watch(model, log='gradients', log_freq=10)
    else:
        wandb_logger = None

    if not train:
        return model

    trainer = pl.Trainer(
        default_root_dir=hparams['default_root_dir'],
        gpus=hparams['gpus'],
        max_epochs=hparams['max_epochs'],
        accumulate_grad_batches=hparams['accumulate_grad_batches'],
        overfit_batches=hparams['overfit_batches'],
        logger=wandb_logger,
        profiler=hparams['profiler'])

    trainer.fit(model)
Example #8
0
def main() -> None:
    torch.manual_seed(config_dict["seed"])
    np.random.seed(config_dict["seed"])
    random.seed(config_dict["seed"])  # not sure if actually used
    np.random.seed(config_dict["seed"])

    config = wandb.config
    environment_config = config.env
    hparams = config.hps

    # TODO Hotfix because wandb doesn't support sweeps.
    if "lr" in config:
        hparams["lr"] = config.lr
        hparams["gamma"] = config.gamma

    logging.warning("CONFIG CHECK FOR SWEEP")
    logging.warning(
        hparams['lr']
    )  # todo aqui quede make sweep work something with imports.
    logging.warning(hparams['gamma'])

    #experiment_name = "dqn_onehot_few_warehouses_bigmreward_allvalid"
    wandb_logger = WandbLogger(
        project="rl_warehouse_assignment",
        name=experiment_name,
        tags=[
            "debug"
            # "experiment"
        ],
        log_model=
        False,  #todo sett this to true if you need the checkpoint models at some point
    )

    wandb_logger.log_hyperparams(dict(config))

    environment_parameters = network_flow_env_builder.build_network_flow_env_parameters(
        environment_config, hparams["episode_length"], order_gen="biased")

    model = DQNLightningOneHot(hparams, environment_parameters)

    trainer = pl.Trainer(
        max_epochs=hparams["max_episodes"] * hparams["replay_size"],
        early_stop_callback=False,
        val_check_interval=100,
        logger=wandb_logger,
        log_save_interval=1,
        row_log_interval=1,  # the default of this may leave info behind.
        callbacks=[
            MyPrintingCallback(),
            ShippingFacilityEnvironmentStorageCallback(
                experiment_name,
                base="data/results/",
                experiment_uploader=WandbDataUploader(),
            ),
        ],
    )

    trainer.fit(model)
def setup_wandb():
    wandb.login()
    experiment = wandb.init(project="real-text",
                            reinit=True,
                            notes=args.wandb_notes,
                            group=args.wandb_group)
    wandb_logger = WandbLogger(save_dir=ROOT_PATH, experiment=experiment)
    wandb_logger.log_hyperparams(args)
    return wandb_logger
def test_wandb_logger_init(wandb, recwarn):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""

    # test wandb.init called when there is no W&B run
    wandb.run = None
    logger = WandbLogger()
    logger.log_metrics({'acc': 1.0})
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

    # mock wandb step
    wandb.init().step = 0

    # test wandb.init not called if there is a W&B run
    wandb.init().log.reset_mock()
    wandb.init.reset_mock()
    wandb.run = wandb.init()
    logger = WandbLogger()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)

    # continue training on same W&B run and offset step
    wandb.init().step = 3
    logger.finalize('success')
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_with({'acc': 1.0}, step=6)

    # log hyper parameters
    logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {
            'test': 'None',
            'nested/a': 1,
            'b': [2, 3, 4]
        },
        allow_val_change=True,
    )

    # watch a model
    logger.watch('model', 'log', 10)
    wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

    # verify warning for logging at a previous step
    assert 'Trying to log at a previous step' not in get_warnings(recwarn)
    # current step from wandb should be 6 (last logged step)
    logger.experiment.step = 6
    # logging at step 2 should raise a warning (step_offset is still 3)
    logger.log_metrics({'acc': 1.0}, step=2)
    assert 'Trying to log at a previous step' in get_warnings(recwarn)
    # logging again at step 2 should not display again the same warning
    logger.log_metrics({'acc': 1.0}, step=2)
    assert 'Trying to log at a previous step' not in get_warnings(recwarn)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #11
0
def test_wandb_logger_init(wandb):
    """Verify that basic functionality of wandb logger works.

    Wandb doesn't work well with pytest so we have to mock it out here.
    """

    # test wandb.init called when there is no W&B run
    wandb.run = None
    logger = WandbLogger(
        name="test_name", save_dir="test_save_dir", version="test_id", project="test_project", resume="never"
    )
    logger.log_metrics({"acc": 1.0})
    wandb.init.assert_called_once_with(
        name="test_name", dir="test_save_dir", id="test_id", project="test_project", resume="never", anonymous=None
    )
    wandb.init().log.assert_called_once_with({"acc": 1.0})

    # test wandb.init and setting logger experiment externally
    wandb.run = None
    run = wandb.init()
    logger = WandbLogger(experiment=run)
    assert logger.experiment

    # test wandb.init not called if there is a W&B run
    wandb.init().log.reset_mock()
    wandb.init.reset_mock()
    wandb.run = wandb.init()
    logger = WandbLogger()

    # verify default resume value
    assert logger._wandb_init["resume"] == "allow"

    with pytest.warns(UserWarning, match="There is a wandb run already in progress"):
        _ = logger.experiment

    logger.log_metrics({"acc": 1.0}, step=3)
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({"acc": 1.0, "trainer/global_step": 3})

    # continue training on same W&B run and offset step
    logger.finalize("success")
    logger.log_metrics({"acc": 1.0}, step=6)
    wandb.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6})

    # log hyper parameters
    logger.log_hyperparams({"test": None, "nested": {"a": 1}, "b": [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {"test": "None", "nested/a": 1, "b": [2, 3, 4]}, allow_val_change=True
    )

    # watch a model
    logger.watch("model", "log", 10, False)
    wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #12
0
def test_wandb_logger_init(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""

    # test wandb.init called when there is no W&B run
    wandb.run = None
    logger = WandbLogger(
        name='test_name', save_dir='test_save_dir', version='test_id', project='test_project', resume='never'
    )
    logger.log_metrics({'acc': 1.0})
    wandb.init.assert_called_once_with(
        name='test_name', dir='test_save_dir', id='test_id', project='test_project', resume='never', anonymous=None
    )
    wandb.init().log.assert_called_once_with({'acc': 1.0})

    # test wandb.init and setting logger experiment externally
    wandb.run = None
    run = wandb.init()
    logger = WandbLogger(experiment=run)
    assert logger.experiment

    # test wandb.init not called if there is a W&B run
    wandb.init().log.reset_mock()
    wandb.init.reset_mock()
    wandb.run = wandb.init()
    logger = WandbLogger()
    # verify default resume value
    assert logger._wandb_init['resume'] == 'allow'
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})

    # continue training on same W&B run and offset step
    logger.finalize('success')
    logger.log_metrics({'acc': 1.0}, step=6)
    wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6})

    # log hyper parameters
    logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {
            'test': 'None',
            'nested/a': 1,
            'b': [2, 3, 4]
        },
        allow_val_change=True,
    )

    # watch a model
    logger.watch('model', 'log', 10)
    wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #13
0
def main(args):
    dict_args = vars(args)
    model_name = dict_args['model_name']

    if dict_args['dev_mode']:
        warn('You are in a DEVELOPMENT MODE!')

    if dict_args['gpus'] > 1:
        warn(
            '# gpu and num_workers should be 1, Not implemented: museval for distributed parallel'
        )
        dict_args['gpus'] = 1

    if model_name == 'cunet':
        model = CUNET_Framework(**dict_args)
    else:
        raise NotImplementedError

    if dict_args['log_system'] == 'wandb':
        logger = WandbLogger(project='source_separation',
                             tags=model_name,
                             offline=False,
                             id=dict_args['run_id'] + 'eval')
        logger.log_hyperparams(model.hparams)

    elif dict_args['log_system'] == 'tensorboard':
        if not os.path.exists(temp_args.tensorboard_path):
            os.mkdir(temp_args.tensorboard_path)
        logger = pl_loggers.TensorBoardLogger(temp_args.tensorboard_path,
                                              name=model_name)
    else:
        logger = True  # default

    ckpt_path = '{}/{}/{}/{}_epoch={}.ckpt'.format(
        dict_args['checkpoints_path'], dict_args['model_name'],
        dict_args['run_id'], dict_args['model_name'], dict_args['epoch'])

    assert (ckpt_path is not None)
    model = model.load_from_checkpoint(ckpt_path)

    data_provider = DataProvider(**dict_args)
    n_fft, hop_length, num_frame = [
        dict_args[key] for key in ['n_fft', 'hop_length', 'num_frame']
    ]
    test_dataloader = data_provider.get_test_dataloader(
        n_fft, hop_length, num_frame)

    trainer = Trainer(gpus=dict_args['gpus'],
                      logger=logger,
                      precision=16 if dict_args['float16'] else 32)

    trainer.test(model, test_dataloader)
Example #14
0
def build_experiment_trainer(config_dict, mode="experiment"):
    """
    :param config_dict: configuration dictionary for experiment
    :param mode: "experiment or debug"
    :return: the PTL Module
    """
    # Initialize seeds for reproducibility
    torch.manual_seed(config_dict['seed'])
    np.random.seed(config_dict['seed'])
    random.seed(config_dict['seed'])  # not sure if actually used
    np.random.seed(config_dict['seed'])

    # Initialize wandb
    run = wandb.init(config=config_dict)

    # extract config subdictionaries.
    config = wandb.config
    environment_config = config.env
    hparams = config.hps

    # Initialize PTL W&B Logger
    experiment_name = "dqn_few_warehouses_v4__demandgen_biased"
    wandb_logger = WandbLogger(
        project="rl_warehouse_assignment",
        name=experiment_name,
        tags=[
            # "debug"
            "experiment"
        ],
        log_model=True)
    wandb_logger.log_hyperparams(dict(config))

    # Initialize
    environment_parameters = network_flow_env_builder.build_network_flow_env_parameters(
        environment_config, hparams['episode_length'], order_gen='biased')

    model = DQNLightning(hparams, environment_parameters)

    trainer = pl.Trainer(
        max_epochs=hparams['max_epochs'],
        early_stop_callback=False,
        val_check_interval=100,
        logger=wandb_logger,
        log_save_interval=1,
        row_log_interval=1,  # the default of this may leave info behind.
        callbacks=[
            MyPrintingCallback(),
            ShippingFacilityEnvironmentStorageCallback(
                experiment_name,
                base="data/results/",
                experiment_uploader=WandbDataUploader())
        ])
Example #15
0
def main(arg):
    seed_everything(42)
    model = PLModel(arg)
    wandb_logger = WandbLogger(project="Bachelorarbeit", name=arg.name)
    wandb_logger.watch(model)
    wandb_logger.log_hyperparams(arg)
    trainer = Trainer(gpus=2,
                      logger=wandb_logger,
                      distributed_backend='ddp',
                      deterministic=True,
                      auto_select_gpus=True,
                      num_sanity_val_steps=0)
    trainer.fit(model)
def main(cfg: DictConfig):
    seed_everything(42)
    logger = WandbLogger(**cfg.logger)
    logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True))
    checkpoint = ModelCheckpoint(**cfg.checkpoint, dirpath=logger.save_dir)
    trainer = Trainer(
        **cfg.trainer,
        logger=logger,
        callbacks=checkpoint,
        plugins=DDPPlugin(find_unused_parameters=True)
    )
    task = instantiate(cfg.task)
    datamodule = instantiate(cfg.data)
    trainer.fit(model=task, datamodule=datamodule)
Example #17
0
def main(cfg: DictConfig):
    print('Train CycleGAN Model')
    cur_dir = hydra.utils.get_original_cwd()
    os.chdir(cur_dir)
    seed_everything(cfg.train.seed)

    # Init asset dir  --------------------------------------------------
    try:
        # Remove checkpoint folder
        shutil.rmtree(cfg.data.asset_dir)
    except:
        pass

    os.makedirs(cfg.data.asset_dir, exist_ok=True)

    # Logger  --------------------------------------------------
    wandb.login()
    logger = WandbLogger(project='AI-Painter', reinit=True)
    logger.log_hyperparams(dict(cfg.data))
    logger.log_hyperparams(dict(cfg.train))

    # Transforms  --------------------------------------------------
    transform = ImageTransform(cfg.data.img_size)

    # DataModule  --------------------------------------------------
    dm = DataModule(cfg, transform, phase='train')

    # Model Networks  --------------------------------------------------
    nets = {
        'G_basestyle': init_weights(CycleGAN_Unet_Generator(), init_type='normal'),
        'G_stylebase': init_weights(CycleGAN_Unet_Generator(), init_type='normal'),
        'D_base': init_weights(CycleGAN_Discriminator(), init_type='normal'),
        'D_style': init_weights(CycleGAN_Discriminator(), init_type='normal'),
    }

    # Lightning System  --------------------------------------------------
    model = CycleGAN_LightningSystem(cfg, transform, **nets)

    # Train  --------------------------------------------------
    trainer = Trainer(
        logger=logger,
        max_epochs=cfg.train.epoch,
        gpus=1,
        reload_dataloaders_every_epoch=True,
        num_sanity_val_steps=0,  # Skip Sanity Check
    )

    trainer.fit(model, datamodule=dm)
Example #18
0
def test_wandb_logger_init(wandb, recwarn):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""

    # test wandb.init called when there is no W&B run
    wandb.run = None
    logger = WandbLogger()
    logger.log_metrics({'acc': 1.0})
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({'acc': 1.0})

    # test wandb.init not called if there is a W&B run
    wandb.init().log.reset_mock()
    wandb.init.reset_mock()
    wandb.run = wandb.init()
    logger = WandbLogger()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({
        'acc': 1.0,
        'trainer/global_step': 3
    })

    # continue training on same W&B run and offset step
    logger.finalize('success')
    logger.log_metrics({'acc': 1.0}, step=6)
    wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6})

    # log hyper parameters
    logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {
            'test': 'None',
            'nested/a': 1,
            'b': [2, 3, 4]
        },
        allow_val_change=True,
    )

    # watch a model
    logger.watch('model', 'log', 10)
    wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #19
0
def main(hparams):
    """
    Main testing routine specific for this project

    :param hparams: Namespace containing configuration values
    :type hparams: Namespace
    """

    # ------------------------
    # 1 INIT MODEL
    # ------------------------

    model = get_model(hparams)
    model.load_state_dict(torch.load(hparams.checkpoint_file)["state_dict"])
    model.eval()

    name = "-".join([hparams.model, hparams.out, "-test"])

    # ------------------------
    # LOGGING SETUP
    # ------------------------

    tb_logger = TensorBoardLogger(save_dir="logs/tb_logs/", name=name)
    tb_logger.experiment.add_graph(model, model.data[0][0].unsqueeze(0))
    wandb_logger = WandbLogger(
        name=hparams.comment if hparams.comment else time.ctime(),
        project=name,
        save_dir="logs",
    )
    wandb_logger.watch(model, log="all", log_freq=200)
    wandb_logger.log_hyperparams(model.hparams)
    for file in [
            i for s in
        [glob(x) for x in ["*.py", "dataloader/*.py", "model/*.py"]] for i in s
    ]:
        shutil.copy(file, wandb.run.dir)

    trainer = pl.Trainer(gpus=hparams.gpus,
                         logger=[wandb_logger])  # , tb_logger],

    # ------------------------
    # 3 START TESTING
    # ------------------------

    trainer.test(model)
Example #20
0
def main():

    with open("../lightning_modules/GNNEmbedding/train_jet_gnn.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

    model = LocalAttentionNodeEmbedding(hparams)
    wandb_logger = WandbLogger(project="End2End-ConnectedJetNodeEmbedding")
    wandb_logger.watch(model)
    wandb_logger.log_hyperparams({"model": type(model)})
    trainer = Trainer(
        gpus=1,
        max_epochs=hparams["max_epochs"],
        logger=wandb_logger,
        num_sanity_val_steps=0,
        accumulate_grad_batches=1,
    )

    trainer.fit(model)
def test_wandb_logger(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    logger = WandbLogger(anonymous=True, offline=True)

    logger.log_metrics({'acc': 1.0})
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

    wandb.init().log.reset_mock()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)

    logger.log_hyperparams({'test': None})
    wandb.init().config.update.assert_called_once_with({'test': None}, allow_val_change=True)

    logger.watch('model', 'log', 10)
    wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #22
0
def run(config):
    pl.seed_everything(config.seed)

    if config.logger:
        from pytorch_lightning.loggers import WandbLogger
        name = f"{config.model_name}-{config.backbone.name}-{config.dataset.src_task}-{config.dataset.tgt_task}"
        logger = WandbLogger(
            project=f"{config.project}",
            name=name,
        )
    else:
        logger = pl.loggers.TestTubeLogger("output", name=f"{config.project}")
        logger.log_hyperparams(config)

    datamodule = get_datamodule(config.dataset,
                                batch_size=config.training.batch_size)
    model = get_model(config, len(datamodule.CLASSES))
    trainer = pl.Trainer(
        precision=16,
        auto_lr_find=True if config.lr_finder else None,
        deterministic=True,
        check_val_every_n_epoch=1,
        gpus=config.gpus,
        logger=logger,
        max_epochs=config.training.epoch,
        weights_summary="top",
    )

    if config.lr_finder:
        lr_finder = trainer.tuner.lr_find(model,
                                          min_lr=1e-8,
                                          max_lr=1e-1,
                                          num_training=100)
        model.hparams.lr = lr_finder.suggestion()
        print(model.hparams.lr)
    else:
        trainer.fit(model, datamodule=datamodule)
        trainer.test()
Example #23
0
def test_wandb_logger(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    tutils.reset_seed()

    logger = WandbLogger(anonymous=True, offline=True)

    logger.log_metrics({'acc': 1.0})
    wandb.init().log.assert_called_once_with({'acc': 1.0})

    wandb.init().log.reset_mock()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})

    logger.log_hyperparams({'test': None})
    wandb.init().config.update.assert_called_once_with({'test': None})

    logger.watch('model', 'log', 10)
    wandb.watch.assert_called_once_with('model', log='log', log_freq=10)

    logger.finalize('fail')
    wandb.join.assert_called_once_with(1)

    wandb.join.reset_mock()
    logger.finalize('success')
    wandb.join.assert_called_once_with(0)

    wandb.join.reset_mock()
    wandb.join.side_effect = TypeError
    with pytest.raises(TypeError):
        logger.finalize('any')

    wandb.join.assert_called()

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #24
0
def main(cfg: DictConfig):
    print('VinBigData Training Classification')
    cur_dir = hydra.utils.get_original_cwd()
    os.chdir(cur_dir)
    # Config  -------------------------------------------------------------------
    data_dir = cfg.data.data_dir
    seed_everything(cfg.data.seed)

    load_dotenv('.env')
    wandb.login()
    wandb_logger = WandbLogger(project='VinBigData-Classification', reinit=True)
    wandb_logger.log_hyperparams(dict(cfg.data))
    wandb_logger.log_hyperparams(dict(cfg.train))
    wandb_logger.log_hyperparams(dict(cfg.aug_kwargs_classification))

    # Data Module  -------------------------------------------------------------------
    transform = ImageTransform(cfg, type='classification')
    cv = StratifiedKFold(n_splits=cfg.data.n_splits)
    dm = ChestXrayDataModule(data_dir, cfg, transform, cv, data_type='classification', sample=False)

    # Model  -----------------------------------------------------------
    net = Timm_model(cfg.train.backbone, out_dim=1)

    # Loss fn  -----------------------------------------------------------
    criterion = nn.BCEWithLogitsLoss()

    # Optimizer, Scheduler  -----------------------------------------------------------
    optimizer = optim.Adam(net.parameters(), lr=cfg.train.lr)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.train.epoch, eta_min=0)
    # Lightning Module
    model = XrayLightningClassification(net, cfg, criterion, optimizer, scheduler)

    # Trainer  --------------------------------------------------------------------------
    trainer = Trainer(
        logger=wandb_logger,
        log_every_n_steps=100,
        max_epochs=cfg.train.epoch,
        gpus=-1,
        num_sanity_val_steps=0,
        # deterministic=True,
        amp_level='O2',
        amp_backend='apex'
    )

    # Train
    trainer.fit(model, datamodule=dm)

    # Stop Logging
    wandb.finish()

    for p in model.weight_paths:
        os.remove(p)
def train(hparams):
    EMBEDDING_DIM = 128
    NUM_GPUS = hparams.num_gpus
    batch_order = 11

    dataset = load_node_dataset(hparams.dataset, hparams.method, hparams=hparams, train_ratio=hparams.train_ratio)

    METRICS = ["precision", "recall", "f1", "accuracy", "top_k" if dataset.multilabel else "ogbn-mag", ]

    if hparams.method == "HAN":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "batch_size": 2 ** batch_order * NUM_GPUS,
            "num_layers": 2,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.0005 * NUM_GPUS,
        }
        model = HAN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif hparams.method == "GTN":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "num_channels": len(dataset.metapaths),
            "num_layers": 2,
            "batch_size": 2 ** batch_order * NUM_GPUS,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.0005 * NUM_GPUS,
        }
        model = GTN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif hparams.method == "MetaPath2Vec":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "walk_length": 50,
            "context_size": 7,
            "walks_per_node": 5,
            "num_negative_samples": 5,
            "sparse": True,
            "batch_size": 400 * NUM_GPUS,
            "train_ratio": dataset.train_ratio,
            "n_classes": dataset.n_classes,
            "lr": 0.01 * NUM_GPUS,
        }
        model = MetaPath2Vec(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif "LATTE" in hparams.method:
        USE_AMP = False
        num_gpus = 1

        if "-1" in hparams.method:
            t_order = 1
        elif "-2" in hparams.method:
            t_order = 2
        elif "-3" in hparams.method:
            t_order = 3
        else:
            t_order = 2

        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "t_order": t_order,
            "batch_size": 2 ** batch_order * max(num_gpus, 1),
            "nb_cls_dense_size": 0,
            "nb_cls_dropout": 0.4,
            "activation": "relu",
            "attn_heads": 2,
            "attn_activation": "sharpening",
            "attn_dropout": 0.2,
            "loss_type": "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "use_proximity": True if "proximity" in hparams.method else False,
            "neg_sampling_ratio": 2.0,
            "n_classes": dataset.n_classes,
            "use_class_weights": False,
            "lr": 0.001 * num_gpus,
            "momentum": 0.9,
            "weight_decay": 1e-2,
        }

        metrics = ["precision", "recall", "micro_f1",
                   "accuracy" if dataset.multilabel else "ogbn-mag", "top_k"]

        model = LATTENodeClassifier(Namespace(**model_hparams), dataset, collate_fn="neighbor_sampler", metrics=metrics)

    MAX_EPOCHS = 250
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               project="multiplex-comparison")
    wandb_logger.log_hyperparams(model_hparams)

    trainer = Trainer(
        gpus=NUM_GPUS, auto_select_gpus=True,
        distributed_backend='dp' if NUM_GPUS > 1 else None,
        max_epochs=MAX_EPOCHS,
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, min_delta=0.0001, strict=False)],
        logger=wandb_logger,
        weights_summary='top',
        amp_level='O1' if USE_AMP else None,
        precision=16 if USE_AMP else 32
    )

    # trainer.fit(model)
    trainer.fit(model, train_dataloader=model.valtrain_dataloader(), val_dataloaders=model.test_dataloader())
    trainer.test(model)
Example #26
0
def main(hparams):
    """
    Main training routine specific for this project

    :param hparams: Namespace containing configuration values
    :type hparams: Namespace
    """

    # ------------------------
    # 1 INIT MODEL
    # ------------------------

    # Prepare model and link it with the data
    model = get_model(hparams)

    # Categorize logging
    name = hparams.model + "-" + hparams.out

    # Callback to save checkpoint of best performing model
    checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint(
        filepath=f"src/model/checkpoints/{name}/",
        monitor="val_loss",
        verbose=True,
        save_top_k=1,
        save_weights_only=False,
        mode="min",
        period=1,
        prefix="-".join([
            str(x) for x in (
                name,
                hparams.in_days,
                hparams.out_days,
                datetime.now().strftime("-%m/%d-%H:%M"),
            )
        ]),
    )

    # ------------------------
    # LOGGING SETUP
    # ------------------------

    # Enable logging only during training
    if not hparams.dry_run:
        # tb_logger = TensorBoardLogger(save_dir="logs/tb_logs/", name=name)
        # tb_logger.experiment.add_graph(model, model.data[0][0].unsqueeze(0))
        wandb_logger = WandbLogger(
            name=hparams.comment if hparams.comment else time.ctime(),
            project=name,
            save_dir="logs",
        )
        # if not hparams.test:
        #     wandb_logger.watch(model, log="all", log_freq=100)
        wandb_logger.log_hyperparams(model.hparams)
        for file in [
                i for s in [
                    glob(x) for x in
                    ["src/*.py", "src/dataloader/*.py", "src/model/*.py"]
                ] for i in s
        ]:
            shutil.copy(file, wandb.run.dir)

    # ------------------------
    # INIT TRAINER
    # ------------------------

    trainer = pl.Trainer(
        auto_lr_find=False,
        # progress_bar_refresh_rate=0,
        # Profiling the code to find bottlenecks
        # profiler=pl.profiler.AdvancedProfiler('profile'),
        max_epochs=hparams.epochs if not hparams.dry_run else 1,
        # CUDA trick to speed up training after the first epoch
        benchmark=True,
        deterministic=False,
        # Sanity checks
        # fast_dev_run=False,
        # overfit_pct=0.01,
        gpus=hparams.gpus,
        precision=16 if hparams.use_16bit and hparams.gpus else 32,
        # Alternative method for 16-bit training
        # amp_level="O2",
        logger=None if hparams.dry_run else [wandb_logger],  # , tb_logger],
        checkpoint_callback=None if hparams.dry_run else checkpoint_callback,
        # Using maximum GPU memory. NB: Learning rate should be adjusted according to
        # the batch size
        # auto_scale_batch_size='binsearch',
    )

    # ------------------------
    # LR FINDER
    # ------------------------

    if hparams.find_lr:
        # Run learning rate finder
        lr_finder = trainer.lr_find(model)

        # Results can be found in
        lr_finder.results

        # Plot with
        fig = lr_finder.plot(suggest=True)
        fig.show()

        # Pick point based on plot, or get suggestion
        new_lr = lr_finder.suggestion()

        # update hparams of the model
        model.hparams.learning_rate = new_lr

    # ------------------------
    # BATCH SIZE SEARCH
    # ------------------------

    if hparams.search_bs:
        # Invoke the batch size search using a sophisticated algorithm.
        new_batch_size = trainer.scale_batch_size(model,
                                                  mode="binary",
                                                  steps_per_trial=50,
                                                  init_val=1,
                                                  max_trials=10)

        # Override old batch size
        model.hparams.batch_size = new_batch_size

    # ------------------------
    # 3 START TRAINING
    # ------------------------

    # Interrupt training anytime and continue to test
    signal.signal(signal.SIGINT or 255, trainer.test)

    trainer.fit(model)
    results = trainer.test()

    return results
Example #27
0
def train(hparams):
    EMBEDDING_DIM = 128
    USE_AMP = None
    NUM_GPUS = hparams.num_gpus
    MAX_EPOCHS = 1000
    batch_order = 11

    dataset = load_node_dataset(hparams.dataset,
                                hparams.method,
                                hparams=hparams,
                                train_ratio=hparams.train_ratio)

    METRICS = [
        "precision",
        "recall",
        "f1",
        "accuracy",
        "top_k" if dataset.multilabel else "ogbn-mag",
    ]

    if hparams.method == "HAN":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "batch_size": 2**batch_order,
            "num_layers": 2,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY"
            if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.001,
        }
        model = HAN(Namespace(**model_hparams),
                    dataset=dataset,
                    metrics=METRICS)
    elif hparams.method == "GTN":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "num_channels": len(dataset.metapaths),
            "num_layers": 2,
            "batch_size": 2**batch_order,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY"
            if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.001,
        }
        model = GTN(Namespace(**model_hparams),
                    dataset=dataset,
                    metrics=METRICS)

    elif hparams.method == "MetaPath2Vec":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "walk_length": 50,
            "context_size": 7,
            "walks_per_node": 5,
            "num_negative_samples": 5,
            "sparse": True,
            "batch_size": 400,
            "train_ratio": dataset.train_ratio,
            "n_classes": dataset.n_classes,
            "lr": 0.01,
        }
        model = MetaPath2Vec(Namespace(**model_hparams),
                             dataset=dataset,
                             metrics=METRICS)

    elif hparams.method == "HGT":
        USE_AMP = False
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "num_channels": len(dataset.metapaths),
            "n_layers": 2,
            "attn_heads": 8,
            "attn_dropout": 0.2,
            "prev_norm": True,
            "last_norm": True,
            "nb_cls_dense_size": 0,
            "nb_cls_dropout": 0.0,
            "use_class_weights": False,
            "batch_size": 2**batch_order,
            "n_epoch": MAX_EPOCHS,
            "train_ratio": dataset.train_ratio,
            "loss_type":
            "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "collate_fn": "collate_HGT_batch",
            "lr": 0.001,  # Not used here, defaults to 1e-3
        }
        model = HGT(Namespace(**model_hparams), dataset, metrics=METRICS)

    elif "LATTE" in hparams.method:
        USE_AMP = False
        num_gpus = 1

        if "-1" in hparams.method:
            n_layers = 1
        elif "-2" in hparams.method:
            n_layers = 2
        elif "-3" in hparams.method:
            n_layers = 3
        else:
            n_layers = 2

        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "layer_pooling": "concat",
            "n_layers": n_layers,
            "batch_size": 2**batch_order,
            "nb_cls_dense_size": 0,
            "nb_cls_dropout": 0.4,
            "activation": "relu",
            "dropout": 0.2,
            "attn_heads": 2,
            "attn_activation": "sharpening",
            "batchnorm": False,
            "layernorm": False,
            "edge_sampling": False,
            "edge_threshold": 0.5,
            "attn_dropout": 0.2,
            "loss_type":
            "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "use_proximity": True if "proximity" in hparams.method else False,
            "neg_sampling_ratio": 2.0,
            "n_classes": dataset.n_classes,
            "use_class_weights": False,
            "lr": 0.001,
            "momentum": 0.9,
            "weight_decay": 1e-2,
        }

        model_hparams.update(hparams.__dict__)

        metrics = [
            "precision", "recall", "micro_f1", "macro_f1",
            "accuracy" if dataset.multilabel else "ogbn-mag", "top_k"
        ]

        model = LATTENodeClf(Namespace(**model_hparams),
                             dataset,
                             collate_fn="neighbor_sampler",
                             metrics=metrics)

    MAX_EPOCHS = 250
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               anonymous=True,
                               project="anon-demo")
    wandb_logger.log_hyperparams(model_hparams)

    trainer = Trainer(gpus=NUM_GPUS,
                      distributed_backend='dp' if NUM_GPUS > 1 else None,
                      max_epochs=MAX_EPOCHS,
                      stochastic_weight_avg=True,
                      callbacks=[
                          EarlyStopping(monitor='val_loss',
                                        patience=10,
                                        min_delta=0.0001,
                                        strict=False)
                      ],
                      logger=wandb_logger,
                      weights_summary='top',
                      amp_level='O1' if USE_AMP and NUM_GPUS > 0 else None,
                      precision=16 if USE_AMP else 32)
    trainer.fit(model)

    model.register_hooks()
    trainer.test(model)

    wandb_logger.log_metrics(
        model.clustering_metrics(n_runs=10, compare_node_types=False))
Example #28
0
    if args.accelerator == "TPU":
        args.global_batch_size = args.batch_size * args.n_tpu_cores

    if args.debug:
        import ptvsd
        ptvsd.enable_attach(address=('localhost', 5678), redirect_output=True)
        ptvsd.wait_for_attach()
        breakpoint()

    if args.accelerator == 'TPU':
        import torch_xla.core.xla_model as xm

    wandb.login()
    experiment = wandb.init(project="lm-finetuning", reinit=True)
    wandb_logger = WandbLogger(experiment=experiment)
    wandb_logger.log_hyperparams(args)

    early_stopping_callback = EarlyStopping(monitor='val_loss', patience=10)
    checkpoint_callback = ModelCheckpoint(wandb.run.dir, save_top_k=-1)

    model = LM(args)
    trainer = pl.Trainer(max_epochs=args.epochs,
                         accumulate_grad_batches=args.grad_steps,
                         gpus=args.n_gpus,
                         num_tpu_cores=args.n_tpu_cores,
                         precision=args.precision,
                         amp_level=args.apex_mode,
                         resume_from_checkpoint=args.checkpoint,
                         logger=wandb_logger,
                         track_grad_norm=args.track_grad_norm,
                         fast_dev_run=args.debug_run,
Example #29
0
import argparse
import os
import yaml
from easydict import EasyDict as edict
from tacotron2 import trainer
from waveglow import Vocoder
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Training model.')
    parser.add_argument('--config',
                        default='tacotron2/configs/ljspeech_tacotron.yml',
                        help='path to config file')
    args = parser.parse_args()
    with open(args.config, 'r') as stream:
        config = edict(yaml.safe_load(stream))
    pl_model = getattr(trainer, config.train.trainer)(config, Vocoder=Vocoder)
    wandb_logger = WandbLogger(
        name='final_kiss',
        project=os.path.basename(config_path).split('.')[0],
        log_model=True)
    wandb_logger.log_hyperparams(config)
    wandb_logger.watch(pl_model.model, log='all', log_freq=100)
    trainer = pl.Trainer(logger=wandb_logger, **config.train.trainer_args)
    trainer.fit(pl_model)
Example #30
0
def main(cfg: DictConfig):
    # instantiate Wandb Logger
    wandblogger = WandbLogger(project=cfg.general.project_name,
                              log_model=True,
                              name=cfg.training.job_name)
    # Log Hyper-parameters to Wandb
    wandblogger.log_hyperparams(cfg)

    # set random seeds so that results are reproducible
    seed_everything(cfg.training.random_seed)

    # generate a random idx for the job
    if cfg.training.unique_idx is None:
        cfg.training.unique_idx = generate_random_id()

    uq_id = cfg.training.unique_idx
    model_name = f"{cfg.training.encoder}-fold={cfg.training.fold}-{uq_id}"

    # Set up Callbacks to assist in Training
    cbs = [
        WandbTask(),
        DisableValidationBar(),
        LogInformationCallback(),
        LearningRateMonitor(logging_interval="step"),
    ]

    if cfg.training.patience is not None:
        cbs.append(
            EarlyStopping(monitor="valid/acc",
                          patience=cfg.training.patience,
                          mode="max"))

    checkpointCallback = ModelCheckpoint(
        monitor="valid/acc",
        save_top_k=1,
        mode="max",
    )
    # set up trainder kwargs
    kwds = dict(checkpoint_callback=checkpointCallback,
                callbacks=cbs,
                logger=wandblogger)

    trainer = instantiate(cfg.trainer, **kwds)

    # set up cassava image classification Task
    model = Task(cfg)

    trainer.fit(model)

    # Laod in the best checkpoint and save the model weights
    checkpointPath = checkpointCallback.best_model_path
    # Testing Stage
    _ = trainer.test(verbose=True, ckpt_path=checkpointPath)

    # load in the best model weights
    model = Task.load_from_checkpoint(checkpointPath)

    # create model save dir to save the weights of the
    # vanilla torch-model
    os.makedirs(cfg.general.save_dir, exist_ok=True)
    path = os.path.join(cfg.general.save_dir, f"{model_name}.pt")
    # save the weights of the model
    torch.save(model.model.state_dict(), f=path)
    # upload trained weights to wandb
    wandb.save(path)

    # save the original compiles config file to wandb
    conf_path = os.path.join(cfg.general.save_dir, "cfg.yml")
    OmegaConf.save(cfg, f=conf_path)
    wandb.save(conf_path)