def configure_optimizers(self):
        if 'decoder_lr' in self.cfg.optimizer.params.keys():
            params = [
                {
                    'params': self.model.decoder.parameters(),
                    'lr': self.cfg.optimizer.params.lr
                },
                {
                    'params': self.model.encoder.parameters(),
                    'lr': self.cfg.optimizer.params.decoder_lr
                },
            ]
            optimizer = load_obj(self.cfg.optimizer.class_name)(params)

        else:
            optimizer = load_obj(self.cfg.optimizer.class_name)(
                self.model.parameters(), **self.cfg.optimizer.params)
        scheduler = load_obj(self.cfg.scheduler.class_name)(
            optimizer, **self.cfg.scheduler.params)

        return (
            [optimizer],
            [{
                'scheduler': scheduler,
                'interval': self.cfg.scheduler.step,
                'monitor': self.cfg.scheduler.monitor
            }],
        )
def test_schedulers(sch_name: str) -> None:
    scheduler_name = sch_name.split('.')[0]
    with initialize(config_path='../conf'):
        cfg = compose(
            config_name='config', overrides=[f'scheduler={scheduler_name}', 'optimizer=sgd', 'private=default']
        )
        optimizer = load_obj(cfg.optimizer.class_name)(torch.nn.Linear(1, 1).parameters(), **cfg.optimizer.params)
        load_obj(cfg.scheduler.class_name)(optimizer, **cfg.scheduler.params)
Example #3
0
def run(cfg: DictConfig) -> None:
    """
    Run pytorch-lightning model

    Args:
        new_dir:
        cfg: hydra config

    """
    set_seed(cfg.training.seed)
    run_name = os.path.basename(os.getcwd())
    hparams = flatten_omegaconf(cfg)

    cfg.callbacks.model_checkpoint.params.filepath = os.getcwd() + cfg.callbacks.model_checkpoint.params.filepath
    callbacks = []
    for callback in cfg.callbacks.other_callbacks:
        if callback.params:
            callback_instance = load_obj(callback.class_name)(**callback.params)
        else:
            callback_instance = load_obj(callback.class_name)()
        callbacks.append(callback_instance)

    loggers = []
    if cfg.logging.log:
        for logger in cfg.logging.loggers:
            if 'experiment_name' in logger.params.keys():
                logger.params['experiment_name'] = run_name
            loggers.append(load_obj(logger.class_name)(**logger.params))

    callbacks.append(EarlyStopping(**cfg.callbacks.early_stopping.params))

    trainer = pl.Trainer(
        logger=loggers,
        # early_stop_callback=EarlyStopping(**cfg.callbacks.early_stopping.params),
        checkpoint_callback=ModelCheckpoint(**cfg.callbacks.model_checkpoint.params),
        callbacks=callbacks,
        **cfg.trainer,
    )

    model = load_obj(cfg.training.lightning_module_name)(hparams=hparams, cfg=cfg)
    dm = load_obj(cfg.datamodule.data_module_name)(hparams=hparams, cfg=cfg)
    trainer.fit(model, dm)

    if cfg.general.save_pytorch_model and cfg.general.save_best:
        if os.path.exists(trainer.checkpoint_callback.best_model_path):  # type: ignore
            best_path = trainer.checkpoint_callback.best_model_path  # type: ignore
            # extract file name without folder
            save_name = os.path.basename(os.path.normpath(best_path))
            model = model.load_from_checkpoint(best_path, hparams=hparams, cfg=cfg, strict=False)
            model_name = f'saved_models/best_{save_name}'.replace('.ckpt', '.pth')
            torch.save(model.model.state_dict(), model_name)
        else:
            os.makedirs('saved_models', exist_ok=True)
            model_name = 'saved_models/last.pth'
            torch.save(model.model.state_dict(), model_name)

    if cfg.general.convert_to_jit and os.path.exists(trainer.checkpoint_callback.best_model_path):  # type: ignore
        convert_to_jit(model, save_name, cfg)
Example #4
0
 def __init__(self, hparams: Dict[str, float], cfg: DictConfig):
     super(LitImageClassification, self).__init__()
     self.cfg = cfg
     self.hparams: Dict[str, float] = hparams
     self.model = load_obj(cfg.model.class_name)(cfg=cfg)
     if not cfg.metric.params:
         self.metric = load_obj(cfg.metric.class_name)()
     else:
         self.metric = load_obj(cfg.metric.class_name)(**cfg.metric.params)
 def __init__(self, hparams: DictConfig, cfg: DictConfig, tag_to_idx: Dict):
     super(LitNER, self).__init__()
     self.cfg = cfg
     self.hparams: Dict[str, float] = hparams
     self.tag_to_idx = tag_to_idx
     self.model = load_obj(cfg.model.class_name)(
         embeddings_dim=cfg.datamodule.embeddings_dim,
         tag_to_idx=tag_to_idx,
         **cfg.model.params)
     if not cfg.metric.params:
         self.metric = load_obj(cfg.metric.class_name)()
     else:
         self.metric = load_obj(cfg.metric.class_name)(**cfg.metric.params)
    def __init__(self, cfg: DictConfig) -> None:
        """
        Model class.

        Args:
            cfg: main config
        """
        super().__init__()
        self.encoder = load_obj(
            cfg.model.encoder.class_name)(**cfg.model.encoder.params)
        self.decoder = load_obj(cfg.model.decoder.class_name)(
            output_dimension=self.encoder.output_dimension,
            **cfg.model.decoder.params)
    def configure_optimizers(self):
        optimizer = load_obj(self.cfg.optimizer.class_name)(
            self.model.parameters(), **self.cfg.optimizer.params)
        scheduler = load_obj(self.cfg.scheduler.class_name)(
            optimizer, **self.cfg.scheduler.params)

        return (
            [optimizer],
            [{
                'scheduler': scheduler,
                'interval': self.cfg.scheduler.step,
                'monitor': self.cfg.scheduler.monitor
            }],
        )
Example #8
0
    def setup(self, stage=None):
        mapping_dict = {
            'n01440764': 0,
            'n02102040': 1,
            'n02979186': 2,
            'n03000684': 3,
            'n03028079': 4,
            'n03394916': 5,
            'n03417042': 6,
            'n03425413': 7,
            'n03445777': 8,
            'n03888257': 9,
        }
        train_labels = []
        train_images = []
        for folder in glob.glob(f'{self.cfg.datamodule.path}/train/*'):
            class_name = os.path.basename(os.path.normpath(folder))
            for filename in glob.glob(f'{folder}/*'):
                train_labels.append(mapping_dict[class_name])
                train_images.append(filename)

        val_labels = []
        val_images = []

        for folder in glob.glob(f'{self.cfg.datamodule.path}/val/*'):
            class_name = os.path.basename(os.path.normpath(folder))
            for filename in glob.glob(f'{folder}/*'):
                val_labels.append(mapping_dict[class_name])
                val_images.append(filename)

        if self.cfg.training.debug:
            train_labels = train_labels[:1000]
            train_images = train_images[:1000]
            val_labels = val_labels[:1000]
            val_images = val_images[:1000]

        # train dataset
        dataset_class = load_obj(self.cfg.datamodule.class_name)

        # initialize augmentations
        train_augs = load_augs(self.cfg['augmentation']['train']['augs'])
        valid_augs = load_augs(self.cfg['augmentation']['valid']['augs'])

        self.train_dataset = dataset_class(
            image_names=train_images,
            labels=train_labels,
            transforms=train_augs,
            mode='train',
            labels_to_ohe=self.cfg.datamodule.labels_to_ohe,
            n_classes=self.cfg.training.n_classes,
        )
        self.valid_dataset = dataset_class(
            image_names=val_images,
            labels=val_labels,
            transforms=valid_augs,
            mode='valid',
            labels_to_ohe=self.cfg.datamodule.labels_to_ohe,
            n_classes=self.cfg.training.n_classes,
        )
Example #9
0
def run(cfg: DictConfig) -> None:
    """
    Run pytorch-lightning model

    Args:
        new_dir:
        cfg: hydra config

    """
    set_seed(cfg.training.seed)
    hparams = flatten_omegaconf(cfg)

    cfg.callbacks.model_checkpoint.params.filepath = os.getcwd() + cfg.callbacks.model_checkpoint.params.filepath
    callbacks = []
    for callback in cfg.callbacks.other_callbacks:
        if callback.params:
            callback_instance = load_obj(callback.class_name)(**callback.params)
        else:
            callback_instance = load_obj(callback.class_name)()
        callbacks.append(callback_instance)

    loggers = []
    if cfg.logging.log:
        for logger in cfg.logging.loggers:
            loggers.append(load_obj(logger.class_name)(**logger.params))

    trainer = pl.Trainer(
        logger=loggers,
        early_stop_callback=EarlyStopping(**cfg.callbacks.early_stopping.params),
        checkpoint_callback=ModelCheckpoint(**cfg.callbacks.model_checkpoint.params),
        callbacks=callbacks,
        **cfg.trainer,
    )

    model = load_obj(cfg.training.lightning_module_name)(hparams=hparams, cfg=cfg)
    dm = load_obj(cfg.datamodule.data_module_name)(hparams=hparams, cfg=cfg)
    trainer.fit(model, dm)

    if cfg.general.save_pytorch_model:
        # save as a simple torch model
        # TODO save not last, but best - for this load the checkpoint and save pytorch model from it
        os.makedirs('saved_models', exist_ok=True)
        model_name = 'saved_models/best.pth'
        print(model_name)
        torch.save(model.model.state_dict(), model_name)
 def __init__(self, cfg: DictConfig):
     super(LitImageClassification, self).__init__()
     self.cfg = cfg
     self.model = load_obj(cfg.model.class_name)(cfg=cfg)
     if 'params' in self.cfg.loss:
         self.loss = load_obj(cfg.loss.class_name)(**self.cfg.loss.params)
     else:
         self.loss = load_obj(cfg.loss.class_name)()
     self.metrics = torch.nn.ModuleDict(
         {
             self.cfg.metric.metric.metric_name: load_obj(self.cfg.metric.metric.class_name)(
                 **cfg.metric.metric.params
             )
         }
     )
     if 'other_metrics' in self.cfg.metric.keys():
         for metric in self.cfg.metric.other_metrics:
             self.metrics.update({metric.metric_name: load_obj(metric.class_name)(**metric.params)})
Example #11
0
 def __init__(self, cfg: DictConfig, tag_to_idx: Dict):
     super(LitNER, self).__init__()
     self.cfg = cfg
     self.tag_to_idx = tag_to_idx
     self.model = load_obj(cfg.model.class_name)(
         embeddings_dim=cfg.datamodule.embeddings_dim,
         tag_to_idx=tag_to_idx,
         **cfg.model.params)
     self.metrics = torch.nn.ModuleDict({
         self.cfg.metric.metric.metric_name:
         load_obj(
             self.cfg.metric.metric.class_name)(**cfg.metric.metric.params)
     })
     if 'other_metrics' in self.cfg.metric.keys():
         for metric in self.cfg.metric.other_metrics:
             self.metrics.update({
                 metric.metric_name:
                 load_obj(metric.class_name)(**metric.params)
             })
Example #12
0
    def setup(self, stage=None):

        train = pd.read_csv(self.cfg.datamodule.train_path)
        train = train.rename(columns={'image_id': 'image_name'})
        train['image_name'] = train['image_name'] + '.jpg'

        # for fast training
        if self.cfg.training.debug:
            train, valid = train_test_split(train, test_size=0.1, random_state=self.cfg.training.seed)
            train = train[:1000]
            valid = valid[:1000]

        else:

            folds = list(
                stratified_group_k_fold(
                    X=train.index, y=train['target'], groups=train['patient_id'], k=self.cfg.datamodule.n_folds
                )
            )
            train_idx, valid_idx = folds[self.cfg.datamodule.fold_n]

            valid = train.iloc[valid_idx]
            train = train.iloc[train_idx]

        # train dataset
        dataset_class = load_obj(self.cfg.datamodule.class_name)

        # initialize augmentations
        train_augs = load_augs(self.cfg['augmentation']['train']['augs'])
        valid_augs = load_augs(self.cfg['augmentation']['valid']['augs'])

        # self.train_dataset = dataset_class(df=train, mode='train', img_path=self.cfg.datamodule.train_image_path,
        #                                    transforms=train_augs)
        #
        # self.valid_dataset = dataset_class(df=valid, mode='valid', img_path=self.cfg.datamodule.train_image_path,
        #                                    transforms=valid_augs)
        self.train_dataset = dataset_class(
            image_names=train['image_name'].values,
            transforms=train_augs,
            labels=train['target'].values,
            img_path=self.cfg.datamodule.train_image_path,
            mode='train',
            labels_to_ohe=False,
            n_classes=self.cfg.training.n_classes,
        )
        self.valid_dataset = dataset_class(
            image_names=valid['image_name'].values,
            transforms=valid_augs,
            labels=valid['target'].values,
            img_path=self.cfg.datamodule.train_image_path,
            mode='valid',
            labels_to_ohe=False,
            n_classes=self.cfg.training.n_classes,
        )
 def train_dataloader(self):
     train_loader = torch.utils.data.DataLoader(
         self.train_dataset,
         batch_size=self.cfg.datamodule.batch_size,
         num_workers=self.cfg.datamodule.num_workers,
         pin_memory=self.cfg.datamodule.pin_memory,
         shuffle=True,
         collate_fn=load_obj(self.cfg.datamodule.collate_fn)(
             **self.cfg.datamodule.mix_params)
         if self.cfg.datamodule.collate_fn else None,
     )
     return train_loader
Example #14
0
def get_training_datasets(cfg: DictConfig) -> Dict:
    """
    Get datases for modelling

    Args:
        cfg: config

    Returns:
        dict with datasets
    """

    train = pd.read_csv(cfg.data.train_path)
    train = train.rename(columns={'image_id': 'image_name'})

    # for fast training
    if cfg.training.debug:
        train, valid = train_test_split(train,
                                        test_size=0.1,
                                        random_state=cfg.training.seed)
        train = train[:100]
        valid = valid[:100]

    else:

        folds = list(
            stratified_group_k_fold(X=train.index,
                                    y=train['target'],
                                    groups=train['patient_id'],
                                    k=cfg.data.n_folds))
        train_idx, valid_idx = folds[cfg.data.fold_n]

        valid = train.iloc[valid_idx]
        train = train.iloc[train_idx]

    # train dataset
    dataset_class = load_obj(cfg.dataset.class_name)

    # initialize augmentations
    train_augs = load_augs(cfg['augmentation']['train']['augs'])
    valid_augs = load_augs(cfg['augmentation']['valid']['augs'])

    train_dataset = dataset_class(df=train,
                                  mode='train',
                                  img_path=cfg.data.train_image_path,
                                  transforms=train_augs)

    valid_dataset = dataset_class(df=valid,
                                  mode='valid',
                                  img_path=cfg.data.train_image_path,
                                  transforms=valid_augs)

    return {'train': train_dataset, 'valid': valid_dataset}
Example #15
0
def load_augs(cfg: DictConfig) -> A.Compose:
    """
    Load albumentations

    Args:
        cfg:

    Returns:
        compose object
    """
    augs = []
    for a in cfg:
        if a['class_name'] == 'albumentations.OneOf':
            small_augs = []
            for small_aug in a['params']:
                # yaml can't contain tuples, so we need to convert manually
                params = {
                    k: (v if not isinstance(v, omegaconf.listconfig.ListConfig)
                        else tuple(v))
                    for k, v in small_aug['params'].items()
                }
                aug = load_obj(small_aug['class_name'])(**params)
                small_augs.append(aug)
            aug = load_obj(a['class_name'])(small_augs)
            augs.append(aug)

        else:
            params = {
                k:
                (v if type(v) != omegaconf.listconfig.ListConfig else tuple(v))
                for k, v in a['params'].items()
            }
            aug = load_obj(a['class_name'])(**params)
            augs.append(aug)

    return A.Compose(augs)
Example #16
0
def get_vectorizer(cfg: DictConfig, word_to_idx: Dict) -> nn.Module:
    """
    Get model

    Args:
        word_to_idx:
        cfg: config

    Returns:
         initialized model
    """

    vectorizer = load_obj(cfg.datamodule.vectorizer_class_name)
    vectorizer = vectorizer(
        word_to_idx=word_to_idx,
        embeddings_path=cfg.datamodule.embeddings_path,
        embeddings_type=cfg.datamodule.embeddings_type,
        embeddings_dim=cfg.datamodule.embeddings_dim,
    )

    return vectorizer
Example #17
0
    def setup(self, stage=None):
        # with open(f'{self.cfg.datamodule.folder_path}{self.cfg.datamodule.file_name}', 'r', encoding='utf-8') as f:
        ner_data = self.load_sentences(f'{self.cfg.datamodule.folder_path}{self.cfg.datamodule.file_name}')

        # generate tag_to_idx
        labels = [labels['labels'] for labels in ner_data]
        flat_labels = list({label for sublist in labels for label in sublist})
        entities_names = sorted({label.split('-')[1] for label in flat_labels if label != 'O'})
        if self.cfg.datamodule.tag_to_idx_from_labels:
            self.tag_to_idx = {v: i for i, v in enumerate({i for j in labels for i in j}) if v != 'O'}
            for special_tag in ['O', 'PAD']:
                self.tag_to_idx[special_tag] = len(self.tag_to_idx)
        else:
            self.tag_to_idx = _generate_tag_to_idx(self.cfg, entities_names)

        # load or generate word_to_idx
        if self.cfg.datamodule.word_to_idx_name:
            with open(
                f'{self.cfg.datamodule.folder_path}{self.cfg.datamodule.word_to_idx_name}', 'r', encoding='utf-8'
            ) as f:
                self.word_to_idx = json.load(f)
        else:
            self.word_to_idx = _generate_word_to_idx(ner_data)

        train_data, valid_data = train_test_split(
            ner_data, random_state=self.cfg.training.seed, test_size=self.cfg.datamodule.valid_size
        )

        dataset_class = load_obj(self.cfg.datamodule.class_name)

        self.train_dataset = dataset_class(
            ner_data=train_data, cfg=self.cfg, word_to_idx=self.word_to_idx, tag_to_idx=self.tag_to_idx
        )
        self.valid_dataset = dataset_class(
            ner_data=valid_data, cfg=self.cfg, word_to_idx=self.word_to_idx, tag_to_idx=self.tag_to_idx
        )

        self._vectorizer = get_vectorizer(self.cfg, self.word_to_idx)
        self.collate = Collator(percentile=100, pad_value=self.tag_to_idx['PAD'])
Example #18
0
def get_test_dataset(cfg: DictConfig) -> object:
    """
    Get test dataset

    Args:
        cfg:

    Returns:
        test dataset
    """

    test_df = pd.read_csv(cfg.data.test_path)

    # valid_augs_list = [load_obj(i['class_name'])(**i['params']) for i in cfg['augmentation']['valid']['augs']]
    # valid_augs = A.Compose(valid_augs_list)
    valid_augs = load_augs(cfg['augmentation']['valid']['augs'])
    dataset_class = load_obj(cfg.dataset.class_name)

    test_dataset = dataset_class(df=test_df,
                                 mode='test',
                                 img_path=cfg.data.test_image_path,
                                 transforms=valid_augs)

    return test_dataset
Example #19
0
def run(cfg: DictConfig) -> None:
    """
    Run pytorch-lightning model

    Args:
        new_dir:
        cfg: hydra config

    """
    set_seed(cfg.training.seed)
    run_name = os.path.basename(os.getcwd())

    cfg.callbacks.model_checkpoint.params.dirpath = Path(
        os.getcwd(), cfg.callbacks.model_checkpoint.params.dirpath).as_posix()
    callbacks = []
    for callback in cfg.callbacks.other_callbacks:
        if callback.params:
            callback_instance = load_obj(
                callback.class_name)(**callback.params)
        else:
            callback_instance = load_obj(callback.class_name)()
        callbacks.append(callback_instance)

    loggers = []
    if cfg.logging.log:
        for logger in cfg.logging.loggers:
            if 'experiment_name' in logger.params.keys():
                logger.params['experiment_name'] = run_name
            loggers.append(load_obj(logger.class_name)(**logger.params))

    callbacks.append(EarlyStopping(**cfg.callbacks.early_stopping.params))
    callbacks.append(ModelCheckpoint(**cfg.callbacks.model_checkpoint.params))

    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        **cfg.trainer,
    )

    dm = load_obj(cfg.datamodule.data_module_name)(cfg=cfg)
    dm.setup()
    model = load_obj(cfg.training.lightning_module_name)(
        cfg=cfg, tag_to_idx=dm.tag_to_idx)
    model._vectorizer = dm._vectorizer
    trainer.fit(model, dm)

    if cfg.general.save_pytorch_model:
        if cfg.general.save_best:
            best_path = trainer.checkpoint_callback.best_model_path  # type: ignore
            # extract file name without folder
            save_name = os.path.basename(os.path.normpath(best_path))
            model = model.load_from_checkpoint(best_path,
                                               cfg=cfg,
                                               tag_to_idx=dm.tag_to_idx,
                                               strict=False)
            model_name = Path(cfg.callbacks.model_checkpoint.params.dirpath,
                              f'best_{save_name}'.replace('.ckpt',
                                                          '.pth')).as_posix()
            torch.save(model.model.state_dict(), model_name)
        else:
            os.makedirs('saved_models', exist_ok=True)
            model_name = 'saved_models/last.pth'
            torch.save(model.model.state_dict(), model_name)
def run(cfg: DictConfig) -> None:
    """
    Run pytorch-lightning model

    Args:
        new_dir:
        cfg: hydra config

    """
    set_seed(cfg.training.seed)
    hparams = flatten_omegaconf(cfg)

    cfg.callbacks.model_checkpoint.params.filepath = os.getcwd(
    ) + cfg.callbacks.model_checkpoint.params.filepath
    callbacks = []
    for callback in cfg.callbacks.other_callbacks:
        if callback.params:
            callback_instance = load_obj(
                callback.class_name)(**callback.params)
        else:
            callback_instance = load_obj(callback.class_name)()
        callbacks.append(callback_instance)

    loggers = []
    if cfg.logging.log:
        for logger in cfg.logging.loggers:
            loggers.append(load_obj(logger.class_name)(**logger.params))

    callbacks.append(EarlyStopping(**cfg.callbacks.early_stopping.params))

    trainer = pl.Trainer(
        logger=loggers,
        checkpoint_callback=ModelCheckpoint(
            **cfg.callbacks.model_checkpoint.params),
        callbacks=callbacks,
        **cfg.trainer,
    )

    dm = load_obj(cfg.datamodule.data_module_name)(hparams=hparams, cfg=cfg)
    dm.setup()
    model = load_obj(cfg.training.lightning_module_name)(
        hparams=hparams, cfg=cfg, tag_to_idx=dm.tag_to_idx)
    model._vectorizer = dm._vectorizer
    # dm = load_obj(cfg.datamodule.data_module_name)(hparams=hparams, cfg=cfg)
    trainer.fit(model, dm)

    if cfg.general.save_pytorch_model:
        if cfg.general.save_best:
            best_path = trainer.checkpoint_callback.best_model_path  # type: ignore
            # extract file name without folder and extension
            save_name = best_path.split('/')[-1][:-5]
            model = model.load_from_checkpoint(best_path,
                                               hparams=hparams,
                                               cfg=cfg,
                                               tag_to_idx=dm.tag_to_idx,
                                               strict=False)
            model_name = f'saved_models/{save_name}.pth'
            torch.save(model.model.state_dict(), model_name)
        else:
            os.makedirs('saved_models', exist_ok=True)
            model_name = 'saved_models/last.pth'
            torch.save(model.model.state_dict(), model_name)