Esempio n. 1
0
def run_evaluation(log_folder, config, ckpt, calgroup=None):
    config = get_config(config)
    if calgroup:
        if hasattr(config.dataset_config, "calgroup"):
            print(
                "Warning: overriding calgroup {0} with user supplied calgroup {1}"
                .format(config.dataset_config.calgroup, calgroup))
        config.dataset_config.calgroup = calgroup
    vnum = get_tb_logdir_version(str(ckpt))
    logger = TensorBoardLogger(dirname(dirname(log_folder)),
                               name=basename(dirname(log_folder)),
                               version=vnum)
    print("Creating new log file in directory {}".format(logger.log_dir))
    modules = ModuleUtility(config.run_config.imports)
    runner = modules.retrieve_class(
        config.run_config.run_class).load_from_checkpoint(ckpt, config)
    trainer_args = {"logger": logger}
    trainer_args["callbacks"] = [LoggingCallback()]
    set_default_trainer_args(trainer_args, config)
    model = LitPSD.load_from_checkpoint(ckpt, config)
    #model.set_logger(logger)
    data_module = PSDDataModule(config, runner.device)

    trainer = Trainer(**trainer_args)
    trainer.test(model, datamodule=data_module)
Esempio n. 2
0
 def __init__(self, config):
     super().__init__()
     self.log = logging.getLogger(__name__)
     if config.net_config.net_type != "2DConvolution":
         raise IOError("config.net_config.net_type must be 2DConvolution")
     self.system_config = config.system_config
     self.net_config = config.net_config
     self.nfeatures = self.system_config.n_features
     self.modules = ModuleUtility(self.net_config.imports)
     self.spatial_size = array([14, 11])
     size = [14, 11, self.system_config.n_features]
     self.model = ExtractedFeatureConv(
         self.nfeatures, self.net_config.hparams.out_planes,
         self.net_config.hparams.n_conv, size,
         **DictionaryUtility.to_dict(self.net_config.hparams.conv))
     hparams = self.net_config.hparams
     flat_size = 1
     for s in self.model.out_size:
         flat_size *= s
     self.n_linear = copy(flat_size)
     self.log.debug(
         "Flattened size of the SCN network output is {}".format(flat_size))
     self.linear = LinearBlock(flat_size, self.system_config.n_type,
                               hparams.n_lin).func
     self.permute_tensor = LongTensor(
         [2, 0, 1])  # needed because spconv requires batch index first
Esempio n. 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="path to config file")
    parser.add_argument("checkpoint", help="path to checkpoint file")
    parser.add_argument("--calgroup",
                        "-c",
                        help="calibration group entry in PROSPECT_CALDB",
                        type=str)
    parser.add_argument("--verbosity",
                        "-v",
                        help="Set the verbosity for this run.",
                        type=int,
                        default=0)
    args = parser.parse_args()
    main_logger = setup_logger(args)
    config = get_config(args.config)
    if args.calgroup:
        if hasattr(config.dataset_config, "calgroup"):
            print(
                "Warning: overriding calgroup {0} with user supplied calgroup {1}"
                .format(config.dataset_config.calgroup, args.calgroup))
        config.dataset_config.calgroup = args.calgroup
    log_folder = dirname(args.config)
    p = Path(log_folder)
    cp = p.glob('*.tfevents.*')
    logger = None
    if cp:
        for ckpt in cp:
            print("Using existing log file {}".format(ckpt))
            vnum = get_tb_logdir_version(str(ckpt))
            logger = TensorBoardLogger(dirname(dirname(log_folder)),
                                       name=basename(dirname(log_folder)),
                                       version=vnum)
            break
    else:
        logger = TensorBoardLogger(log_folder, name=config.run_config.exp_name)
        print("Creating new log file in directory {}".format(logger.log_dir))
    modules = ModuleUtility(config.run_config.imports)
    runner = modules.retrieve_class(
        config.run_config.run_class).load_from_checkpoint(args.checkpoint,
                                                          config=config)
    trainer_args = {"logger": logger, "callbacks": [LoggingCallback()]}
    set_default_trainer_args(trainer_args, config)
    #model.set_logger(logger)
    data_module = PSDDataModule(config, runner.device)
    trainer = Trainer(**trainer_args)
    trainer.test(runner, datamodule=data_module)
Esempio n. 4
0
 def __init__(self, config, device):
     super().__init__()
     self.log = logging.getLogger(__name__)
     self.config = config
     self.device = device
     if hasattr(self.config.system_config, "half_precision"):
         self.half_precision = self.config.system_config.half_precision
         self.log.debug("Half precision set to {}".format(
             self.half_precision))
         if not hasattr(self.config.dataset_config.dataset_params,
                        "use_half"):
             setattr(self.config.dataset_config.dataset_params, "use_half",
                     self.half_precision)
     else:
         self.half_precision = False
     self.ntype = len(self.config.dataset_config.paths)
     self.total_train = self.config.dataset_config.n_train * self.ntype
     self.modules = ModuleUtility(self.config.dataset_config.imports)
     self.dataset_class = self.modules.retrieve_class(
         self.config.dataset_config.dataset_class)
     self.dataset_shuffle_map = {}
 def __init__(self, config):
     super().__init__()
     self.log = logging.getLogger(__name__)
     if config.net_config.net_type != "2DConvolution":
         raise IOError("config.net_config.net_type must be 2DConvolution")
     self.system_config = config.system_config
     self.net_config = config.net_config
     self.nsamples = self.system_config.n_samples
     self.modules = ModuleUtility(self.net_config.imports)
     if hasattr(self.net_config, "z_weights"):
         self.use_z_model = True
         if not hasattr(self.net_config, "z_config"):
             raise ValueError(
                 "if specifying z_weights, you must also specify corresponding z_config"
             )
         z_config = get_config(self.net_config.z_config)
         #setattr(z_config.net_config.hparams.conv, "todense", False)
         self.log.info("Using Z model from {}".format(
             self.net_config.z_weights))
         self.z_model = LitZ.load_from_checkpoint(self.net_config.z_weights,
                                                  config=z_config)
         self.z_model.freeze()
     else:
         self.use_z_model = False
     if not hasattr(self.net_config, "algorithm"):
         setattr(self.net_config, "algorithm", "conv")
     if self.use_z_model:
         if self.net_config.algorithm == "conv":
             self.model = SparseConv2DForEZ(self.nsamples * 2,
                                            out_planes=1,
                                            **DictionaryUtility.to_dict(
                                                self.net_config.hparams))
         elif self.net_config.algorithm == "features":
             self.model = SparseConv2DForEZ(self.nsamples,
                                            out_planes=1,
                                            **DictionaryUtility.to_dict(
                                                self.net_config.hparams))
     else:
         if self.net_config.algorithm == "conv":
             self.model = SparseConv2DForEZ(
                 self.nsamples * 2,
                 **DictionaryUtility.to_dict(self.net_config.hparams))
         elif self.net_config.algorithm == "features":
             self.model = SparseConv2DForEZ(
                 self.nsamples,
                 **DictionaryUtility.to_dict(self.net_config.hparams))
     self.spatial_size = array([14, 11])
     self.permute_tensor = LongTensor(
         [2, 0, 1])  # needed because spconv requires batch index first
Esempio n. 6
0
 def __init__(self, optuna_config, config, model_dir, trainer_args):
     self.optuna_config = optuna_config
     self.model_dir = model_dir
     self.config = config
     self.hyperparameters = {}
     self.log = logging.getLogger(__name__)
     base_dir = os.path.join(model_dir, "studies")
     if not os.path.exists(base_dir):
         os.makedirs(base_dir, exist_ok=True)
     self.study_dir = os.path.join(
         model_dir, "studies/{}".format(config.run_config.exp_name))
     self.study_name = self.config.run_config.exp_name if not hasattr(
         optuna_config, "name") else self.optuna_config.name
     self.trainer_args = trainer_args
     if not os.path.exists(self.study_dir):
         os.makedirs(self.study_dir, exist_ok=True)
     self.connstr = "sqlite:///" + os.path.join(self.study_dir, "study.db")
     write_run_info(self.study_dir)
     self.hyperparameters_bounds = DictionaryUtility.to_dict(
         self.optuna_config.hyperparameters)
     self.log.debug("hyperparameters bounds set to {0}".format(
         self.hyperparameters_bounds))
     self.modules = ModuleUtility(["optuna.pruners", "optuna.samplers"])
     self.parse_config()
Esempio n. 7
0
 def __init__(self, config):
     super().__init__()
     self.log = logging.getLogger(__name__)
     if config.net_config.net_type != "2DConvolution":
         raise IOError("config.net_config.net_type must be 2DConvolution")
     self.system_config = config.system_config
     self.net_config = config.net_config
     self.nsamples = self.system_config.n_samples
     self.modules = ModuleUtility(self.net_config.imports)
     if not hasattr(self.net_config, "algorithm"):
         setattr(self.net_config, "algorithm", "conv")
     if hasattr(self.net_config, "version"):
         self.version = self.net_config.version
     else:
         self.version = 0
     if self.net_config.algorithm == "conv":
         if self.version == 0:
             self.model = SparseConv2DForZ(
                 self.nsamples * 2,
                 **DictionaryUtility.to_dict(self.net_config.hparams.conv))
         else:
             self.model = SparseConv2DForEZ(self.nsamples * 2,
                                            out_planes=1,
                                            **DictionaryUtility.to_dict(
                                                self.net_config.hparams))
     elif self.net_config.algorithm == "point":
         self.model = Pointwise2DForZ(
             self.nsamples * 2,
             **DictionaryUtility.to_dict(self.net_config.hparams.point))
     elif self.net_config.algorithm == "features":
         if self.version == 0:
             self.model = SparseConv2DForZ(
                 self.nsamples,
                 **DictionaryUtility.to_dict(self.net_config.hparams.conv))
         else:
             self.model = SparseConv2DForEZ(self.nsamples,
                                            out_planes=1,
                                            **DictionaryUtility.to_dict(
                                                self.net_config.hparams))
     self.spatial_size = array([14, 11])
     self.permute_tensor = LongTensor(
         [2, 0, 1])  # needed because spconv requires batch index first
Esempio n. 8
0
class PSDDataModule(pl.LightningDataModule):
    def __init__(self, config, device):
        super().__init__()
        self.log = logging.getLogger(__name__)
        self.config = config
        self.device = device
        if hasattr(self.config.system_config, "half_precision"):
            self.half_precision = self.config.system_config.half_precision
            self.log.debug("Half precision set to {}".format(
                self.half_precision))
            if not hasattr(self.config.dataset_config.dataset_params,
                           "use_half"):
                setattr(self.config.dataset_config.dataset_params, "use_half",
                        self.half_precision)
        else:
            self.half_precision = False
        self.ntype = len(self.config.dataset_config.paths)
        self.total_train = self.config.dataset_config.n_train * self.ntype
        self.modules = ModuleUtility(self.config.dataset_config.imports)
        self.dataset_class = self.modules.retrieve_class(
            self.config.dataset_config.dataset_class)
        self.dataset_shuffle_map = {}

    def prepare_data(self):
        # called only on 1 GPU
        pass

    def setup(self, stage=None):
        # called on every GPU
        if stage == 'fit' or stage is None:
            if not hasattr(self, "train_dataset"):
                if hasattr(self.config.dataset_config, "train_config"):
                    self.train_dataset = self.dataset_class.retrieve_config(
                        self.config.dataset_config.train_config, self.device,
                        self.half_precision)
                    self.log.info("Using train dataset from {}.".format(
                        self.config.dataset_config.train_config))
                else:
                    self.train_dataset = self.dataset_class(
                        self.config, "train",
                        self.config.dataset_config.n_train, self.device,
                        **DictionaryUtility.to_dict(
                            self.config.dataset_config.dataset_params))
                    self.log.info("Training dataset generated.")
                self.train_excludes = self.train_dataset.get_file_list()
            worker_info = get_worker_info()
            if hasattr(self.config.dataset_config, "data_prep"):
                if self.config.dataset_config.data_prep == "shuffle":
                    if hasattr(self.config.dataset_config, "train_config"):
                        self.log.warning(
                            "You specified a training dataset and shuffling data prep. Data shuffling is "
                            "only supported when specifying a dataset via a directory list. Skipping "
                            "shuffle.")
                    else:
                        if worker_info is None:
                            self.log.info(
                                "Main process beginning to shuffle dataset.")
                        else:
                            self.log.info(
                                "Worker process {} beginning to shuffle dataset."
                                .format(worker_info.id))
                        self.train_dataset.write_shuffled(
                        )  # might need to make this call configurable
        if stage == 'test' or stage is None:
            if not hasattr(self, "val_dataset"):
                if hasattr(self.config.dataset_config, "val_config"):
                    self.val_dataset = self.dataset_class.retrieve_config(
                        self.config.dataset_config.val_config, self.device,
                        self.half_precision)
                    self.log.info("Using validation dataset from {}.".format(
                        self.config.dataset_config.val_config))
                else:
                    if hasattr(self.config.dataset_config, "n_validate"):
                        n_validate = self.config.dataset_config.n_validate
                    else:
                        n_validate = self.config.dataset_config.n_test
                    if hasattr(self, "train_excludes"):
                        par = {"file_excludes": self.train_excludes}
                    else:
                        par = {}
                    self.val_dataset = self.dataset_class(
                        self.config, "validate", n_validate, self.device,
                        **par,
                        **DictionaryUtility.to_dict(
                            self.config.dataset_config.dataset_params))
                    self.log.info("Validation dataset generated.")

            if not hasattr(self, "test_dataset"):
                if hasattr(self.config.dataset_config, "val_config"):
                    self.test_dataset = self.dataset_class.retrieve_config(
                        self.config.dataset_config.test_config, self.device,
                        self.half_precision)
                    self.log.info("Using test dataset from {}.".format(
                        self.config.dataset_config.test_config))
                else:
                    if hasattr(self, "train_excludes"):
                        par = {
                            "file_excludes":
                            self.train_excludes +
                            self.val_dataset.get_file_list()
                        }
                    else:
                        par = {
                            "file_excludes": self.val_dataset.get_file_list()
                        }

                    if hasattr(self.config.dataset_config,
                               "test_dataset_params"):
                        self.test_dataset = self.dataset_class(
                            self.config, "test",
                            self.config.dataset_config.n_test, self.device,
                            **par,
                            **DictionaryUtility.to_dict(
                                self.config.dataset_config.test_dataset_params
                            ))
                    else:
                        self.test_dataset = self.dataset_class(
                            self.config, "test",
                            self.config.dataset_config.n_test, self.device,
                            **par,
                            **DictionaryUtility.to_dict(
                                self.config.dataset_config.dataset_params))
                    self.log.info("Test dataset generated.")

    def train_dataloader(self):
        if not hasattr(self, "train_dataset"):
            self.setup("train")
        return DataLoader(self.train_dataset,
                          shuffle=True,
                          collate_fn=collate_fn,
                          **DictionaryUtility.to_dict(
                              self.config.dataset_config.dataloader_params))

    def val_dataloader(self):
        if not hasattr(self, "val_dataset"):
            self.setup("test")
        return DataLoader(self.val_dataset,
                          shuffle=False,
                          collate_fn=collate_fn,
                          **DictionaryUtility.to_dict(
                              self.config.dataset_config.dataloader_params))

    def test_dataloader(self):
        if not hasattr(self, "test_dataset"):
            self.setup("test")
        return DataLoader(self.test_dataset,
                          shuffle=False,
                          collate_fn=collate_fn,
                          **DictionaryUtility.to_dict(
                              self.config.dataset_config.dataloader_params))
Esempio n. 9
0
class ModelOptimization:
    """
    hyperparameter optimization class
    """
    def __init__(self, optuna_config, config, model_dir, trainer_args):
        self.optuna_config = optuna_config
        self.model_dir = model_dir
        self.config = config
        self.hyperparameters = {}
        self.log = logging.getLogger(__name__)
        base_dir = os.path.join(model_dir, "studies")
        if not os.path.exists(base_dir):
            os.makedirs(base_dir, exist_ok=True)
        self.study_dir = os.path.join(
            model_dir, "studies/{}".format(config.run_config.exp_name))
        self.study_name = self.config.run_config.exp_name if not hasattr(
            optuna_config, "name") else self.optuna_config.name
        self.trainer_args = trainer_args
        if not os.path.exists(self.study_dir):
            os.makedirs(self.study_dir, exist_ok=True)
        self.connstr = "sqlite:///" + os.path.join(self.study_dir, "study.db")
        write_run_info(self.study_dir)
        self.hyperparameters_bounds = DictionaryUtility.to_dict(
            self.optuna_config.hyperparameters)
        self.log.debug("hyperparameters bounds set to {0}".format(
            self.hyperparameters_bounds))
        self.modules = ModuleUtility(["optuna.pruners", "optuna.samplers"])
        self.parse_config()

    def parse_config(self):
        if not hasattr(self.optuna_config, "hyperparameters"):
            raise IOError(
                "No hyperparameters found in optuna config. You must set the hyperparameters to a dictionary of key: "
                "value where key is hte path to the hyperparameter in the config file, and value is an array of two "
                "elements bounding the range of the parameter")
        for h in self.hyperparameters_bounds.keys():
            i = 0
            path_list = h.split("/")
            path_list = [p for p in path_list if p]
            plen = len(path_list)
            myobj = None
            for j, name in enumerate(path_list):
                if not name:
                    continue
                if j == plen - 1:
                    break
                if i > 0:
                    myobj = get_from_path(myobj, name)
                else:
                    myobj = get_from_path(self.config, name)
                i += 1
            if myobj:
                self.hyperparameters[h] = myobj

    def modify_config(self, trial):
        for hp in self.hyperparameters.keys():
            name = hp.split("/")[-1]
            bounds = self.hyperparameters_bounds[hp]
            if isinstance(bounds, dict):
                if "val" in bounds.keys():
                    setattr(self.hyperparameters[hp], name,
                            trial.suggest_categorical(name, bounds["val"]))
                else:
                    raise ValueError(
                        "Invalid format for hyperparameter key {0}. Specify category with \"val\":[list "
                        "of values]".format(hp))
            elif len(bounds) > 2:
                setattr(self.hyperparameters[hp], name,
                        trial.suggest_categorical(name, bounds))
            elif isinstance(bounds[0], int):
                setattr(self.hyperparameters[hp], name,
                        trial.suggest_int(name, bounds[0], bounds[1]))
            elif isinstance(bounds[0], float):
                t = None
                if bounds[0] != 0 and bounds[1] != 0:
                    if bounds[1] / bounds[0] > 100 or bounds[0] / bounds[
                            1] > 100:
                        t = trial.suggest_loguniform(name, bounds[0],
                                                     bounds[1])
                if t is None:
                    t = trial.suggest_float(name, bounds[0], bounds[1])
                setattr(self.hyperparameters[hp], name, t)
            elif isinstance(bounds[0], bool):
                setattr(self.hyperparameters[hp], name,
                        trial.suggest_int(name, 0, 1))
            self.log.info("setting {0} to {1}".format(
                hp, getattr(self.hyperparameters[hp], name)))

    def objective(self, trial):
        self.modify_config(trial)
        if not os.path.exists(self.study_dir):
            os.mkdir(self.study_dir)
        if not os.path.exists(
                os.path.join(self.study_dir, "trial_{}".format(trial.number))):
            os.mkdir(
                os.path.join(self.study_dir, "trial_{}".format(trial.number)))
        logger = TensorBoardLogger(self.study_dir,
                                   name="trial_{}".format(trial.number),
                                   default_hp_metric=False)
        log_folder = logger.log_dir
        if not os.path.exists(log_folder):
            os.makedirs(log_folder, exist_ok=True)
        trainer_args = self.trainer_args
        checkpoint_callback = \
            ModelCheckpoint(
                dirpath=log_folder, filename='{epoch}-{val_loss:.2f}',
                monitor="val_loss")
        trainer_args["logger"] = logger
        trainer_args["default_root_dir"] = self.study_dir
        set_default_trainer_args(trainer_args, self.config)
        if trainer_args["profiler"]:
            profiler = SimpleProfiler(output_filename=os.path.join(
                log_folder, "profile_results.txt"))
            trainer_args["profiler"] = profiler
        save_config(self.config, log_folder, "trial_{}".format(trial.number),
                    "config")
        # save_config(DictionaryUtility.to_object(trainer_args), log_folder,
        #        "trial_{}".format(trial.number), "train_args")
        cbs = [LoggingCallback(), PruningCallback(), checkpoint_callback]
        # trainer_args["early_stop_callback"] = PyTorchLightningPruningCallback(trial, monitor="val_early_stop_on")
        if self.config.run_config.run_class == "LitZ":
            cbs.append(
                EarlyStopping(monitor='val_loss',
                              min_delta=.00,
                              verbose=True,
                              mode="min",
                              patience=5))
        else:
            cbs.append(
                EarlyStopping(monitor='val_loss',
                              min_delta=.00,
                              verbose=True,
                              mode="min",
                              patience=4))

        trainer = pl.Trainer(**trainer_args, callbacks=cbs)
        modules = ModuleUtility(self.config.run_config.imports)
        model = modules.retrieve_class(self.config.run_config.run_class)(
            self.config, trial)
        data_module = PSDDataModule(self.config, model.device)
        try:
            trainer.fit(model, datamodule=data_module)
            loss = trainer.checkpoint_callback.best_model_score
            self.log.info("best loss found for trial {0} is {1}".format(
                trial.number, loss))
        except RuntimeError as e:
            print(
                "Caught error during trial {0}, moving to next trial. Error message below."
                .format(trial.number, trial))
            print(e)
            self.log.info("Trial {0} failed with error {1}".format(
                trial.number, e))
            gc.collect()
            loss = None
        return loss

    def run_study(self, pruning=False):
        pruner = optuna.pruners.MedianPruner(
            n_warmup_steps=10,
            interval_steps=3) if pruning else optuna.pruners.NopPruner()
        if hasattr(self.optuna_config, "pruner"):
            if hasattr(self.optuna_config, "pruner_params"):
                pruner = self.modules.retrieve_class(
                    "pruners." +
                    self.optuna_config.pruner)(**DictionaryUtility.to_dict(
                        self.optuna_config.pruner_params))
            else:
                pruner = self.modules.retrieve_class(
                    "pruners." + self.optuna_config.pruner)()
        opt_dict = {}
        if hasattr(self.optuna_config, "sampler"):
            if hasattr(self.optuna_config, "sampler_params"):
                opt_dict["sampler"] = self.modules.retrieve_class(
                    "samplers." +
                    self.optuna_config.sampler)(**DictionaryUtility.to_dict(
                        self.optuna_config.sampler_params))
            else:
                opt_dict["sampler"] = self.modules.retrieve_class(
                    "samplers." + self.optuna_config.sampler)()

        study = optuna.create_study(study_name=self.study_name,
                                    direction="minimize",
                                    pruner=pruner,
                                    storage=self.connstr,
                                    load_if_exists=True,
                                    **opt_dict)
        self.log.debug("optimize parameters: \n{}".format(
            DictionaryUtility.to_dict(self.optuna_config.optimize_args)))
        study.optimize(self.objective,
                       **DictionaryUtility.to_dict(
                           self.optuna_config.optimize_args),
                       show_progress_bar=True,
                       gc_after_trial=True)
        output = {}
        self.log.info("Number of finished trials: {}".format(len(
            study.trials)))
        self.log.info("Best trial:")
        trial = study.best_trial
        self.log.info("  Value: {}".format(trial.value))
        self.log.info("  Params: ")
        for key, value in trial.params.items():
            self.log.info("    {}: {}".format(key, value))
        self.log.info("Number of finished trials: {}".format(len(
            study.trials)))
        output["n_finished_trials"] = len(study.trials)
        self.log.info("Best trial:")
        output["best_trial"] = trial.value
        self.log.info("  Value: {}".format(trial.value))
        output["best_trial_params"] = trial.params
        self.log.info("  Params: ")
        for key, value in trial.params.items():
            self.log.info("    {}: {}".format(key, value))
        save_config(output, self.study_dir, "trial", "results", True)
Esempio n. 10
0
    def objective(self, trial):
        self.modify_config(trial)
        if not os.path.exists(self.study_dir):
            os.mkdir(self.study_dir)
        if not os.path.exists(
                os.path.join(self.study_dir, "trial_{}".format(trial.number))):
            os.mkdir(
                os.path.join(self.study_dir, "trial_{}".format(trial.number)))
        logger = TensorBoardLogger(self.study_dir,
                                   name="trial_{}".format(trial.number),
                                   default_hp_metric=False)
        log_folder = logger.log_dir
        if not os.path.exists(log_folder):
            os.makedirs(log_folder, exist_ok=True)
        trainer_args = self.trainer_args
        checkpoint_callback = \
            ModelCheckpoint(
                dirpath=log_folder, filename='{epoch}-{val_loss:.2f}',
                monitor="val_loss")
        trainer_args["logger"] = logger
        trainer_args["default_root_dir"] = self.study_dir
        set_default_trainer_args(trainer_args, self.config)
        if trainer_args["profiler"]:
            profiler = SimpleProfiler(output_filename=os.path.join(
                log_folder, "profile_results.txt"))
            trainer_args["profiler"] = profiler
        save_config(self.config, log_folder, "trial_{}".format(trial.number),
                    "config")
        # save_config(DictionaryUtility.to_object(trainer_args), log_folder,
        #        "trial_{}".format(trial.number), "train_args")
        cbs = [LoggingCallback(), PruningCallback(), checkpoint_callback]
        # trainer_args["early_stop_callback"] = PyTorchLightningPruningCallback(trial, monitor="val_early_stop_on")
        if self.config.run_config.run_class == "LitZ":
            cbs.append(
                EarlyStopping(monitor='val_loss',
                              min_delta=.00,
                              verbose=True,
                              mode="min",
                              patience=5))
        else:
            cbs.append(
                EarlyStopping(monitor='val_loss',
                              min_delta=.00,
                              verbose=True,
                              mode="min",
                              patience=4))

        trainer = pl.Trainer(**trainer_args, callbacks=cbs)
        modules = ModuleUtility(self.config.run_config.imports)
        model = modules.retrieve_class(self.config.run_config.run_class)(
            self.config, trial)
        data_module = PSDDataModule(self.config, model.device)
        try:
            trainer.fit(model, datamodule=data_module)
            loss = trainer.checkpoint_callback.best_model_score
            self.log.info("best loss found for trial {0} is {1}".format(
                trial.number, loss))
        except RuntimeError as e:
            print(
                "Caught error during trial {0}, moving to next trial. Error message below."
                .format(trial.number, trial))
            print(e)
            self.log.info("Trial {0} failed with error {1}".format(
                trial.number, e))
            gc.collect()
            loss = None
        return loss