def test_trains_pickle(tmpdir): """Verify that pickling trainer with TRAINS logger works.""" tutils.reset_seed() # hparams = tutils.get_default_hparams() # model = LightningTestModel(hparams) TrainsLogger.set_bypass_mode(True) TrainsLogger.set_credentials( api_host='http://integration.trains.allegro.ai:8008', files_host='http://integration.trains.allegro.ai:8081', web_host='http://integration.trains.allegro.ai:8080', ) logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test") trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}) trainer2.logger.finalize() logger.finalize()
def test_trains_logger(tmpdir): """Verify that basic functionality of TRAINS logger works.""" tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) TrainsLogger.set_bypass_mode(True) TrainsLogger.set_credentials( api_host='http://integration.trains.allegro.ai:8008', files_host='http://integration.trains.allegro.ai:8081', web_host='http://integration.trains.allegro.ai:8080', ) logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test") trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger) trainer = Trainer(**trainer_options) result = trainer.fit(model) print('result finished') logger.finalize() assert result == 1, "Training failed"
def test_trains_logger(tmpdir): """Verify that basic functionality of TRAINS logger works.""" tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test") trainer_options = dict(default_save_path=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger) trainer = Trainer(**trainer_options) result = trainer.fit(model) print('result finished') assert result == 1, "Training failed"
def test_trains_pickle(tmpdir): """Verify that pickling trainer with TRAINS logger works.""" tutils.reset_seed() # hparams = tutils.get_hparams() # model = LightningTestModel(hparams) logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test") trainer_options = dict(default_save_path=tmpdir, max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0})
def test_trains_logger(tmpdir): """Verify that basic functionality of TRAINS logger works.""" model = EvalModelTemplate() TrainsLogger.set_bypass_mode(True) TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008', files_host='http://integration.trains.allegro.ai:8081', web_host='http://integration.trains.allegro.ai:8080', ) logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test") trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger ) result = trainer.fit(model) print('result finished') logger.finalize() assert result == 1, "Training failed"