def test_comet_version_without_experiment(comet): """ Test that CometLogger.version does not create an Experiment. """ api_key = "key" experiment_name = "My Name" comet.generate_guid.return_value = "1234" with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key, experiment_name=experiment_name) assert logger._experiment is None first_version = logger.version assert first_version is not None assert logger.version == first_version assert logger._experiment is None _ = logger.experiment logger.reset_experiment() second_version = logger.version == "1234" assert second_version is not None assert second_version != first_version
def test_comet_metrics_safe(comet, tmpdir, monkeypatch): """Test that CometLogger.log_metrics doesn't do inplace modification of metrics.""" _patch_comet_atexit(monkeypatch) logger = CometLogger(project_name="test", save_dir=tmpdir) metrics = {"tensor": tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), "epoch": 1} logger.log_metrics(metrics) assert metrics["tensor"].requires_grad
def test_comet_epoch_logging(comet, comet_experiment, tmpdir, monkeypatch): """ Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """ _patch_comet_atexit(monkeypatch) logger = CometLogger(project_name="test", save_dir=tmpdir) logger.log_metrics({"test": 1, "epoch": 1}, step=123) logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)
def test_comet_logger_online(): """Test comet online with mocks.""" # Test api_key given with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test both given with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet: logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general') _ = logger.experiment comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general') # Test neither given with pytest.raises(MisconfigurationException): CometLogger(workspace='dummy-test', project_name='general') # Test already exists with patch('pytorch_lightning.loggers.comet.CometExistingExperiment' ) as comet_existing: logger = CometLogger( experiment_key='test', experiment_name='experiment', api_key='key', workspace='dummy-test', project_name='general', ) _ = logger.experiment comet_existing.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general', previous_experiment='test') comet_existing().set_name.assert_called_once_with('experiment') with patch('pytorch_lightning.loggers.comet.API') as api: CometLogger(api_key='key', workspace='dummy-test', project_name='general', rest_api_key='rest') api.assert_called_once_with('rest')
def test_comet_epoch_logging(tmpdir, monkeypatch): """ Test that CometLogger removes the epoch key from the metrics dict and passes it as argument. """ _patch_comet_atexit(monkeypatch) with patch( "pytorch_lightning.loggers.comet.CometOfflineExperiment.log_metrics" ) as log_metrics: logger = CometLogger(project_name="test", save_dir=tmpdir) logger.log_metrics({"test": 1, "epoch": 1}, step=123) log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123)
def test_comet_logger_online(comet): """Test comet online with mocks.""" # Test api_key given with patch("pytorch_lightning.loggers.comet.CometExperiment" ) as comet_experiment: logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") _ = logger.experiment comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") # Test both given with patch("pytorch_lightning.loggers.comet.CometExperiment" ) as comet_experiment: logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general") _ = logger.experiment comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") # Test already exists with patch("pytorch_lightning.loggers.comet.CometExistingExperiment" ) as comet_existing: logger = CometLogger( experiment_key="test", experiment_name="experiment", api_key="key", workspace="dummy-test", project_name="general", ) _ = logger.experiment comet_existing.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general", previous_experiment="test") comet_existing().set_name.assert_called_once_with("experiment") with patch("pytorch_lightning.loggers.comet.API") as api: CometLogger(api_key="key", workspace="dummy-test", project_name="general", rest_api_key="rest") api.assert_called_once_with("rest")
def main(): parser = argparse.ArgumentParser() parser = Trainer.add_argparse_args(parser) parser.add_argument("--batch-size", type=int, default=2) args = parser.parse_args() datamodule = IRModule.load() datamodule.batch_size = args.batch_size model = Model.from_tinybert() if COMET_INSTALLED: comet_logger = CometLogger( api_key=os.environ.get("COMET_API_KEY"), experiment_name="mtg-search", log_graph=False, log_code=False, log_env_details=False, disabled=True, ) comet_logger.log_hyperparams(asdict(model.config)) key = comet_logger.experiment.get_key() else: key = uuid.uuid4().hex comet_logger = True # to pass logger=True to Trainer model.config.key = key callbacks = [ ModelCheckpoint( dirpath=MODELS_DIR, save_top_k=1, monitor="val_acc", filename=key, ) ] trainer = Trainer.from_argparse_args( args, logger=comet_logger, callbacks=callbacks, num_sanity_val_steps=0, val_check_interval=10, ) trainer.fit(model, datamodule=datamodule) trainer.test(model, datamodule=datamodule)
def run(): # torch.multiprocessing.freeze_support() optimizer = Optimizer(optimizer_config) for parameters in optimizer.get_parameters(): hyperparameters = Namespace(**parameters["parameters"]) model = PyTorchLightningModel(hparams=hyperparameters) comet_logger = CometLogger( api_key=get_config("comet.api_key"), rest_api_key=get_config("comet.api_key"), optimizer_data=parameters, ) trainer = Trainer( max_epochs=1, # early_stop_callback=True, # requires val_loss be logged logger=[comet_logger], # num_processes=2, # distributed_backend='ddp_cpu' ) trainer.fit(model)
def create_logger(cli_args=None, data_module=None, runner_type=None): api_key = os.environ.get("COMET_API_KEY") workspace = os.environ.get("COMET_WORKSPACE") logger = TensorBoardLogger("lightning_logs") today = datetime.today() if api_key: tags = [] if cli_args: model_type = cli_args.model_type extra_tag = cli_args.tag data_module_name = data_module.__class__.__name__ tags = [runner_type, extra_tag, data_module_name] tags = [tag for tag in tags if tag != None] logger = CometLogger( api_key=api_key, workspace=workspace, project_name="master-jk-pl", experiment_name=today.strftime("%y/%m/%d - %H:%M"), ) logger.experiment.add_tags(tags) logger.experiment.log_table("data.csv", tabular_data=data_module.data) logger.experiment.log_code(file_name="config.yaml") if runner_type: file_name = ("multiTaskLearner.py" if runner_type == "mtl" else "singleTaskLearner.py") logger.experiment.log_code(file_name=f"models/{file_name}") else: print("No Comet-API-key found, defaulting to Tensorboard", flush=True) return logger
def create_logger(experiment_name): COMET = toml.load('config.toml')['comet'] print('Running experiment', experiment_name) logger = CometLogger(api_key=COMET["api_key"], project_name=COMET["project_name"], experiment_name=experiment_name) return logger
def test_comet_logger(tmpdir, monkeypatch): """Verify that basic functionality of Comet.ml logger works.""" # prevent comet logger from trying to print at exit, since # pytest's stdout/stderr redirection breaks it import atexit monkeypatch.setattr(atexit, 'register', lambda _: None) tutils.reset_seed() hparams = tutils.get_default_hparams() model = LightningTestModel(hparams) comet_dir = os.path.join(tmpdir, 'cometruns') # We test CometLogger in offline mode with local saves logger = CometLogger( save_dir=comet_dir, project_name='general', workspace='dummy-test', ) trainer_options = dict(default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.05, logger=logger) trainer = Trainer(**trainer_options) result = trainer.fit(model) trainer.logger.log_metrics({'acc': torch.ones(1)}) assert result == 1, 'Training failed'
def test_comet_logger(tmpdir, monkeypatch): """Verify that basic functionality of Comet.ml logger works.""" # prevent comet logger from trying to print at exit, since # pytest's stdout/stderr redirection breaks it import atexit monkeypatch.setattr(atexit, "register", lambda _: None) tutils.reset_seed() hparams = tutils.get_hparams() model = LightningTestModel(hparams) comet_dir = os.path.join(tmpdir, "cometruns") # We test CometLogger in offline mode with local saves logger = CometLogger( save_dir=comet_dir, project_name="general", workspace="dummy-test", ) trainer_options = dict(default_save_path=tmpdir, max_epochs=1, train_percent_check=0.01, logger=logger) trainer = Trainer(**trainer_options) result = trainer.fit(model) print('result finished') assert result == 1, "Training failed"
def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch): """Test that the logger creates the folders and files in the right place.""" _patch_comet_atexit(monkeypatch) comet.config.get_api_key.return_value = None comet.generate_guid.return_value = "4321" logger = CometLogger(project_name="test", save_dir=tmpdir) assert not os.listdir(tmpdir) assert logger.mode == "offline" assert logger.save_dir == tmpdir assert logger.name == "test" assert logger.version == "4321" _ = logger.experiment comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name="test") # mock return values of experiment logger.experiment.id = "1" logger.experiment.project_name = "test" model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3) assert trainer.log_dir == logger.save_dir trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "test" / "1" / "checkpoints") assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} assert trainer.log_dir == logger.save_dir
def test_comet_pickle(tmpdir, monkeypatch): """Verify that pickling trainer with comet logger works.""" # prevent comet logger from trying to print at exit, since # pytest's stdout/stderr redirection breaks it import atexit monkeypatch.setattr(atexit, "register", lambda _: None) tutils.reset_seed() # hparams = tutils.get_hparams() # model = LightningTestModel(hparams) comet_dir = os.path.join(tmpdir, "cometruns") # We test CometLogger in offline mode with local saves logger = CometLogger( save_dir=comet_dir, project_name="general", workspace="dummy-test", ) trainer_options = dict(default_save_path=tmpdir, max_epochs=1, logger=logger) trainer = Trainer(**trainer_options) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0})
def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch): """ Test that the logger creates the folders and files in the right place. """ _patch_comet_atexit(monkeypatch) comet.config.get_api_key.return_value = None comet.generate_guid.return_value = "4321" logger = CometLogger(project_name='test', save_dir=tmpdir) assert not os.listdir(tmpdir) assert logger.mode == 'offline' assert logger.save_dir == tmpdir assert logger.name == 'test' assert logger.version == "4321" _ = logger.experiment comet_experiment.assert_called_once_with(offline_directory=tmpdir, project_name='test') # mock return values of experiment logger.experiment.id = '1' logger.experiment.project_name = 'test' model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints') assert set(os.listdir( trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
def test_comet_logger_dirs_creation(tmpdir, monkeypatch): """ Test that the logger creates the folders and files in the right place. """ # prevent comet logger from trying to print at exit, since # pytest's stdout/stderr redirection breaks it import atexit monkeypatch.setattr(atexit, 'register', lambda _: None) logger = CometLogger(project_name='test', save_dir=tmpdir) assert not os.listdir(tmpdir) assert logger.mode == 'offline' assert logger.save_dir == tmpdir _ = logger.experiment version = logger.version assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'} model = EvalModelTemplate() trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) assert trainer.ckpt_path == trainer.weights_save_path == ( tmpdir / 'test' / version / 'checkpoints') assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}
def test_comet_logger_manual_experiment_key(comet): """Test that Comet Logger respects manually set COMET_EXPERIMENT_KEY.""" api_key = "key" experiment_key = "96346da91469407a85641afe5766b554" instantation_environ = {} def save_os_environ(*args, **kwargs): nonlocal instantation_environ instantation_environ = os.environ.copy() return DEFAULT # Test api_key given with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}): with patch('pytorch_lightning.loggers.comet.CometExperiment', side_effect=save_os_environ) as comet_experiment: logger = CometLogger(api_key=api_key) assert logger.version == experiment_key assert logger._experiment is None _ = logger.experiment comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key
def main(args, model_name: str, reproducible: bool, comet: bool, wandb: bool): if reproducible: seed_everything(42) args.deterministic = True args.benchmark = True if comet: from pytorch_lightning.loggers import CometLogger comet_logger = CometLogger( api_key=os.environ.get('COMET_API_KEY'), workspace=os.environ.get('COMET_WORKSPACE'), # Optional project_name=os.environ.get('COMET_PROJECT_NAME'), # Optional experiment_name=model_name # Optional ) args.logger = comet_logger if wandb: from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger( project=os.environ.get('WANDB_PROJECT_NAME'), log_model=True, sync_step=True) args.logger = wandb_logger if args.default_root_dir is None: args.default_root_dir = 'results' # Save best model model_checkpoint = ModelCheckpoint( filename=model_name + '_{epoch}', save_top_k=1, monitor='val_iou', mode='max', ) args.checkpoint_callback = model_checkpoint data = SimulatorDataModule(dataPath=args.dataPath, augment=args.augment, batch_size=args.batch_size, num_workers=8) model = RightLaneModule(lr=args.learningRate, lrRatio=args.lrRatio, decay=args.decay, num_cls=4) # Parse all trainer options available from the command line trainer = Trainer.from_argparse_args(args) trainer.fit(model, datamodule=data) # Reload best model model = RightLaneModule.load_from_checkpoint( model_checkpoint.best_model_path, dataPath=args.dataPath, num_cls=4) # Upload weights if comet: comet_logger.experiment.log_model(model_name + '_weights', model_checkpoint.best_model_path) # Perform testing trainer.test(model, datamodule=data)
def _get_comet_logger(exp_name: str, project_name: str) -> LightningLoggerBase: # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( api_key=os.environ.get('COMET_API_KEY'), experiment_name=exp_name, project_name=project_name, ) return comet_logger
def build_comet_logger(save_dir: str, config: Bunch) -> CometLogger: return CometLogger( save_dir=save_dir, workspace=config.comet_workspace, project_name=config.comet_project_name, api_key=config.comet_api_key if config.use_comet_experiments else None, experiment_name=config.experiment_name, )
def train_mult(config, checkpoint_dir=None): hyp_params.attn_dropout = config["attn_dropout"] hyp_params.attn_dropout_a = config["attn_dropout_a"] hyp_params.attn_dropout_v = config["attn_dropout_v"] hyp_params.embed_dropout = config["embed_dropout"] hyp_params.out_dropout = config["out_dropout"] hyp_params.relu_dropout = config["relu_dropout"] hyp_params.res_dropout = config["res_dropout"] # hyp_params.layers = int(config["layers"]) # hyp_params.num_heads = int(config["num_heads"]) # hyp_params.project_dim = int(config["num_heads"]) * int(config["head_dim"]) hyp_params.lr = config["lr"] hyp_params.weight_decay = config["weight_decay"] comet_logger = CometLogger( api_key="cgss7piePhyFPXRw1J2uUEjkQ", workspace="transformer", project_name=hyp_params.project_name, save_dir="logs/comet_ml", ) experiement_key = comet_logger.experiment.get_key() csv_logger = CSVLogger("logs/csv", name=experiement_key) early_stopping = EarlyStopping( monitor="valid_1mae", patience=10, verbose=True, mode="max" ) checkpoint = ModelCheckpoint(save_top_k=1, monitor="valid_1mae", mode="max") # tune_reporter = TuneReportCallback(["valid_loss", "valid_1mae"]) tune_checkpoint_reporter = TuneReportCheckpointCallback( metrics=["valid_loss", "valid_1mae"] ) model = MULTModelWarpedAll(hyp_params, early_stopping=early_stopping) trainer = pl.Trainer( gpus=1, max_epochs=hyp_params.num_epochs, log_every_n_steps=1, callbacks=[early_stopping, checkpoint, tune_checkpoint_reporter], logger=[csv_logger, comet_logger], limit_train_batches=hyp_params.limit, limit_val_batches=hyp_params.limit, weights_summary="full", weights_save_path="logs/weights", progress_bar_refresh_rate=0, ) if checkpoint_dir is not None: ck = th.load(os.path.join(checkpoint_dir, "checkpoint")) model.load_state_dict(ck["state_dict"]) trainer.current_epoch = ck["epoch"] trainer.fit(model) ck = th.load(checkpoint.best_model_path) model.load_state_dict(ck["state_dict"]) trainer.test(model)
def test_comet_name_default(comet): """ Test that CometLogger.name don't create an Experiment and returns a default value. """ api_key = "key" with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key) assert logger._experiment is None assert logger.name == "comet-default" assert logger._experiment is None
def test_comet_name_project_name(comet): """ Test that CometLogger.name does not create an Experiment and returns project name if passed. """ api_key = "key" project_name = "My Project Name" with patch('pytorch_lightning.loggers.comet.CometExperiment'): logger = CometLogger(api_key=api_key, project_name=project_name) assert logger._experiment is None assert logger.name == project_name assert logger._experiment is None
def test_comet_logger_experiment_name(comet): """Test that Comet Logger experiment name works correctly.""" api_key = "key" experiment_name = "My Name" # Test api_key given with patch("pytorch_lightning.loggers.comet.CometExperiment") as comet_experiment: logger = CometLogger(api_key=api_key, experiment_name=experiment_name) assert logger._experiment is None _ = logger.experiment comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) comet_experiment().set_name.assert_called_once_with(experiment_name)
def main(args): torch.manual_seed(0) with open(args.input_json, "r") as f: input_json = json.load(f) dataset_path = input_json["dataset_path"] criteria = input_json["criteria"] in_channel = input_json["in_channel"] num_class = input_json["num_class"] epoch = input_json["epoch"] batch_size = input_json["batch_size"] num_workers = input_json["num_workers"] model_savepath = input_json["model_savepath"] learning_rate = input_json["learning_rate"] gpu_ids = input_json["gpu_ids"] api_key = input_json["api_key"] project_name = input_json["project_name"] experiment_name = input_json["experiment_name"] log = input_json["log"] comet_logger = CometLogger( api_key = api_key, project_name = project_name, experiment_name = experiment_name, save_dir = log ) #torch.manual_seed(0) system = UNetSystem( dataset_path = dataset_path, criteria = criteria, in_channel = in_channel, num_class = num_class, learning_rate = learning_rate, batch_size = batch_size, num_workers = num_workers, checkpoint = BestAndLatestModelCheckpoint(model_savepath), ) trainer = pl.Trainer( num_sanity_val_steps = 0, max_epochs = epoch, checkpoint_callback = None, logger = comet_logger, gpus = gpu_ids ) trainer.fit(system)
def main(args): criteria = {"train": args.train_list, "val": args.val_list} image_path_list = [ args.image_path_layer_1, args.image_path_layer_2, args.image_path_thin ] system = UNetSystem(image_path_list=image_path_list, label_path=args.label_path, criteria=criteria, in_channel_1=args.in_channel_1, in_channel_2=args.in_channel_2, in_channel_thin=args.in_channel_thin, out_channel_thin=args.out_channel_thin, num_class=args.num_class, learning_rate=args.lr, batch_size=args.batch_size, checkpoint=BestAndLatestModelCheckpoint( args.model_savepath), num_workers=args.num_workers) if args.api_key != "No": from pytorch_lightning.loggers import CometLogger comet_logger = CometLogger(api_key=args.api_key, project_name=args.project_name, experiment_name=args.experiment_name, save_dir=args.log) trainer = pl.Trainer(num_sanity_val_steps=0, max_epochs=args.epoch, checkpoint_callback=None, logger=comet_logger, gpus=args.gpu_ids) else: trainer = pl.Trainer(num_sanity_val_steps=0, max_epochs=args.epoch, checkpoint_callback=None, gpus=args.gpu_ids) trainer.fit(system) # Make modeleweight read-only if not args.overwrite: for f in Path(args.model_savepath).glob("*.pkl"): print(f) os.chmod(f, 0o444)
def init_loggers(cfg: DictConfig): comet_cfg = cfg.get('comet') tensorboard_cfg = cfg.get('tensorboard') global comet_logger, tensorboard_logger comet_logger = CometLogger( api_key=comet_cfg.get('COMET_API_KEY'), workspace=comet_cfg.get('workspace'), project_name=comet_cfg.get('project_name'), # Optional experiment_name=comet_cfg.get('experiment_prefix_name') + comet_cfg.get('experiment_fixed_name'), # Optional experiment_key=comet_cfg.get( 'experiment_key') # restore previous experiment ) tensorboard_logger = TensorBoardLogger( save_dir=tensorboard_cfg.get("save_dir"), name=tensorboard_cfg.get("name")) comet_logger.experiment.log_code(file_name=None, folder='../../../../') return tensorboard_logger, comet_logger
def main(): comet_logger = CometLogger( api_key='5zqkkwKFbkhDgnFn7Alsby6py', workspace='clrkwng', project_name='clevr-network', experiment_name='lightning', ) data_module = CLEVRDataModule() model = LightningCLEVRClassifier([1, 1, 1, 1], 3) trainer = pl.Trainer( gpus=1, profiler=True, logger=comet_logger, check_val_every_n_epoch=5, max_epochs=100, ) trainer.fit(model, data_module)
def run(cfg: DictConfig) -> None: """ Run pytorch-lightning model Args: cfg: hydra config """ set_seed(cfg.training.seed) hparams = flatten_omegaconf(cfg) model = LitM5NBeats(hparams=hparams, cfg=cfg) early_stopping = pl.callbacks.EarlyStopping( **cfg.callbacks.early_stopping.params) model_checkpoint = pl.callbacks.ModelCheckpoint( **cfg.callbacks.model_checkpoint.params) lr_logger = pl.callbacks.LearningRateLogger() logger = [] if cfg.logging.log: tb_logger = TensorBoardLogger(save_dir=cfg.general.save_dir) comet_logger = CometLogger( save_dir=cfg.general.save_dir, workspace=cfg.general.workspace, project_name=cfg.general.project_name, api_key=cfg.private.comet_api, experiment_name=os.getcwd().split('\\')[-1], ) # wandb_logger = WandbLogger(name=os.getcwd().split('\\')[-1], # save_dir=cfg.general.save_dir, # project=cfg.general.project_name # ) logger = [tb_logger, comet_logger] trainer = pl.Trainer( logger=logger, early_stop_callback=early_stopping, checkpoint_callback=model_checkpoint, callbacks=[lr_logger], gradient_clip_val=0.5, **cfg.trainer, ) trainer.fit(model)
def main(args): criteria = {"train": args.train_list, "val": args.val_list} system_path = "." + args.module_name + ".system" checkpoint_path = "." + args.module_name + ".modelCheckpoint" system_module = import_module(system_path, "model") checkpoint_module = import_module(checkpoint_path, "model") UNetSystem = getattr(system_module, args.system_name) checkpoint = getattr(checkpoint_module, args.checkpoint_name) system = UNetSystem(dataset_path=args.dataset_path, criteria=criteria, in_channel=args.in_channel, num_class=args.num_class, learning_rate=args.lr, batch_size=args.batch_size, num_workers=args.num_workers, checkpoint=checkpoint(args.model_savepath)) if args.api_key != "No": from pytorch_lightning.loggers import CometLogger comet_logger = CometLogger(api_key=args.api_key, project_name=args.project_name, experiment_name=args.experiment_name, save_dir=args.log) trainer = pl.Trainer(num_sanity_val_steps=0, max_epochs=args.epoch, checkpoint_callback=None, logger=comet_logger, gpus=args.gpu_ids) else: trainer = pl.Trainer(num_sanity_val_steps=0, max_epochs=args.epoch, checkpoint_callback=None, gpus=args.gpu_ids) trainer.fit(system) # Make modeleweight read-only if not args.overwrite: for f in Path(args.model_savepath).glob("*.pkl"): print(f) os.chmod(f, 0o444)