Esempio n. 1
0
def test_dali(dataset_config, transform):
    augmentations = [
        EasyDict({
            'module': 'nvidia_dali',
            'args': {
                'transforms': [transform]
            }
        })
    ]
    dataset_config.train.augmentations = augmentations

    if dataset_config == lndmrks_dataset_config and transform[
            'transform'] in excpected_raise_error_transform:
        with pytest.raises(RuntimeError):
            dataloader = create_dataloader(
                dataloader_config=dali_loader,
                dataset_config=dataset_config,
                preprocess_config=preprocess_args,
                collate_fn=dataset_config.collate_fn)
    else:
        dataloader = create_dataloader(dataloader_config=dali_loader,
                                       dataset_config=dataset_config,
                                       preprocess_config=preprocess_args,
                                       collate_fn=dataset_config.collate_fn)

        for data in dataloader:
            fetched_data = data
            break
Esempio n. 2
0
def test_dataloader(dataloader_cfg):
    dataloader = create_dataloader(dataloader_config=dataloader_cfg,
                                   dataset_config=classification_config,
                                   preprocess_config=preprocess_args,
                                   collate_fn=None)
    for data in dataloader:
        fetched_data = data
        break
    assert isinstance(fetched_data[0], torch.Tensor)
    assert len(fetched_data[0].shape) == 4  # N,C,H,W
    assert fetched_data[0].shape[
        2] == preprocess_args.input_size  # Assume square input
    assert isinstance(len(dataloader), int)
    assert hasattr(dataloader, 'dataset')
    assert isinstance(dataloader.dataset, BasicDatasetWrapper)
Esempio n. 3
0
def test_dali_loader_with_no_dali_aug():
    augmentations = [
        EasyDict({
            'module': 'albumentations',
            'args': {
                'transforms': [transforms[0]]
            }
        })
    ]
    class_dataset_config.train.augmentations = augmentations

    dataloader = create_dataloader(dataloader_config=dali_loader,
                                   dataset_config=class_dataset_config,
                                   preprocess_config=preprocess_args,
                                   collate_fn=class_dataset_config.collate_fn)
Esempio n. 4
0
def test_neg_dali_aug_and_pytorch_data_loader():
    augmentations = [
        EasyDict({
            'module': 'nvidia_dali',
            'args': {
                'transforms': [transforms[0]]
            }
        })
    ]
    class_dataset_config.train.augmentations = augmentations

    # Expect RuntimeError if nvidia dali augmentation module used with pytorch dataloader
    with pytest.raises(RuntimeError):
        dataloader = create_dataloader(
            dataloader_config=pytorch_loader,
            dataset_config=class_dataset_config,
            preprocess_config=preprocess_args,
            collate_fn=class_dataset_config.collate_fn)
Esempio n. 5
0
def test_neg_dali_aug_not_the_first_aug():
    augmentations = [
        EasyDict({
            'module': 'albumentations',
            'args': {
                'transforms': [transforms[0]]
            }
        }),
        EasyDict({
            'module': 'nvidia_dali',
            'args': {
                'transforms': [transforms[0]]
            }
        })
    ]
    class_dataset_config.train.augmentations = augmentations

    # Expect RuntimeError if dali augmentations is not in the first order of the augmentation list
    with pytest.raises(RuntimeError):
        dataloader = create_dataloader(
            dataloader_config=dali_loader,
            dataset_config=class_dataset_config,
            preprocess_config=preprocess_args,
            collate_fn=class_dataset_config.collate_fn)
            accumulation_step = config.trainer.driver.args.accumulation_step
        except:
            pass

    # Get steps per epoch

    steps_per_epoch = 100

    ## Try to get accumulation_step from argparse or experiment config
    if args.steps_per_epoch is not None:
        steps_per_epoch = args.steps_per_epoch
    else:
        try:
            dataloader = create_dataloader(
                dataloader_config=config.dataloader,
                dataset_config=config.dataset,
                preprocess_config=config.model.preprocess_args,
                collate_fn=None)
            steps_per_epoch = len(dataloader)
        except:
            pass

    lrs = calculate_lr(lr_scheduler_cfg,
                       epochs=epochs,
                       optimizer_config=optim_cfg,
                       accumulation_step=accumulation_step,
                       steps_per_epoch=steps_per_epoch)

    plt.plot(lrs)
    plt.grid(True, linestyle='--', alpha=0.2)
    plt.ylabel("learning rate")
    def __init__(self,
                 config:EasyDict,
                 config_path: Union[str,Path,None] = None,
                 hypopt: bool = False,
                 resume: bool = False):
        """Class initialization

        Args:
            config (EasyDict): dictionary parsed from Vortex experiment file
            config_path (Union[str,Path,None], optional): path to experiment file. 
                Need to be provided for backup **experiment file**. 
                Defaults to None.
            hypopt (bool, optional): flag for hypopt, disable several pipeline process. 
                Defaults to False.
            resume (bool, optional): flag to resume training. 
                Defaults to False.

        Raises:
            Exception: raise undocumented error if exist

        Example:
            ```python
            from vortex.development.utils.parser import load_config
            from vortex.development.core.pipelines import TrainingPipeline
            
            # Parse config
            config_path = 'experiments/config/example.yml'
            config = load_config(config_path)
            train_executor = TrainingPipeline(config=config,
                                              config_path=config_path,
                                              hypopt=False)
            ```
        """

        self.start_epoch = 0
        checkpoint, state_dict = None, None
        if resume or ('checkpoint' in config and config.checkpoint is not None):
            if 'checkpoint' not in config:
                raise RuntimeError("You specify to resume but 'checkpoint' is not configured "
                    "in the config file. Please specify 'checkpoint' option in the top level "
                    "of your config file pointing to model path used for resume.")
            if resume or os.path.exists(config.checkpoint):
                checkpoint = torch.load(config.checkpoint, map_location=torch.device('cpu'))
                state_dict = checkpoint['state_dict']

            if resume:
                self.start_epoch = checkpoint['epoch']
                model_config = EasyDict(checkpoint['config'])
                if config.model.name != model_config.model.name:
                    raise RuntimeError("Model name configuration specified in config file ({}) is not "
                        "the same as saved in model checkpoint ({}).".format(config.model.name,
                        model_config.model.name))
                if config.model.network_args != model_config.model.network_args:
                    raise RuntimeError("'network_args' configuration specified in config file ({}) is "
                        "not the same as saved in model checkpoint ({}).".format(config.model.network_args, 
                        model_config.model.network_args))

                if 'name' in config.dataset.train:
                    cfg_dataset_name = config.dataset.train.name
                elif 'dataset' in config.dataset.train:
                    cfg_dataset_name = config.dataset.train.dataset
                else:
                    raise RuntimeError("dataset name is not found in config. Please specify in "
                        "'config.dataset.train.name'.")

                model_dataset_name = None
                if 'name' in model_config.dataset.train:
                    model_dataset_name = model_config.dataset.train.name
                elif 'dataset' in model_config.dataset.train:
                    model_dataset_name = model_config.dataset.train.dataset
                if cfg_dataset_name != model_dataset_name:
                    raise RuntimeError("Dataset specified in config file ({}) is not the same as saved "
                        "in model checkpoint ({}).".format(cfg_dataset_name, model_dataset_name))

                if ('n_classes' in config.model.network_args and 
                        (config.model.network_args.n_classes != model_config.model.network_args.n_classes)):
                    raise RuntimeError("Number of classes configuration specified in config file ({}) "
                        "is not the same as saved in model checkpoint ({}).".format(
                        config.model.network_args.n_classes, model_config.model.network_args.n_classes))

        self.config = config
        self.hypopt = hypopt

        # Check experiment config validity
        self._check_experiment_config(config)

        if not self.hypopt:
            # Create experiment logger
            self.experiment_logger = create_experiment_logger(config)

            # Output directory creation
            # If config_path is provided, it will duplicate the experiment file into the run directory
            self.experiment_directory,self.run_directory=check_and_create_output_dir(config,
                                                                                     self.experiment_logger,
                                                                                     config_path)

            # Create local experiments run log file
            self._create_local_runs_log(self.config,
                                        self.experiment_logger,
                                        self.experiment_directory,
                                        self.run_directory)
        else:
            self.experiment_logger=None

        # Training components creation

        if 'device' in config:
            self.device = config.device
        elif 'device' in config.trainer:
            self.device = config.trainer.device
        else:
            raise RuntimeError("'device' field not found in config. Please specify properly in main level.")

        model_components = create_model(model_config=config.model, state_dict=state_dict)
        if not isinstance(model_components, EasyDict):
            model_components = EasyDict(model_components)
        # not working for easydict
        # model_components.setdefault('collate_fn',None)
        if not 'collate_fn' in model_components:
            model_components.collate_fn = None
        self.model_components = model_components

        self.model_components.network = self.model_components.network.to(self.device)
        self.criterion = self.model_components.loss.to(self.device)

        param_groups = None
        if 'param_groups' in self.model_components:
            param_groups = self.model_components.param_groups

        if 'dataloader' in config:
            dataloader_config = config.dataloader
        elif 'dataloader' in config.dataset:
            dataloader_config = config.dataset.dataloader
        else:
            raise RuntimeError("Dataloader config field not found in config.")

        self.dataloader = create_dataloader(dataloader_config=dataloader_config,
                                            dataset_config=config.dataset,
                                            preprocess_config=config.model.preprocess_args,
                                            collate_fn=self.model_components.collate_fn,
                                            stage='train')
        self.trainer = engine.create_trainer(
            config.trainer, criterion=self.criterion,
            model=self.model_components.network,
            experiment_logger=self.experiment_logger,
            param_groups=param_groups
        )
        if resume:
            self.trainer.optimizer.load_state_dict(checkpoint['optimizer_state'])
            if self.trainer.scheduler is not None:
                scheduler_args = self.config.trainer.lr_scheduler.args
                if isinstance(scheduler_args, dict):
                    for name, v in scheduler_args.items():
                        if name in checkpoint["scheduler_state"]:
                            checkpoint["scheduler_state"][name] = v
                self.trainer.scheduler.load_state_dict(checkpoint["scheduler_state"])

        has_save = False
        self.save_best_metrics, self.save_best_type = None, None
        self.best_metrics = None
        if 'save_best_metrics' in self.config.trainer and self.config.trainer.save_best_metrics is not None:
            has_save = self.config.trainer.save_best_metrics is not None
            self.save_best_metrics = self.config.trainer.save_best_metrics
            if not isinstance(self.save_best_metrics, (list, tuple)):
                self.save_best_metrics = [self.save_best_metrics]

            self.save_best_type = list({'loss' if m == 'loss' else 'val_metric' for m in self.save_best_metrics})
            self.best_metrics = {name: float('inf') if name == 'loss' else float('-inf') for name in self.save_best_metrics}
            if 'loss' in self.save_best_metrics:
                self.save_best_metrics.remove('loss')

            if resume:
                best_metrics_ckpt = checkpoint['best_metrics']
                if isinstance(best_metrics_ckpt, dict):
                    self.best_metrics.update(best_metrics_ckpt)

        self.save_epoch, self.save_last_epoch = None, None
        if 'save_epoch' in self.config.trainer and self.config.trainer.save_epoch is not None:
            self.save_epoch = self.config.trainer.save_epoch
            has_save = has_save or self.config.trainer.save_epoch is not None
        if not has_save:
            warnings.warn("No model checkpoint saving configuration is specified, the training would still "
                "work but will only save the last epoch model.\nYou can configure either one of "
                "'config.trainer.save_epoch' or 'config.trainer.save_best_metric")

        # Validation components creation
        try:
            if 'validator' in config:
                validator_cfg = config.validator
            elif 'device' in config.trainer:
                validator_cfg = config.trainer.validator
            else:
                raise RuntimeError("'validator' field not found in config. Please specify properly in main level.")

            val_dataset = create_dataset(config.dataset, config.model.preprocess_args, stage='validate')
            
            ## use same batch-size as training by default
            validation_args = EasyDict({'batch_size' : self.dataloader.batch_size})
            validation_args.update(validator_cfg.args)
            self.validator = engine.create_validator(
                self.model_components, 
                val_dataset, validation_args, 
                device=self.device
            )
            
            self.val_epoch = validator_cfg.val_epoch
            self.valid_for_validation = True
        except AttributeError as e:
            warnings.warn('validation step not properly configured, will be skipped')
            self.valid_for_validation = False
        except Exception as e:
            raise Exception(str(e))

        # Reproducibility settings check
        if hasattr(config, 'seed') :
            _set_seed(config.seed)

        if not self.hypopt:
            print("\nexperiment directory:", self.run_directory)
        self._has_cls_names = hasattr(self.dataloader.dataset, "class_names")