def train(): hp = { "epochs": 10, "lr_initial": 0.001, "lr_decay_every": 30, "lr_decay_by": 0.3, } config = { "data_path": "../data", "val_split": 0.05, "batch_size": 64, "manual_seed": 2, "output_path": "./output", "model_save_frequency": 5, "dataloader_num_workers": 0, } dataset = MnistDataset(**config) model = MnistModel(**hp, **config) wandb_logger = WandbLogger(project="classification_test", log_model=True) trainer = pl.Trainer( gpus=0, max_epochs=hp["epochs"], default_root_dir=config["output_path"], logger=wandb_logger, ) wandb_logger.watch(model) trainer.fit(model, datamodule=dataset)
def main(): data_module = CFDataModule(batch_size=2) wandb_logger = WandbLogger(project="cosmoflow") early_stop_callback = EarlyStopping( monitor='val_loss', min_delta=0.0001, patience=2, verbose=True, mode="min", ) print("create tainer") trainer = pl.Trainer( gpus=1, num_sanity_val_steps=0, max_epochs=20, distributed_backend="horovod", replace_sampler_ddp=False, early_stop_callback=early_stop_callback, logger=wandb_logger, progress_bar_refresh_rate=0, ) # print("tainer created") model = Cosmoflow() trainer.fit(model, data_module)
def main(): # Use concurrency experiment wandb.require(experiment="service") print("PIDPID", os.getpid()) # Set up data num_samples = 100000 train = DataLoader(RandomDataset(32, num_samples), batch_size=32) val = DataLoader(RandomDataset(32, num_samples), batch_size=32) test = DataLoader(RandomDataset(32, num_samples), batch_size=32) # init model model = BoringModel() # set up wandb config = dict(some_hparam="Logged Before Trainer starts DDP") wandb_logger = WandbLogger(log_model=True, config=config, save_code=True) # Initialize a trainer trainer = Trainer( max_epochs=1, gpus=2, strategy="ddp_spawn", logger=wandb_logger, ) # Train the model trainer.fit(model, train, val) trainer.test(test_dataloaders=test)
def main(hparams): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = SegModel(hparams) # ------------------------ # 2 SET WANDB LOGGER # ------------------------ wandb_logger = WandbLogger() # optional: log model topology wandb_logger.watch(model.net) # ------------------------ # 3 INIT TRAINER # ------------------------ trainer = pl.Trainer(gpus=hparams.gpus, logger=wandb_logger, max_epochs=hparams.epochs, accumulate_grad_batches=hparams.grad_batches, checkpoint_callback=False) # ------------------------ # 5 START TRAINING # ------------------------ trainer.fit(model)
def main(hparams: Namespace): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = SegModel(**vars(hparams)) # ------------------------ # 2 SET LOGGER # ------------------------ logger = False if hparams.log_wandb: logger = WandbLogger() # optional: log model topology logger.watch(model.net) # ------------------------ # 3 INIT TRAINER # ------------------------ trainer = pl.Trainer( gpus=hparams.gpus, logger=logger, max_epochs=hparams.epochs, accumulate_grad_batches=hparams.grad_batches, accelerator=hparams.accelerator, precision=16 if hparams.use_amp else 32, ) # ------------------------ # 5 START TRAINING # ------------------------ trainer.fit(model)
def test_wandb_pickle(wandb): """Verify that pickling trainer with wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" tutils.reset_seed() class Experiment: id = 'the_id' wandb.init.return_value = Experiment() logger = WandbLogger(id='the_id', offline=True) trainer_options = dict(max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) assert os.environ['WANDB_MODE'] == 'dryrun' assert trainer2.logger.__class__.__name__ == WandbLogger.__name__ _ = trainer2.logger.experiment wandb.init.assert_called() assert 'id' in wandb.init.call_args[1] assert wandb.init.call_args[1]['id'] == 'the_id' del os.environ['WANDB_MODE']
def main(): print("Running main") print(time.ctime()) args = parse_args() with open(args.config) as file: default_configs = yaml.load(file, Loader=yaml.FullLoader) print("Initialising model") print(time.ctime()) model = CheckpointedPyramid(default_configs) # model.setup(stage="fit") logger = WandbLogger( project=default_configs["project"], group="InitialTest", save_dir=default_configs["artifacts"], ) trainer = Trainer( gpus=4, num_nodes=8, strategy=CustomDDPPlugin(find_unused_parameters=False), max_epochs=default_configs["max_epochs"], logger=logger, ) trainer.fit(model)
def test_wandb_logger_offline_log_model(wandb, tmpdir): """ Test that log_model=True raises an error in offline mode """ with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'): logger = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True)
def main(args) -> None: """ Функция запуска обучения. """ config = load_cfg(args.config) pretty_printer = pprint.PrettyPrinter(indent=2) pretty_printer.pprint(config) model = BaselineLearner(config) logger = False if args.use_logger: logger = WandbLogger(name=config.name) logger.watch(model.net) trainer = pl.Trainer( gpus=args.gpus, logger=logger, callbacks=[ ModelCheckpoint(monitor='valid_loss', dirpath=config.sources.ckpt_path, filename=config.name) ], max_epochs=config.training.epochs, distributed_backend=args.distributed_backend, precision=16 if args.use_amp else 32, ) trainer.fit(model) print('Model training completed!')
def cli_main(): parser = argparse.ArgumentParser() # trainer args parser = pl.Trainer.add_argparse_args(parser) # model args parser = SLExperiment.add_model_specific_args(parser) args = parser.parse_args() model = SLExperiment(**args.__dict__) if args.resume: model.resume(args.resume) logger = None callbacks = [] if not args.fast_dev_run: logger = WandbLogger(project="argumentation", save_dir=str(config.root_dir), tags=[args.tag]) logger.log_hyperparams(args) # save checkpoints based on avg_reward checkpoint_callback = ModelCheckpoint( dirpath=logger.experiment.dir, save_top_k=1, monitor="validation/loss", mode="min", save_weights_only=True, verbose=True, ) callbacks.append(checkpoint_callback) if args.tag: tag_checkpoint_callback = copy.deepcopy(checkpoint_callback) tag_checkpoint_callback.dirpath = model.model_dir tag_checkpoint_callback.filename = model.model_name callbacks.append(tag_checkpoint_callback) # early stopping if args.patience: early_stop_callback = EarlyStopping(monitor="validation/loss", patience=args.patience, mode="min", verbose=True) callbacks.append(early_stop_callback) pl.seed_everything(123) trainer = pl.Trainer.from_argparse_args( args, logger=logger, callbacks=callbacks, track_grad_norm=2, ) if args.train_ds and args.val_ds: trainer.fit(model) if args.test_ds: trainer.test(model)
def get_logger(model_config): # M logger_choice = model_config["logger"] if "project" not in model_config.keys(): model_config["project"] = "my_project" if logger_choice == "wandb": logger = WandbLogger( project=model_config["project"], save_dir=model_config["artifact_library"], id=model_config["resume_id"], ) elif logger_choice == "tb": logger = TensorBoardLogger( name=model_config["project"], save_dir=model_config["artifact_library"], version=model_config["resume_id"], ) elif logger_choice == None: logger = None logging.info("Logger retrieved") return logger
def train(config): fix_seeds(seed=config.train.seed) crnn = CRNNEncoder( in_channels=config.model.get('in_channels', 42), hidden_size=config.model.get('hidden_size', 16), dropout=config.model.get('dropout', 0.1), cnn_layers=config.model.get('cnn_layers', 2), rnn_layers=config.model.get('rnn_layers', 2), kernel_size=config.model.get('kernel_size', 9) ) model = AttentionNet( crnn, hidden_size=config.model.get('hidden_size', 16), num_classes=config.model.get('num_classes', 3) ) pl_model = KWSModel( model, lr=config.train.get('lr', 4e-5), in_channels=config.model.get('in_channels', 42), batch_size=config.train.get('batch_size', 32) ) wandb_logger = WandbLogger(name=config.train.get('experiment_name', 'final_run'), project='kws-attention', log_model=True) wandb_logger.log_hyperparams(config) wandb_logger.watch(model, log='all', log_freq=100) trainer = pl.Trainer(max_epochs=config.train.get('max_epochs', 15), logger=wandb_logger, gpus=config.train.get('gpus', 1)) trainer.fit(pl_model)
def train(hparams): NUM_GPUS = hparams.num_gpus USE_AMP = False # True if NUM_GPUS > 1 else False MAX_EPOCHS = 50 dataset = load_link_dataset(hparams.dataset, hparams=hparams) hparams.n_classes = dataset.n_classes model = LATTELinkPredictor(hparams, dataset, collate_fn="triples_batch", metrics=[hparams.dataset]) wandb_logger = WandbLogger(name=model.name(), tags=[dataset.name()], project="multiplex-comparison") trainer = Trainer( gpus=NUM_GPUS, distributed_backend='ddp' if NUM_GPUS > 1 else None, auto_lr_find=False, max_epochs=MAX_EPOCHS, early_stop_callback=EarlyStopping(monitor='val_loss', patience=10, min_delta=0.01, strict=False), logger=wandb_logger, # regularizers=regularizers, weights_summary='top', amp_level='O1' if USE_AMP else None, precision=16 if USE_AMP else 32) trainer.fit(model) trainer.test(model)
def main(cfg: DictConfig): datamodule = instantiate(cfg.data) task = instantiate(cfg.task) logger = WandbLogger(**cfg.logger) # logger = CSVLogger(save_dir='logs') trainer = Trainer(**cfg.trainer, logger=logger) trainer.fit(model=task, datamodule=datamodule)
def main(config): # ------------------------ # 1 LIGHTNING MODEL # ------------------------ model = SegModel(config) # ------------------------ # 2 DATA PIPELINES # ------------------------ kittiData = KittiDataModule(config) # ------------------------ # 3 WANDB LOGGER # ------------------------ wandb_logger = WandbLogger() # optional: log model topology wandb_logger.watch(model.net) # ------------------------ # 4 TRAINER # ------------------------ trainer = pl.Trainer( gpus=-1, logger=wandb_logger, max_epochs=config.epochs, accumulate_grad_batches=config.grad_batches, ) # ------------------------ # 5 START TRAINING # ------------------------ trainer.fit(model, kittiData)
def main(hparams: Namespace): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = SegModel(**vars(hparams)) # ------------------------ # 2 SET LOGGER # ------------------------ logger = False if hparams.log_wandb: logger = WandbLogger() # optional: log model topology logger.watch(model.net) # ------------------------ # 3 INIT TRAINER # ------------------------ trainer = pl.Trainer.from_argparse_args(hparams) # ------------------------ # 5 START TRAINING # ------------------------ trainer.fit(model)
def test_wandb_pickle(wandb, tmpdir): """ Verify that pickling trainer with wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. """ class Experiment: """ """ id = 'the_id' def project_name(self): return 'the_project_name' wandb.init.return_value = Experiment() logger = WandbLogger(id='the_id', offline=True) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, ) # Access the experiment to ensure it's created assert trainer.logger.experiment, 'missing experiment' pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) assert os.environ['WANDB_MODE'] == 'dryrun' assert trainer2.logger.__class__.__name__ == WandbLogger.__name__ assert trainer2.logger.experiment, 'missing experiment' wandb.init.assert_called() assert 'id' in wandb.init.call_args[1] assert wandb.init.call_args[1]['id'] == 'the_id' del os.environ['WANDB_MODE']
def test_wandb_logger_dirs_creation(wandb, tmpdir): """ Test that the logger creates the folders and files in the right place. """ logger = WandbLogger(save_dir=str(tmpdir), offline=True) assert logger.version is None assert logger.name is None # mock return values of experiment logger.experiment.id = '1' logger.experiment.project_name.return_value = 'project' for _ in range(2): _ = logger.experiment assert logger.version == '1' assert logger.name == 'project' assert str(tmpdir) == logger.save_dir assert not os.listdir(tmpdir) version = logger.version model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
def objective(trial): pl.seed_everything(42, workers=True) trans = {'Resize': {'width': 224, 'height': 224}} dm = DataModule(batch_size=64, num_workers=24, pin_memory=True, train_trans=trans, val_trans=trans, shuffle_train=False) loss = trial.suggest_categorical( "loss", ["bce", "dice", "jaccard", "focal", "log_cosh_dice"]) model = SMP({ 'optimizer': 'Adam', 'lr': 0.0003, 'loss': loss, 'model': 'Unet', 'backbone': 'resnet18', 'pretrained': 'imagenet' }) wandb_logger = WandbLogger(project="MnMs2-opt", name=loss) trainer = pl.Trainer( gpus=1, precision=16, logger=wandb_logger, max_epochs=10, callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_iou")], checkpoint_callback=False, limit_train_batches=1., limit_val_batches=1., deterministic=True) trainer.fit(model, dm) score = trainer.test(model, dm.val_dataloader()) wandb_logger.experiment.finish() return score[0]['test_iou']
def test_wandb_logger(wandb): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" logger = WandbLogger(anonymous=True, offline=True) logger.log_metrics({'acc': 1.0}) wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) wandb.init().log.reset_mock() logger.log_metrics({'acc': 1.0}, step=3) wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) # continue training on same W&B run wandb.init().step = 3 logger.finalize('success') logger.log_metrics({'acc': 1.0}, step=3) wandb.init().log.assert_called_with({'acc': 1.0}, step=6) logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) wandb.init().config.update.assert_called_once_with( {'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]}, allow_val_change=True, ) logger.watch('model', 'log', 10) wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10) assert logger.name == wandb.init().project_name() assert logger.version == wandb.init().id
def main(): # Use concurrency experiment wandb.require(experiment="service") print("PIDPID", os.getpid()) # Set up data num_samples = 100000 train = RandomDataset(32, num_samples) train = DataLoader(train, batch_size=32) val = RandomDataset(32, num_samples) val = DataLoader(val, batch_size=32) test = RandomDataset(32, num_samples) test = DataLoader(test, batch_size=32) # init model model = BoringModel() # set up wandb config = dict(some_hparam="Logged Before Trainer starts DDP") wandb_logger = WandbLogger(log_model=True, config=config, save_code=True) # Initialize a trainer trainer = pl.Trainer( max_epochs=1, progress_bar_refresh_rate=20, num_processes=2, accelerator="ddp_cpu", logger=wandb_logger, ) # Train the model trainer.fit(model, train, val) trainer.test(dataloaders=test)
def test_multi_gpu_wandb_ddp_spawn(tmpdir): """Make sure DP/DDP + AMP work.""" from pytorch_lightning.loggers import WandbLogger tutils.set_random_master_port() model = EvalModelTemplate() wandb.run = MagicMock() wandb.init(name='name', project='project') logger = WandbLogger(name='name', offline=True) trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, gpus=2, distributed_backend='ddp_spawn', precision=16, logger=logger, ) # tutils.run_model_test(trainer_options, model) trainer = Trainer(**trainer_options) result = trainer.fit(model) assert result trainer.test(model)
def test_wandb_pickle(tmpdir): """Verify that pickling trainer with wandb logger works.""" tutils.reset_seed() wandb_dir = str(tmpdir) logger = WandbLogger(save_dir=wandb_dir, anonymous=True) assert logger is not None
def main(): import sys import wandb from config import setSeed, getConfig from main.vqvae import VQVAE from pytorch_lightning.loggers import WandbLogger import pytorch_lightning as pl from IPython import embed run = wandb.init() conf = getConfig(sys.argv[1]) conf = update_custom(conf, run.config) wandb_logger = WandbLogger(project='mineRL', name=conf['experiment'], tags=[alg, 'sweep']) wandb_logger.log_hyperparams(conf) vqvae = VQVAE(conf) trainer = pl.Trainer(gpus=1, max_epochs=conf['epochs'], progress_bar_refresh_rate=20, weights_summary='full', logger=wandb_logger, default_root_dir=f"./results/{conf['experiment']}") trainer.fit(vqvae)
def main(args, model=None) -> SummarizationModule: Path(args.output_dir).mkdir(exist_ok=True) if len(os.listdir(args.output_dir)) > 3 and args.do_train: raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if model is None: model: BaseTransformer = SummarizationModule(args) if (args.logger == "default" or args.fast_dev_run or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var")): logger = True # don't pollute wandb logs unnecessarily elif args.logger == "wandb": from pytorch_lightning.loggers import WandbLogger logger = WandbLogger(name=model.output_dir.name) elif args.logger == "wandb_shared": from pytorch_lightning.loggers import WandbLogger # TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB. logger = WandbLogger(name=model.output_dir.name, project="hf_summarization") trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir), logger=logger, # TODO: early stopping callback seems messed up ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: return model model.hparams.test_checkpoint = "" checkpoints = list( sorted( glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) if checkpoints: model.hparams.test_checkpoint = checkpoints[-1] trainer.resume_from_checkpoint = checkpoints[-1] trainer.logger.log_hyperparams(model.hparams) trainer.test( model ) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics. return model
def main( checkpoint: str, test: bool = False, overfit: float = 0, max_epochs: int = 1000, ): config: VisualElectraConfig = VisualElectraConfig() # Base BERT model config.tokenizer = AutoTokenizer.from_pretrained( "google/bert_uncased_L-4_H-512_A-8" # "bert-base-uncased" ) gen_model_name = "google/bert_uncased_L-2_H-512_A-8" disc_model_name = "google/bert_uncased_L-8_H-768_A-12" config.hidden_size = 512 gen_conf = AutoConfig.from_pretrained(gen_model_name) config.generator_model = AutoModelForMaskedLM.from_config(gen_conf) config.generator_hidden_size = 512 disc_conf = AutoConfig.from_pretrained(disc_model_name) disc_conf.is_decoder = True config.discriminator_model = AutoModel.from_config(disc_conf) config.discriminator_hidden_size = 768 full_model = VisualElectra.load_from_checkpoint(checkpoint, config=config) model = full_model.discriminator model.training_objective = TrainingObjective.Captioning model.add_lm_head() data = CocoCaptions() data.prepare_data() data.setup() logger = None fast_dev_run = test & (overfit == 0) if test is not True: logger = WandbLogger(project="final-year-project", offline=False, log_model=True, save_dir=work_dir, config={'checkpoint': checkpoint}, tags=['electra-finetune']) callbacks = [CheckpointEveryNSteps(50000)] trainer = pl.Trainer(gpus=1, fast_dev_run=fast_dev_run, default_root_dir=work_dir, log_every_n_steps=10, logger=logger, max_epochs=max_epochs, overfit_batches=overfit, callbacks=callbacks # check_val_every_n_epoch=1000 if overfit > 0 else 1, ) trainer.fit(model, data)
def train_classifier(logging=False, train=True): hparams = { 'gpus': [1], 'max_epochs': 25, 'num_classes': 700, 'feature_dimension': 512, 'model_dimension': 1024, 'pretrained_text': False, 'num_modalities': 1, 'batch_size': 32, 'learning_rate': 1e-3, 'model_path': "/home/sgurram/Projects/aai/aai/experimental/sgurram/lava/src/wandb/run-20210626_215155-yqwe58z7/files/lava/yqwe58z7/checkpoints/epoch=6-step=12529.ckpt", 'model_descriptor': 'lava timesformer 1/3 kinetics data, unshuffled', 'accumulate_grad_batches': 2, 'overfit_batches': 0, 'type_modalities': 'av', 'modality_fusion': 'concat', 'loss_funtions': ['cross_entropy'], 'metrics': None, 'optimizer': 'adam', 'scheduler': 'n/a', 'profiler': 'simple', 'default_root_dir': '/home/sgurram/Desktop/video_lava_classifer', } model = EvalLightning( num_classes=hparams['num_classes'], feature_dimension=hparams['feature_dimension'], model_dimension=hparams['model_dimension'], num_modalities=hparams['num_modalities'], batch_size=hparams['batch_size'], learning_rate=hparams['learning_rate'], model_path=hparams['model_path'], model=LAVALightning, pretrained_text=hparams['pretrained_text'], ) if logging: wandb_logger = WandbLogger(name='run', project='lava') wandb_logger.log_hyperparams(hparams) wandb_logger.watch(model, log='gradients', log_freq=10) else: wandb_logger = None if not train: return model trainer = pl.Trainer( default_root_dir=hparams['default_root_dir'], gpus=hparams['gpus'], max_epochs=hparams['max_epochs'], accumulate_grad_batches=hparams['accumulate_grad_batches'], overfit_batches=hparams['overfit_batches'], logger=wandb_logger, profiler=hparams['profiler']) trainer.fit(model)
def train(dataset_name: str, model_name: str, expt_dir: str, data_folder: str, num_workers: int = 0, is_test: bool = False, resume_from_checkpoint: str = None): seed_everything(SEED) dataset_main_folder = data_folder vocab = Vocabulary.load(join(dataset_main_folder, "vocabulary.pkl")) if model_name == "code2seq": config_function = get_code2seq_test_config if is_test else get_code2seq_default_config config = config_function(dataset_main_folder) model = Code2Seq(config, vocab, num_workers) model.half() #elif model_name == "code2class": # config_function = get_code2class_test_config if is_test else get_code2class_default_config # config = config_function(dataset_main_folder) # model = Code2Class(config, vocab, num_workers) else: raise ValueError(f"Model {model_name} is not supported") # define logger wandb_logger = WandbLogger(project=f"{model_name}-{dataset_name}", log_model=True, offline=True) wandb_logger.watch(model) # define model checkpoint callback model_checkpoint_callback = ModelCheckpoint( filepath=join(expt_dir, "{epoch:02d}-{val_loss:.4f}"), period=config.hyperparams.save_every_epoch, save_top_k=3, ) # define early stopping callback early_stopping_callback = EarlyStopping( patience=config.hyperparams.patience, verbose=True, mode="min") # use gpu if it exists gpu = 1 if torch.cuda.is_available() else None # define learning rate logger lr_logger = LearningRateLogger() trainer = Trainer( max_epochs=20, gradient_clip_val=config.hyperparams.clip_norm, deterministic=True, check_val_every_n_epoch=config.hyperparams.val_every_epoch, row_log_interval=config.hyperparams.log_every_epoch, logger=wandb_logger, checkpoint_callback=model_checkpoint_callback, early_stop_callback=early_stopping_callback, resume_from_checkpoint=resume_from_checkpoint, gpus=gpu, callbacks=[lr_logger], reload_dataloaders_every_epoch=True, ) trainer.fit(model) trainer.save_checkpoint(join(expt_dir, 'Latest.ckpt')) trainer.test()
def main(args, model_name: str, reproducible: bool, comet: bool, wandb: bool): if reproducible: seed_everything(42) args.deterministic = True args.benchmark = True if comet: from pytorch_lightning.loggers import CometLogger comet_logger = CometLogger( api_key=os.environ.get('COMET_API_KEY'), workspace=os.environ.get('COMET_WORKSPACE'), # Optional project_name=os.environ.get('COMET_PROJECT_NAME'), # Optional experiment_name=model_name # Optional ) args.logger = comet_logger if wandb: from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger( project=os.environ.get('WANDB_PROJECT_NAME'), log_model=True, sync_step=True) args.logger = wandb_logger if args.default_root_dir is None: args.default_root_dir = 'results' # Save best model model_checkpoint = ModelCheckpoint( filename=model_name + '_{epoch}', save_top_k=1, monitor='val_iou', mode='max', ) args.checkpoint_callback = model_checkpoint data = SimulatorDataModule(dataPath=args.dataPath, augment=args.augment, batch_size=args.batch_size, num_workers=8) model = RightLaneModule(lr=args.learningRate, lrRatio=args.lrRatio, decay=args.decay, num_cls=4) # Parse all trainer options available from the command line trainer = Trainer.from_argparse_args(args) trainer.fit(model, datamodule=data) # Reload best model model = RightLaneModule.load_from_checkpoint( model_checkpoint.best_model_path, dataPath=args.dataPath, num_cls=4) # Upload weights if comet: comet_logger.experiment.log_model(model_name + '_weights', model_checkpoint.best_model_path) # Perform testing trainer.test(model, datamodule=data)
def get_logger(model_config): wandb_logger = WandbLogger(project=model_config["project"], save_dir=model_config["wandb_save_dir"], id=model_config["resume_id"]) logging.info("Logger retrieved") return wandb_logger