Exemple #1
0
def lightning_test(model: pl.LightningModule, checkpoint_path: str,
                   test_output: str, datamodule: Union[pl.LightningDataModule,
                                                       DatasetPL],
                   experiment_key: str, logger: Any):
    if len(experiment_key) > 0:
        logger = CometLightningLogger(experiment_key=experiment_key,
                                      experiment_name=model.name)

    # trainer = pl.Trainer(logger=logger)
    model = model.load_from_checkpoint(checkpoint_path=checkpoint_path)

    datamodule.setup('test')
    ix_to_filename = datamodule.test_dataset.ix_to_filename
    with torch.no_grad():
        for i, (genome, target_weights) in enumerate(datamodule.test_dataset):
            genome = torch.from_numpy(genome)
            genome = genome.unsqueeze(0)
            genome = genome.unsqueeze(2)

            genome = model.pad_input(genome, model.sqz_seq_len)
            genome = torch.from_numpy(genome.astype('float32'))

            distribution = model(genome)
            distribution = distribution.squeeze()

            # write to file
            filename = ix_to_filename[i]
            np.save(os.path.join(test_output, filename), distribution)

    return
Exemple #2
0
def run_common(
    model_class: pl.LightningModule,
    datamodule_class: pl.LightningDataModule,
    infer_fn: callable = lambda x, arg: print("infer_fn do nothing~"),
    export_fn: callable = lambda x, arg: print("export_fn do nothing~")):
    args = parser_args()
    model: pl.LightningModule = None
    """ build model """
    if args.stage in ['fit', 'test']:
        with open(args.config, 'r') as f:
            config: dict = safe_load(f)
        if args.ckpt:
            print("Load from checkpoint", args.ckpt)
            model = model_class.load_from_checkpoint(args.ckpt,
                                                     strict=False,
                                                     **config['model'])
        else:
            model = model_class(**config['model'])
        datamodule = datamodule_class(**config['dataset'])
        ckpt_callback = CusModelCheckpoint(**config['checkpoint'])
        logger = TensorBoardLogger(**config['logger'])
        callbacks = None
        if 'callbacks' in config.keys():
            if config['callbacks'] is not None:
                callbacks = []
                for k, v in config['callbacks'].items():
                    callbacks.append(CALLBACKDICT[k](**v))

        trainer = pl.Trainer(checkpoint_callback=ckpt_callback,
                             logger=logger,
                             **config['trainer'])
        if args.stage == 'fit':
            trainer.fit(model, datamodule)
        elif args.stage == 'test':
            trainer.test(model, datamodule)
    elif args.stage == 'infer':
        print("Load from checkpoint", args.ckpt)
        model = model_class.load_from_checkpoint(args.ckpt, strict=False)
        args_list = parser_extra_args(args.extra)
        infer_fn(model, args_list)
    elif args.stage == 'export':
        print("Load from checkpoint", args.ckpt)
        model = model_class.load_from_checkpoint(args.ckpt, strict=False)
        args_list = parser_extra_args(args.extra)
        export_fn(model, args_list)
def load_model_from_checkpoint(ckpt_path, model_class: LightningModule = UNet):
    model = model_class.load_from_checkpoint(ckpt_path) if Path(
        ckpt_path).exists() else None
    if model is not None and torch.cuda.is_available():
        model.to(torch.device("cuda"))
    return model
Exemple #4
0
def load_model(module: pl.LightningModule, checkpoint_path: str) :
    """ Load a pytorchlightning module from checkpoint"""
    model = module.load_from_checkpoint(checkpoint_path)
    model.eval()
    return model