Example #1
0
    dm = WikiDataModule(tokenizer_name_or_path='wiki-test',
                        files=["wikitext-103-raw/wiki.train.raw"],
                        max_vocab_size=30000,
                        min_frequency=2,
                        special_tokens=[
                            "<s>",
                            "<pad>",
                            "</s>",
                            "<unk>",
                            "<mask>",
                        ],
                        batch_size=args.batch_size)
    dm.prepare_data()
    dm.setup('fit')
    tokenizer = dm.tokenizer
    config = RobertaConfig(
        hidden_size=768,
        intermediate_size=3072,
        num_attention_heads=12,
        num_hidden_layers=12,
        vocab_size=len(tokenizer.get_vocab()),
    )
    hf_model = RobertaForMaskedLM(config)
    model = Pretrainer(hf_model, config, tokenizer)
    trainer = pl.Trainer.from_argparse_args(args, logger=WandbLogger())
    model.hparams.total_steps = (
        (len(dm.ds['train']) //
         (args.batch_size * max(1, (trainer.num_gpus or 0)))) //
        trainer.accumulate_grad_batches * float(trainer.max_epochs))
    trainer.fit(model, dm)
Example #2
0
def test_wandb_logger_init(wandb):
    """Verify that basic functionality of wandb logger works.

    Wandb doesn't work well with pytest so we have to mock it out here.
    """

    # test wandb.init called when there is no W&B run
    wandb.run = None
    logger = WandbLogger(
        name="test_name", save_dir="test_save_dir", version="test_id", project="test_project", resume="never"
    )
    logger.log_metrics({"acc": 1.0})
    wandb.init.assert_called_once_with(
        name="test_name", dir="test_save_dir", id="test_id", project="test_project", resume="never", anonymous=None
    )
    wandb.init().log.assert_called_once_with({"acc": 1.0})

    # test wandb.init and setting logger experiment externally
    wandb.run = None
    run = wandb.init()
    logger = WandbLogger(experiment=run)
    assert logger.experiment

    # test wandb.init not called if there is a W&B run
    wandb.init().log.reset_mock()
    wandb.init.reset_mock()
    wandb.run = wandb.init()
    logger = WandbLogger()

    # verify default resume value
    assert logger._wandb_init["resume"] == "allow"

    with pytest.warns(UserWarning, match="There is a wandb run already in progress"):
        _ = logger.experiment

    logger.log_metrics({"acc": 1.0}, step=3)
    wandb.init.assert_called_once()
    wandb.init().log.assert_called_once_with({"acc": 1.0, "trainer/global_step": 3})

    # continue training on same W&B run and offset step
    logger.finalize("success")
    logger.log_metrics({"acc": 1.0}, step=6)
    wandb.init().log.assert_called_with({"acc": 1.0, "trainer/global_step": 6})

    # log hyper parameters
    logger.log_hyperparams({"test": None, "nested": {"a": 1}, "b": [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {"test": "None", "nested/a": 1, "b": [2, 3, 4]}, allow_val_change=True
    )

    # watch a model
    logger.watch("model", "log", 10, False)
    wandb.init().watch.assert_called_once_with("model", log="log", log_freq=10, log_graph=False)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #3
0
def main():

    df_train = pd.read_csv('../data/train.csv')
    df_train['type'] = df_train['before_file_path'].apply(
        lambda x: 'BC' if 'BC' in x else 'LT')
    # df_train['before_file_path'] = df_train['before_file_path'].apply(lambda x: x.replace('.png', '_resize256.png'))
    # df_train['after_file_path'] = df_train['after_file_path'].apply(lambda x: x.replace('.png', '_resize256.png'))
    df_train['splits'] = df_train['before_file_path'].apply(
        lambda x: x.split('adjust/')[-1][:5]
    )  # + '_' + df_train['time_delta'].astype(str)

    train1 = df_train[df_train['type'] == 'BC'].reset_index(drop=True)
    train2 = df_train[df_train['type'] == 'LT'].reset_index(drop=True)

    # df_train = df_train[~df_train['splits'].isin(['BC_03', 'BC_04', 'LT_08', 'LT_05'])].reset_index(drop=True)
    print(df_train.splits.value_counts())

    if config.type == 'BC':
        df_train = df_train[df_train['type'] == 'BC'].reset_index(drop=True)
    elif config.type == 'LT':
        df_train = df_train[df_train['type'] == 'LT'].reset_index(drop=True)

    # skf = StratifiedKFold(n_splits=config.k, random_state=config.seed, shuffle=True)
    # n_splits = list(skf.split(df_train, df_train['splits']))

    # df_train['time_delta'] = np.log(df_train['time_delta'])

    gk = GroupKFold(n_splits=config.k)
    # n_splits = list(gk.split(df_train, y=df_train['time_delta'], groups=df_train['splits']))
    n_splits = list(
        gk.split(train1, y=train1['time_delta'], groups=train1['splits']))
    n_splits2 = list(
        gk.split(train2, y=train2['time_delta'], groups=train2['splits']))
    train1['n_fold'] = -1
    train2['n_fold'] = -1
    for i in range(config.k):
        train1.loc[n_splits[i][1], 'n_fold'] = i
        train2.loc[n_splits2[i][1], 'n_fold'] = i
    # df_train['n_fold'] = -1
    # for i in range(config.k):
    #     df_train.loc[n_splits[i][1], 'n_fold'] = i
    # print(df_train['n_fold'].value_counts())

    for fold in config.training_folds:
        config.start_time = time.strftime('%Y-%m-%d %H:%M',
                                          time.localtime(time.time())).replace(
                                              ' ', '_')

        logger = WandbLogger(
            name=f"{config.start_time}_{config.version}_{config.k}fold_{fold}",
            project='dacon-plant',
            config={
                key: config.__dict__[key]
                for key in config.__dict__.keys() if '__' not in key
            },
        )

        tt = pd.concat([
            train1[train1['n_fold'] != fold], train2[train2['n_fold'] != fold]
        ]).reset_index(drop=True)
        vv = pd.concat([
            train1[train1['n_fold'] == fold], train2[train2['n_fold'] == fold]
        ]).reset_index(drop=True)
        # tt = df_train.loc[df_train['n_fold']!=fold].reset_index(drop=True)#.iloc[:1000]
        # vv = df_train.loc[df_train['n_fold']==fold].reset_index(drop=True)
        print(vv['splits'].value_counts())

        train_transforms = train_get_transforms()
        valid_transforms = valid_get_transforms()

        config.train_dataset = PlantDataset(config,
                                            tt,
                                            mode='train',
                                            transforms=train_transforms)
        config.valid_dataset = PlantDataset(config,
                                            vv,
                                            mode='valid',
                                            transforms=valid_transforms)

        print('train_dataset input shape, label : ',
              config.train_dataset[0]['be_img'].shape,
              config.train_dataset[0]['af_img'].shape,
              config.train_dataset[0]['label'])
        print('valid_dataset input shape, label : ',
              config.valid_dataset[0]['be_img'].shape,
              config.valid_dataset[0]['af_img'].shape,
              config.valid_dataset[0]['label'])

        lr_monitor = LearningRateMonitor(
            logging_interval='epoch')  # ['epoch', 'step']
        checkpoints = ModelCheckpoint(
            'model/' + config.version,
            save_top_k=1,
            monitor='total_val_mse',
            mode='min',
            filename=f'{config.k}fold_{fold}__' +
            '{epoch}_{total_val_loss:.4f}_{total_val_mse:.4f}')

        model = plModel(config)
        trainer = pl.Trainer(
            max_epochs=config.epochs,
            gpus=1,
            log_every_n_steps=50,
            # gradient_clip_val=1000, gradient_clip_algorithm='value', # defalut : [norm, value]
            # amp_backend='native', precision=16, # amp_backend default : native
            callbacks=[checkpoints, lr_monitor],
            logger=logger)

        trainer.fit(model)
        del model, trainer
        wandb.finish()
Example #4
0
def main(args=None, model=None) -> GenerativeQAModule:
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
    parser = GenerativeQAModule.add_retriever_specific_args(parser)
    args = args or parser.parse_args()

    Path(args.output_dir).mkdir(exist_ok=True)
    Path(args.output_dir + "/dpr_ctx_checkpoint").mkdir(
        exist_ok=True)  # save dpr_context encoder seprately for the future use
    print(args.shard_dir)
    if os.path.exists(
            args.shard_dir
    ):  # we do not need previous kb shards used in dataset re-conding and re-indexing
        shutil.rmtree(args.shard_dir)
    Path(args.shard_dir).mkdir(exist_ok=True)

    if os.path.exists(
            args.cache_dir
    ):  # we do not need previous cache files used in dataset re-conding and re-indexing
        shutil.rmtree(args.cache_dir)
    Path(args.cache_dir).mkdir(exist_ok=True)

    named_actors = []
    if args.distributed_retriever == "ray" and args.gpus > 1:
        if not is_ray_available():
            raise RuntimeError("Please install Ray to use the Ray "
                               "distributed retriever.")
        # Connect to an existing Ray cluster.
        try:
            ray.init(address=args.ray_address)
        except (ConnectionError, ValueError):
            logger.warning(
                "Connection to Ray cluster failed. Make sure a Ray"
                "cluster is running by either using Ray's cluster "
                "launcher (`ray up`) or by manually starting Ray on "
                "each node via `ray start --head` for the head node "
                "and `ray start --address='<ip address>:6379'` for "
                "additional nodes. See "
                "https://docs.ray.io/en/master/cluster/index.html "
                "for more info.")
            raise

        # Create Ray actors only for rank 0.
        if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"]
                == 0) and ("NODE_RANK" not in os.environ
                           or os.environ["NODE_RANK"] == 0):
            remote_cls = ray.remote(RayRetriever)
            named_actors = [
                remote_cls.options(
                    name="retrieval_worker_{}".format(i)).remote()
                for i in range(args.num_retrieval_workers)
            ]
        else:
            logger.info(
                "Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
                    os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]))
            named_actors = [
                ray.get_actor("retrieval_worker_{}".format(i))
                for i in range(args.num_retrieval_workers)
            ]
    args.actor_handles = named_actors
    assert args.actor_handles == named_actors

    if model is None:
        model: GenerativeQAModule = GenerativeQAModule(args)

    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        training_logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        training_logger = WandbLogger(name=model.output_dir.name,
                                      project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        training_logger = WandbLogger(name=model.output_dir.name,
                                      project=f"hf_{dataset}")

    es_callback = (get_early_stopping_callback(model.val_metric,
                                               args.early_stopping_patience)
                   if args.early_stopping_patience >= 0 else False)

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric),
        early_stopping_callback=es_callback,
        logger=training_logger,
        profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
    )

    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Example #5
0
def test_wandb_log_model(wandb, tmpdir):
    """Test that the logger creates the folders and files in the right place."""

    wandb.run = None
    model = BoringModel()

    # test log_model=True
    logger = WandbLogger(log_model=True)
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    wandb.init().log_artifact.assert_called_once()

    # test log_model='all'
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    logger = WandbLogger(log_model="all")
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    assert wandb.init().log_artifact.call_count == 2

    # test log_model=False
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    logger = WandbLogger(log_model=False)
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    assert not wandb.init().log_artifact.called

    # test correct metadata
    import pytorch_lightning.loggers.wandb as pl_wandb

    pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
    wandb.init().log_artifact.reset_mock()
    wandb.init.reset_mock()
    wandb.Artifact.reset_mock()
    logger = pl_wandb.WandbLogger(log_model=True)
    logger.experiment.id = "1"
    logger.experiment.project_name.return_value = "project"
    trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
    trainer.fit(model)
    wandb.Artifact.assert_called_once_with(
        name="model-1",
        type="model",
        metadata={
            "score": None,
            "original_filename": "epoch=1-step=5-v3.ckpt",
            "ModelCheckpoint": {
                "monitor": None,
                "mode": "min",
                "save_last": None,
                "save_top_k": 1,
                "save_weights_only": False,
                "_every_n_train_steps": 0,
            },
        },
    )
Example #6
0
        super().__init__()

    def train_dataloader(self):
        return DataLoader(train,
                          batch_size=32,
                          shuffle=True,
                          num_workers=cpu_count(),
                          collate_fn=blocked_collate,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(val,
                          batch_size=1,
                          num_workers=cpu_count(),
                          pin_memory=True)


data = BlockedMusicDataModule()
model = BlockedPhononet(num_out=max(fcd.y) + 1)

logger = WandbLogger(project='CarnaticPhononet', name='First Attempt')
trainer = Trainer(gpus=1,
                  logger=logger,
                  max_epochs=100000,
                  num_sanity_val_steps=2,
                  deterministic=True,
                  val_check_interval=0.1,
                  auto_lr_find=False)
model.lr = 0.000912
trainer.fit(model, data)
Example #7
0
def test_wandb_logger(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    logger = WandbLogger(anonymous=True, offline=True)

    logger.log_metrics({'acc': 1.0})
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)

    wandb.init().log.reset_mock()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)

    # continue training on same W&B run
    wandb.init().step = 3
    logger.finalize('success')
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_with({'acc': 1.0}, step=6)

    logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
    wandb.init().config.update.assert_called_once_with(
        {
            'test': 'None',
            'nested/a': 1,
            'b': [2, 3, 4]
        },
        allow_val_change=True,
    )

    logger.watch('model', 'log', 10)
    wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
Example #8
0
    },
    # training params
    'precision': 16,
    'max_epochs': 50,
    'val_batches': 10
}

dm = DataModule(
    file = 'data_extra' if config['extra_data'] else 'data_old', 
    **config
)


model = Resnet(config)

wandb_logger = WandbLogger(project="cassava", config=config)

es = EarlyStopping(monitor='val_acc', mode='max', patience=3)
checkpoint = ModelCheckpoint(dirpath='./', filename=f'{config["backbone"]}-{config["size"]}-{{val_acc:.5f}}', save_top_k=1, monitor='val_acc', mode='max')

trainer = pl.Trainer(
    gpus=1,
    precision=config['precision'],
    logger= wandb_logger,
    max_epochs=config['max_epochs'],
    callbacks=[es, checkpoint],
    limit_val_batches=config['val_batches']
)

trainer.fit(model, dm)
Example #9
0
def main(args):
    if args.seed:
        set_seed(args.seed)

    if args.mode == "abstractive":
        summarizer = AbstractiveSummarizer
    else:
        summarizer = ExtractiveSummarizer

    if args.load_weights:
        model = summarizer(hparams=args)
        checkpoint = torch.load(
            args.load_weights, map_location=lambda storage, loc: storage
        )
        model.load_state_dict(checkpoint["state_dict"], strict=args.no_strict)
    elif args.load_from_checkpoint:
        try:
            model = summarizer.load_from_checkpoint(
                args.load_from_checkpoint, strict=args.no_strict
            )
        except RuntimeError as e:
            e_str = str(e)
            if (
                "Missing key(s) in state_dict" in e_str
                or "word_embedding_model.embeddings.position_ids" in e_str
            ):
                print(
                    (
                        "The below is a common issue. Due to the `transformers` update "
                        "from 3.0.2 to 3.1.0, models trained in versions <3.0.2 need to be "
                        "loaded with the `--no_strict` argument. More details can be found at "
                        "huggingface/transformers#6882."
                    )
                )
            raise e
        
        # The model is loaded with self.hparams.data_path set to the directory where the data
        # was located during training. When loading the model, it may be desired to change
        # the data path, which the below line accomplishes.
        if args.data_path:
            model.hparams.data_path = args.data_path
        # Same as above but for `test_use_pyrouge`
        if args.test_use_pyrouge:
            model.hparams.test_use_pyrouge = args.test_use_pyrouge
    else:
        model = summarizer(hparams=args)

    # Create learning rate logger
    lr_logger = LearningRateMonitor()
    args.callbacks = [lr_logger]

    if args.use_logger == "wandb":
        wandb_logger = WandbLogger(
            project=args.wandb_project, log_model=(not args.no_wandb_logger_log_model)
        )
        args.logger = wandb_logger

    if args.use_custom_checkpoint_callback:
        args.checkpoint_callback = ModelCheckpoint(
            save_top_k=-1, period=1, verbose=True,
        )
    if args.custom_checkpoint_every_n:
        custom_checkpoint_callback = StepCheckpointCallback(
            step_interval=args.custom_checkpoint_every_n,
            save_path=args.weights_save_path,
        )
        args.callbacks.append(custom_checkpoint_callback)

    trainer = Trainer.from_argparse_args(args)

    if args.lr_find:
        lr_finder = trainer.lr_find(model)
        fig = lr_finder.plot(suggest=True)
        fig.show()
        new_lr = lr_finder.suggestion()
        logger.info("Recommended Learning Rate: %s", new_lr)

    # remove `args.callbacks` if it exists so it does not get saved with the model (would result in crash)
    if args.custom_checkpoint_every_n:
        del args.callbacks

    if args.do_train:
        trainer.fit(model)
    if args.do_test:
        trainer.test(model)
Example #10
0
    DataModule = CIFAR10DataModule
elif data == 'CELEBA':
    Generator = GeneratorDCGAN_CELEBA
    Discriminator = DiscriminatorDCGAN_CELEBA
    DataModule = CelebaDataModule
    path_pytorch_data = path_pytorch_data / 'celeba'

dm = DataModule(path_pytorch_data)
latent_dim = 100
img_shape = dm.size()

generator = Generator(latent_dim=latent_dim, img_shape=img_shape)
discriminator = Discriminator(img_shape=img_shape)
model = GAN(*dm.size(), latent_dim=latent_dim, generator=generator, discriminator=discriminator)

logger = WandbLogger(project='gan_celeba_2')

dm.prepare_data()
dm.setup()
dataloader = dm.train_dataloader()
real_batch = next(iter(dataloader))

real_images = np.transpose(vutils.make_grid(real_batch[0][:6], padding=2, normalize=True).numpy(), (1, 2, 0))
logger.experiment.log({'real_sample': [wandb.Image(real_images, caption='Real Images')]})

gpus = 0
if torch.cuda.is_available():
    print('GPU Available')
    device = 'cuda:0'
    gpus = 1
else:
Example #11
0
else:
    raise NotImplementedError(f"Dataset {conf.dataset.name} is not supported!")

dm = DMLDataModule(
    name=DataSetType.name,
    DataSetType=DataSetType,
    root=conf.dataset.root,
    classes=classes,
    eval_classes=eval_classes,
    batch_size=conf.model.batch_size,
    train_transform=make_transform_inception_v3(augment=True),
    eval_transform=make_transform_inception_v3(augment=False),
)

wandb_logger = WandbLogger(name=dm.name,
                           project=conf.experiment.name,
                           save_dir="/mnt/vol_b/models/few-shot")
dm.setup(project_name=conf.experiment.name)

model = DML(
    val_dataset=dm.val_dataset,
    num_classes=dm.num_classes,
    pooling=conf.model.pooling,
    pretrained=conf.model.pretrained,
    lr_backbone=conf.model.lr_backbone,
    weight_decay_backbone=conf.model.weight_decay_backbone,
    lr_embedding=conf.model.lr_embedding,
    weight_decay_embedding=conf.model.weight_decay_embedding,
    lr=conf.model.lr,
    weight_decay_proxynca=conf.model.weight_decay_proxynca,
    dataloader=dm.train_dataloader(),
Example #12
0
def main(args):
    if args.seed:
        set_seed(args.seed)

    if args.mode == "abstractive":
        summarizer = AbstractiveSummarizer
    else:
        summarizer = ExtractiveSummarizer

    if args.load_weights:
        model = summarizer(hparams=args)
        checkpoint = torch.load(args.load_weights,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint["state_dict"])
    elif args.load_from_checkpoint:
        model = summarizer.load_from_checkpoint(args.load_from_checkpoint)
        # The model is loaded with self.hparams.data_path set to the directory where the data
        # was located during training. When loading the model, it may be desired to change
        # the data path, which the below line accomplishes.
        if args.data_path:
            model.hparams.data_path = args.data_path
        # Same as above but for `test_use_pyrouge`
        if args.test_use_pyrouge:
            model.hparams.test_use_pyrouge = args.test_use_pyrouge
    else:
        model = summarizer(hparams=args)

    # Create learning rate logger
    lr_logger = LearningRateMonitor()
    args.callbacks = [lr_logger]

    if args.use_logger == "wandb":
        wandb_logger = WandbLogger(
            project=args.wandb_project,
            log_model=(not args.no_wandb_logger_log_model))
        args.logger = wandb_logger

    if args.use_custom_checkpoint_callback:
        args.checkpoint_callback = ModelCheckpoint(
            save_top_k=-1,
            period=1,
            verbose=True,
        )
    if args.custom_checkpoint_every_n:
        custom_checkpoint_callback = StepCheckpointCallback(
            step_interval=args.custom_checkpoint_every_n,
            save_path=args.weights_save_path,
        )
        args.callbacks.append(custom_checkpoint_callback)

    trainer = Trainer.from_argparse_args(args)

    if args.lr_find:
        lr_finder = trainer.lr_find(model)
        fig = lr_finder.plot(suggest=True)
        fig.show()
        new_lr = lr_finder.suggestion()
        logger.info("Recommended Learning Rate: %s", new_lr)

    # remove `args.callbacks` if it exists so it does not get saved with the model (would result in crash)
    if args.custom_checkpoint_every_n:
        del args.callbacks

    if args.do_train:
        trainer.fit(model)
    if args.do_test:
        trainer.test(model)
        return [optim], [scheduler]


# %%
# Train the MoCo model
# ---------------
#
# We can instantiate the model and train it using the
# lightning trainer.
from pytorch_lightning.loggers import WandbLogger
import datetime
# use a GPU if available
gpus = 1 if torch.cuda.is_available() else 0
wandb_logger = WandbLogger(
    project='vit-self-train',
    entity='dcastf01',
    name="mocogeneral" + datetime.datetime.utcnow().strftime("%Y-%m-%d %X"),
    # offline=True, #to debug
)
model = MocoModel()
trainer = pl.Trainer(logger=wandb_logger,
                     max_epochs=max_epochs,
                     gpus=gpus,
                     progress_bar_refresh_rate=100)
trainer.fit(model, dataloader_train_moco)

# %%
# Train the Classifier
model.eval()
classifier = Classifier(model.resnet_moco)
wandb_logger = WandbLogger(
    project='vit-self-train',
Example #14
0
                          lr=args.learning_rate,
                          total_steps=len(dataset.train_dataloader()) * epochs,
                          concat=args.inject_concat,
                          **train_p)
            checkpoint = torch.load(path_to_pretrained, map_location="cpu")
            model.load_state_dict(checkpoint)

        # Additional trainer params
        accumulate_grad_batches = args.gradient_acc_batches if args.gradient_acc_batches is not None else None

        print("Training", model_name)
        print("Initializing the trainer")
        exp = wandb.init(project="knowledgeinjection", name=model_name)
        wandb.watch(model, log="all")
        logger = WandbLogger(name=model_name,
                             project="knowledgeinjection",
                             experiment=exp)
        trainer = pl.Trainer(callbacks=callbacks,
                             gpus=args.gpus if args.use_gpu else None,
                             auto_select_gpus=args.use_gpu,
                             max_epochs=epochs,
                             val_check_interval=0.25,
                             logger=logger,
                             precision=fp16,
                             log_every_n_steps=10)
        print("Fitting...")

        trainer.fit(model, datamodule=dataset)
        print("Testing...")
        trainer.test(datamodule=dataset)
        print("Done!")
Example #15
0
def main(args, model=None) -> SummarizationModule:
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if model is None:
        if args.task == "summarization":
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)

    # add atomic relation tokens
    if args.atomic:
        print("Special tokens are added.")

        additional_tokens_list = [
            "AtLocation",
            "CapableOf",
            "Causes",
            "CausesDesire",
            "CreatedBy",
            "DefinedAs",
            "DesireOf",
            "Desires",
            "HasA",
            "HasFirstSubevent",
            "HasLastSubevent",
            "HasPainCharacter",
            "HasPainIntensity",
            "HasPrerequisite",
            "HasProperty",
            "HasSubEvent",
            "HasSubevent",
            "HinderedBy",
            "InheritsFrom",
            "InstanceOf",
            "IsA",
            "LocatedNear",
            "LocationOfAction",
            "MadeOf",
            "MadeUpOf",
            "MotivatedByGoal",
            "NotCapableOf",
            "NotDesires",
            "NotHasA",
            "NotHasProperty",
            "NotIsA",
            "NotMadeOf",
            "ObjectUse",
            "PartOf",
            "ReceivesAction",
            "RelatedTo",
            "SymbolOf",
            "UsedFor",
            "isAfter",
            "isBefore",
            "isFilledBy",
            "oEffect",
            "oReact",
            "oWant",
            "xAttr",
            "xEffect",
            "xIntent",
            "xNeed",
            "xReact",
            "xReason",
            "xWant",
        ]

        num_added_toks = model.tokenizer.add_tokens(additional_tokens_list)
        model.model.resize_token_embeddings(len(model.tokenizer))

    dataset = Path(args.data_dir).name
    if (
        args.logger_name == "default"
        or args.fast_dev_run
        or str(args.output_dir).startswith("/tmp")
        or str(args.output_dir).startswith("/var")
    ):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name, project=dataset)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(
            name=model.output_dir.name, project=f"hf_{dataset}")

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir, model.val_metric),
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    trainer.test(model)
    return model
Example #16
0
                                           pin_memory=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_ds,
                                           batch_size=batch_size,
                                           drop_last=True,
                                           shuffle=False,
                                           num_workers=8,
                                           pin_memory=True)

    def configure_optimizers(self):
        return torch.optim.SGD(
            self.parameters(),
            lr=lr,
            momentum=momentum,
        )


model = WikitextLM()
trainer = pl.Trainer(
    default_root_dir='logs',
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=epochs,
    fast_dev_run=debug,
    logger=WandbLogger(save_dir='logs/',
                       name='wikitext-no-kb',
                       project='experiment-1'),
)

trainer.fit(model)
Example #17
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    model = TriviaQA(args)

    logger = TestTubeLogger(
        save_dir=args.save_dir,
        name=args.save_prefix,
        #        version=0  # always use version=0
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(args.save_dir, args.save_prefix),
        filename="checkpoints",
        save_top_k=5,
        verbose=True,
        monitor='avg_val_loss',
        # save_last=True,
        mode='min',
        period=-1,
    )

    wandb_logger = WandbLogger(name=args.run_name, project=args.project_name)

    print(args)
    train_set_size = 110648  # hardcode dataset size. Needed to compute number of steps for the lr scheduler
    args.steps = args.epochs * train_set_size / (args.batch_size *
                                                 max(args.gpus, 1))
    print(
        f'>>>>>>> #steps: {args.steps}, #epochs: {args.epochs}, batch_size: {args.batch_size * args.gpus} <<<<<<<'
    )

    trainer = pl.Trainer(
        gpus=args.gpus,
        distributed_backend='ddp' if args.gpus and args.gpus > 1 else None,
        track_grad_norm=-1,
        max_epochs=args.epochs,
        replace_sampler_ddp=False,
        accumulate_grad_batches=args.batch_size,
        val_check_interval=args.val_every,
        num_sanity_val_steps=2,
        # check_val_every_n_epoch=2,
        limit_val_batches=args.val_percent_check,
        limit_test_batches=args.val_percent_check,
        logger=wandb_logger if not args.disable_checkpointing else False,
        checkpoint_callback=checkpoint_callback
        if not args.disable_checkpointing else False,
        amp_level='O2',
        resume_from_checkpoint=args.resume_ckpt,
    )
    if not args.test:
        trainer.fit(model)
        os.path.join(args.save_dir, args.save_prefix)
        now = datetime.now()
        save_string = now.strftime("final-model-%m%d%Y-%H:%M:%S")
        trainer.save_checkpoint(
            os.path.join(args.save_dir, args.save_prefix, save_string))
        martifact = wandb.Artifact('final_model.ckpt', type='model')
        martifact.add_file(
            os.path.join(args.save_dir, args.save_prefix, save_string))
        wandb_logger.experiment.log_artifact(martifact)

    trainer.test(model)
Example #18
0
def init_wandb_logger(project_config: dict,
                      run_config: dict,
                      lit_model: pl.LightningModule,
                      datamodule: pl.LightningDataModule,
                      log_path: str = "logs/") -> pl.loggers.WandbLogger:
    """Initialize Weights&Biases logger."""

    # with this line wandb will throw an error if the run to be resumed does not exist yet
    # instead of auto-creating a new run
    os.environ["WANDB_RESUME"] = "must"

    resume_from_checkpoint = run_config.get("resume_training",
                                            {}).get("resume_from_checkpoint",
                                                    None)
    wandb_run_id = run_config.get("resume_training",
                                  {}).get("wandb_run_id", None)

    wandb_logger = WandbLogger(
        project=project_config["loggers"]["wandb"]["project"],
        entity=project_config["loggers"]["wandb"]["entity"],
        log_model=project_config["loggers"]["wandb"]["log_model"],
        offline=project_config["loggers"]["wandb"]["offline"],
        group=run_config.get("wandb", {}).get("group", None),
        job_type=run_config.get("wandb", {}).get("job_type", "train"),
        tags=run_config.get("wandb", {}).get("tags", []),
        notes=run_config.get("wandb", {}).get("notes", ""),

        # resume run only if ckpt was set in the run config
        id=wandb_run_id if resume_from_checkpoint != "None"
        and wandb_run_id != "None" and resume_from_checkpoint is not None
        and resume_from_checkpoint is not False and wandb_run_id is not False
        else None,
        save_dir=log_path,
        save_code=False)

    if not os.path.exists(log_path):
        os.makedirs(log_path)

    if hasattr(lit_model, 'model'):
        wandb_logger.watch(lit_model.model, log=None)
    else:
        wandb_logger.watch(lit_model, log=None)

    wandb_logger.log_hyperparams({
        "model":
        lit_model.model.__class__.__name__,
        "optimizer":
        lit_model.configure_optimizers().__class__.__name__,
        "train_size":
        len(datamodule.data_train) if hasattr(datamodule, 'data_train')
        and datamodule.data_train is not None else 0,
        "val_size":
        len(datamodule.data_val) if hasattr(datamodule, 'data_val')
        and datamodule.data_val is not None else 0,
        "test_size":
        len(datamodule.data_test) if hasattr(datamodule, 'data_test')
        and datamodule.data_test is not None else 0,
    })
    wandb_logger.log_hyperparams(run_config["trainer"])
    wandb_logger.log_hyperparams(run_config["model"])
    wandb_logger.log_hyperparams(run_config["dataset"])

    return wandb_logger
def train(param):
    if not isinstance(param, dict):
        args = vars(param)
    else:
        args = param

    framework = get_class_by_name('conditioned_separation', args['model'])
    if args['spec_type'] != 'magnitude':
        args['input_channels'] = 4

    if args['resume_from_checkpoint'] is None:
        if args['seed'] is not None:
            seed_everything(args['seed'])

    model = framework(**args)

    if args['last_activation'] != 'identity' and args[
            'spec_est_mode'] != 'masking':
        warn(
            'Please check if you really want to use a mapping-based spectrogram estimation method '
            'with a final activation function. ')
    ##########################################################

    # -- checkpoint
    ckpt_path = Path(args['ckpt_root_path'])
    mkdir_if_not_exists(ckpt_path)
    ckpt_path = ckpt_path.joinpath(args['model'])
    mkdir_if_not_exists(ckpt_path)
    run_id = args['run_id']
    ckpt_path = ckpt_path.joinpath(run_id)
    mkdir_if_not_exists(ckpt_path)
    save_top_k = args['save_top_k']

    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=save_top_k,
        verbose=False,
        monitor='val_loss',
        save_last=False,
        save_weights_only=args['save_weights_only'])
    args['checkpoint_callback'] = checkpoint_callback

    # -- early stop
    patience = args['patience']
    early_stop_callback = EarlyStopping(monitor='val_loss',
                                        min_delta=0.0,
                                        patience=patience,
                                        verbose=False)
    args['early_stop_callback'] = early_stop_callback

    if args['resume_from_checkpoint'] is not None:
        run_id = run_id + "_resume_" + args['resume_from_checkpoint']
        args['resume_from_checkpoint'] = Path(args['ckpt_root_path']).joinpath(
            args['model']).joinpath(args['run_id']).joinpath(
                args['resume_from_checkpoint'])
        args['resume_from_checkpoint'] = str(args['resume_from_checkpoint'])

    model_name = model.spec2spec.__class__.__name__

    # -- logger setting
    log = args['log']
    if log == 'False':
        args['logger'] = False
    elif log == 'wandb':
        args['logger'] = WandbLogger(project='lasaft_exp',
                                     tags=[model_name],
                                     offline=False,
                                     name=run_id)
        args['logger'].log_hyperparams(model.hparams)
        args['logger'].watch(model, log='all')
    elif log == 'tensorboard':
        raise NotImplementedError
    else:
        args['logger'] = True  # default
        default_save_path = 'etc/lightning_logs'
        mkdir_if_not_exists(default_save_path)

    valid_kwargs = inspect.signature(Trainer.__init__).parameters
    trainer_kwargs = dict(
        (name, args[name]) for name in valid_kwargs if name in args)

    # Trainer
    trainer = Trainer(**trainer_kwargs)
    dataset_args = {
        'musdb_root': args['musdb_root'],
        'batch_size': args['batch_size'],
        'num_workers': args['num_workers'],
        'pin_memory': args['pin_memory'],
        'num_frame': args['num_frame'],
        'hop_length': args['hop_length'],
        'n_fft': args['n_fft']
    }

    dp = DataProvider(**dataset_args)
    train_dataset, training_dataloader = dp.get_training_dataset_and_loader()
    valid_dataset, validation_dataloader = dp.get_validation_dataset_and_loader(
    )

    for key in sorted(args.keys()):
        print('{}:{}'.format(key, args[key]))

    if args['auto_lr_find']:
        lr_find = trainer.tuner.lr_find(model,
                                        training_dataloader,
                                        validation_dataloader,
                                        early_stop_threshold=None,
                                        min_lr=1e-5)

        print(f"Found lr: {lr_find.suggestion()}")
        return None

    if args['resume_from_checkpoint'] is not None:
        print('resume from the checkpoint')

    trainer.fit(model, training_dataloader, validation_dataloader)

    return None
import torchvision.transforms.functional as TF
import torchvision.models as models
from pytorch_lightning.loggers import WandbLogger


from torch import nn
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from pytorch_lightning.callbacks import ModelCheckpoint

# Base Model, Dataset, Batch Size, Learning Rate
wandb_logger = WandbLogger(name='Adjusted AttParseNet Unaligned Mult hflip 40 0.01 Mworks08', project='attparsenet', entity='unr-mpl')

activation = None

if __name__=="__main__":
    args = attparsenet_utils.get_args()

    pl.seed_everything(args.random_seed)

    # Initialize the model
    if args.model == "attparsenet":

        if args.load == True:
            net = attparsenet.AttParseNet.load_from_checkpoint(args.load_path + args.load_file, hparams=args)
        else:
            net = attparsenet.AttParseNet(args)
Example #21
0
def main(args, model=None) -> SummarizationModule:
    #Path(args.output_dir).mkdir(exist_ok=True)
    #if len(os.listdir(args.output_dir)) > 3 and args.do_train:
    #    raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

    #model定義。翻訳と要約で
    if model is None:
        if "summarization" in args.task:
            model: SummarizationModule = SummarizationModule(args)
        else:
            model: SummarizationModule = TranslationModule(args)
    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")

    if args.early_stopping_patience >= 0:
        es_callback = get_early_stopping_callback(model.val_metric,
                                                  args.early_stopping_patience)
    else:
        es_callback = False

    lower_is_better = args.val_metric == "loss"

    #trainerに投げて、中で訓練も行う
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric,
                                                    args.save_top_k,
                                                    lower_is_better),
        early_stopping_callback=es_callback,
        logger=logger,
    )
    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(
            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                      recursive=True)))
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]
    trainer.logger.log_hyperparams(model.hparams)

    # test() without a model tests using the best checkpoint automatically
    trainer.test()
    return model
Example #22
0
    if not os.path.exists(DIR):
        os.mkdir(DIR)

    # logging using Weight&Biaises
    experiment = wandb.init(project="cassava_land",
                            tags=[selected_model],
                            config=hyperparameter_defaults)

    # make sure config dict come from
    # the set of values in experiment
    config = experiment.config

    wandb_logger = WandbLogger(save_dir=f'./{TODAY}',
                               tags=[selected_model,
                                     str(config["img_size"])],
                               log_model=not DEV_MODE,
                               experiment=experiment,
                               offline=DEV_MODE)

    # checkpoint
    checkpoint = ModelCheckpoint(
        dirpath=DIR,
        filename=f'{selected_model}-' + f'subset={config["subset"]}-'
        f'img_size={config["img_size"]}-' + f'fold={FOLD}-' +
        '{train_acc:.5f}-' + '{val_acc:.5f}-' + '{epoch: 02d}-',
        monitor='val_acc',
        mode='max')

    # early stopping
    es = EarlyStopping(monitor='val_acc', patience=3, mode='max')
def test_v1_5_0_wandb_unused_sync_step(tmpdir):
    with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"):
        WandbLogger(sync_step=True)
Example #24
0
from pytorch_lightning.loggers import WandbLogger

if __name__ == "__main__":
    parser = get_arguments()
    opt = parser.parse_args()
    print(opt)
    # model
    model = myModel(opt)
    # data load
    train_loader, valid_loader = load_horse2zebra(opt)

    # logger
    if opt.use_wandb:
        if not opt.wandb_name == "":
            logger = WandbLogger(name=opt.wandb_name,
                                 project=opt.wandb_project,
                                 offline=opt.offline)
        else:
            logger = WandbLogger(project=opt.wandb_project,
                                 offline=opt.offline)
        logger.watch(model)
        logger.log_hyperparams(opt)

    else:
        logger = None

    # trainer
    trainer = Trainer(logger=logger,
                      gpus=1,
                      max_epochs=opt.epochs,
                      callbacks=[ImagePredictionLogger(valid_loader)])
Example #25
0
def test_wandb_logger_offline_log_model(wandb, tmpdir):
    """Test that log_model=True raises an error in offline mode."""
    with pytest.raises(MisconfigurationException, match="checkpoints cannot be uploaded in offline mode"):
        _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True)
def train(hparams):
    EMBEDDING_DIM = 128
    NUM_GPUS = hparams.num_gpus
    batch_order = 11

    dataset = load_node_dataset(hparams.dataset, hparams.method, hparams=hparams, train_ratio=hparams.train_ratio)

    METRICS = ["precision", "recall", "f1", "accuracy", "top_k" if dataset.multilabel else "ogbn-mag", ]

    if hparams.method == "HAN":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "batch_size": 2 ** batch_order * NUM_GPUS,
            "num_layers": 2,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.0005 * NUM_GPUS,
        }
        model = HAN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif hparams.method == "GTN":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "num_channels": len(dataset.metapaths),
            "num_layers": 2,
            "batch_size": 2 ** batch_order * NUM_GPUS,
            "collate_fn": "HAN_batch",
            "train_ratio": dataset.train_ratio,
            "loss_type": "BINARY_CROSS_ENTROPY" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "n_classes": dataset.n_classes,
            "lr": 0.0005 * NUM_GPUS,
        }
        model = GTN(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif hparams.method == "MetaPath2Vec":
        USE_AMP = True
        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "walk_length": 50,
            "context_size": 7,
            "walks_per_node": 5,
            "num_negative_samples": 5,
            "sparse": True,
            "batch_size": 400 * NUM_GPUS,
            "train_ratio": dataset.train_ratio,
            "n_classes": dataset.n_classes,
            "lr": 0.01 * NUM_GPUS,
        }
        model = MetaPath2Vec(Namespace(**model_hparams), dataset=dataset, metrics=METRICS)
    elif "LATTE" in hparams.method:
        USE_AMP = False
        num_gpus = 1

        if "-1" in hparams.method:
            t_order = 1
        elif "-2" in hparams.method:
            t_order = 2
        elif "-3" in hparams.method:
            t_order = 3
        else:
            t_order = 2

        model_hparams = {
            "embedding_dim": EMBEDDING_DIM,
            "t_order": t_order,
            "batch_size": 2 ** batch_order * max(num_gpus, 1),
            "nb_cls_dense_size": 0,
            "nb_cls_dropout": 0.4,
            "activation": "relu",
            "attn_heads": 2,
            "attn_activation": "sharpening",
            "attn_dropout": 0.2,
            "loss_type": "BCE" if dataset.multilabel else "SOFTMAX_CROSS_ENTROPY",
            "use_proximity": True if "proximity" in hparams.method else False,
            "neg_sampling_ratio": 2.0,
            "n_classes": dataset.n_classes,
            "use_class_weights": False,
            "lr": 0.001 * num_gpus,
            "momentum": 0.9,
            "weight_decay": 1e-2,
        }

        metrics = ["precision", "recall", "micro_f1",
                   "accuracy" if dataset.multilabel else "ogbn-mag", "top_k"]

        model = LATTENodeClassifier(Namespace(**model_hparams), dataset, collate_fn="neighbor_sampler", metrics=metrics)

    MAX_EPOCHS = 250
    wandb_logger = WandbLogger(name=model.name(),
                               tags=[dataset.name()],
                               project="multiplex-comparison")
    wandb_logger.log_hyperparams(model_hparams)

    trainer = Trainer(
        gpus=NUM_GPUS, auto_select_gpus=True,
        distributed_backend='dp' if NUM_GPUS > 1 else None,
        max_epochs=MAX_EPOCHS,
        callbacks=[EarlyStopping(monitor='val_loss', patience=10, min_delta=0.0001, strict=False)],
        logger=wandb_logger,
        weights_summary='top',
        amp_level='O1' if USE_AMP else None,
        precision=16 if USE_AMP else 32
    )

    # trainer.fit(model)
    trainer.fit(model, train_dataloader=model.valtrain_dataloader(), val_dataloaders=model.test_dataloader())
    trainer.test(model)
Example #27
0
def test_wandb_logger(wandb):
    """Verify that basic functionality of wandb logger works.
    Wandb doesn't work well with pytest so we have to mock it out here."""
    tutils.reset_seed()

    logger = WandbLogger(anonymous=True, offline=True)

    logger.log_metrics({'acc': 1.0})
    wandb.init().log.assert_called_once_with({'acc': 1.0})

    wandb.init().log.reset_mock()
    logger.log_metrics({'acc': 1.0}, step=3)
    wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})

    logger.log_hyperparams({'test': None})
    wandb.init().config.update.assert_called_once_with({'test': None})

    logger.watch('model', 'log', 10)
    wandb.watch.assert_called_once_with('model', log='log', log_freq=10)

    logger.finalize('fail')
    wandb.join.assert_called_once_with(1)

    wandb.join.reset_mock()
    logger.finalize('success')
    wandb.join.assert_called_once_with(0)

    wandb.join.reset_mock()
    wandb.join.side_effect = TypeError
    with pytest.raises(TypeError):
        logger.finalize('any')

    wandb.join.assert_called()

    assert logger.name == wandb.init().project_name()
    assert logger.version == wandb.init().id
# Dataloaders
train_dataloader = data.DataLoader(test_dataset, batch_size=256)
validation_dataloader = data.DataLoader(validation_dataset, batch_size=256)
test_dataloader = data.DataLoader(train_dataset, batch_size=256)

# Logging
model.model_tags.append(split)
model.model_tags.append(band_type)
model.model_tags.append("train:" + str(len(train_dataset)))
model.model_tags.append("validation:" + str(len(validation_dataset)))
model.model_tags.append("test:" + str(len(test_dataset)))
model.model_tags.append("seed:" + str(model.seed))

wandb_logger = WandbLogger(
    name=model.model_name,
    tags=model.model_tags,
    project="eeg-connectome-analysis",
    save_dir="/content/drive/Shared drives/EEG_Aditya/model-results/wandb",
    log_model=True)
wandb_logger.watch(model, log='gradients', log_freq=100)

# Checkpoints
val_loss_cp = pl.callbacks.ModelCheckpoint(monitor='validation-loss')

trainer = pl.Trainer(max_epochs=1000,
                     gpus=1,
                     logger=wandb_logger,
                     precision=16,
                     fast_dev_run=False,
                     auto_lr_find=True,
                     auto_scale_batch_size=True,
                     log_every_n_steps=1,
Example #29
0
    # Load config from checkpoint
    ckpt = torch.load(ckpt_path + ckpt_name[0], map_location='cpu')
    config = BonzConfig(**ckpt['hyper_parameters'])
    config.num_label = 1

    # Train and test dataloader
    data_module = BonzDataModule(task_name=args.task_name)

    # Bonz Model with Pytorch-lightning
    model = BonzModel_PL(config=config, task_name=args.task_name, lr=args.lr)

    # Load pre-trained weights
    model.load_state_dict(ckpt['state_dict'], strict=False)

    wandb_logger = WandbLogger(
        name=f'{args.version}',
        project='BenchmarkLM',
    )

    trainer = pl.Trainer(
        logger=wandb_logger,
        checkpoint_callback=False,
        benchmark=True,
        log_every_n_steps=100,
        check_val_every_n_epoch=1,
        #accelerator='ddp',
        amp_level='native',
        precision=16,
        gpus=1,
        profiler='simple',
        max_epochs=3,
        reload_dataloaders_every_epoch=True,
 num_subjects = len(dataset)
 print('Num Subjects: ',num_subjects)
 num_training_subjects = int(training_split_ratio * num_subjects)
 num_validation_subjects = num_subjects - num_training_subjects
 num_split_subjects = num_training_subjects, num_validation_subjects
 generator=torch.Generator().manual_seed(seed)
 training_subjects, validation_subjects = torch.utils.data.random_split(subjects, num_split_subjects,generator)
 training_set = tio.SubjectsDataset(training_subjects, training_transform)
 validation_set = tio.SubjectsDataset(validation_subjects, validation_transform)
 print('Training set:', len(training_set), 'subjects')
 print('Validation set:', len(validation_set), 'subjects')
 
 #TRAINER RUNNING
 wandb_logger = None
 if wandb_logging:
     wandb_logger = WandbLogger(project=wandb_project_name,name=wandb_run_name, offline = False)
     
 model = TumourSegmentation(
     train_dataset=training_set,
     val_dataset=validation_set,
     col_fn=col_fn,
     batch_size=batch_size,
     num_loading_cpus=num_loading_cpus,
     in_channels=len(input_channels_list),
     classes=seg_channels,
     learning_rate=learning_rate
 )
 
 trainer = pl.Trainer(
     default_root_dir=default_root_dir,
     accumulate_grad_batches=accumulate_grad_batches,