def lightning_test(model: pl.LightningModule, checkpoint_path: str, test_output: str, datamodule: Union[pl.LightningDataModule, DatasetPL], experiment_key: str, logger: Any): if len(experiment_key) > 0: logger = CometLightningLogger(experiment_key=experiment_key, experiment_name=model.name) # trainer = pl.Trainer(logger=logger) model = model.load_from_checkpoint(checkpoint_path=checkpoint_path) datamodule.setup('test') ix_to_filename = datamodule.test_dataset.ix_to_filename with torch.no_grad(): for i, (genome, target_weights) in enumerate(datamodule.test_dataset): genome = torch.from_numpy(genome) genome = genome.unsqueeze(0) genome = genome.unsqueeze(2) genome = model.pad_input(genome, model.sqz_seq_len) genome = torch.from_numpy(genome.astype('float32')) distribution = model(genome) distribution = distribution.squeeze() # write to file filename = ix_to_filename[i] np.save(os.path.join(test_output, filename), distribution) return
def run_common( model_class: pl.LightningModule, datamodule_class: pl.LightningDataModule, infer_fn: callable = lambda x, arg: print("infer_fn do nothing~"), export_fn: callable = lambda x, arg: print("export_fn do nothing~")): args = parser_args() model: pl.LightningModule = None """ build model """ if args.stage in ['fit', 'test']: with open(args.config, 'r') as f: config: dict = safe_load(f) if args.ckpt: print("Load from checkpoint", args.ckpt) model = model_class.load_from_checkpoint(args.ckpt, strict=False, **config['model']) else: model = model_class(**config['model']) datamodule = datamodule_class(**config['dataset']) ckpt_callback = CusModelCheckpoint(**config['checkpoint']) logger = TensorBoardLogger(**config['logger']) callbacks = None if 'callbacks' in config.keys(): if config['callbacks'] is not None: callbacks = [] for k, v in config['callbacks'].items(): callbacks.append(CALLBACKDICT[k](**v)) trainer = pl.Trainer(checkpoint_callback=ckpt_callback, logger=logger, **config['trainer']) if args.stage == 'fit': trainer.fit(model, datamodule) elif args.stage == 'test': trainer.test(model, datamodule) elif args.stage == 'infer': print("Load from checkpoint", args.ckpt) model = model_class.load_from_checkpoint(args.ckpt, strict=False) args_list = parser_extra_args(args.extra) infer_fn(model, args_list) elif args.stage == 'export': print("Load from checkpoint", args.ckpt) model = model_class.load_from_checkpoint(args.ckpt, strict=False) args_list = parser_extra_args(args.extra) export_fn(model, args_list)
def load_model_from_checkpoint(ckpt_path, model_class: LightningModule = UNet): model = model_class.load_from_checkpoint(ckpt_path) if Path( ckpt_path).exists() else None if model is not None and torch.cuda.is_available(): model.to(torch.device("cuda")) return model
def load_model(module: pl.LightningModule, checkpoint_path: str) : """ Load a pytorchlightning module from checkpoint""" model = module.load_from_checkpoint(checkpoint_path) model.eval() return model