Esempio n. 1
0
    def prepare_resume(self):
        """Tries to resume the experiment by using the defined resume path or PytorchExperiment."""

        checkpoint_file = ""
        base_dir = ""

        reset_epochs = self._resume_reset_epochs

        if self._resume_path is not None:
            if isinstance(self._resume_path, str):
                if self._resume_path.endswith(".pth.tar"):
                    checkpoint_file = self._resume_path
                    base_dir = os.path.dirname(
                        os.path.dirname(checkpoint_file))
                elif self._resume_path.endswith(
                        "checkpoint") or self._resume_path.endswith(
                            "checkpoint/"):
                    checkpoint_file = get_last_file(self._resume_path)
                    base_dir = os.path.dirname(
                        os.path.dirname(checkpoint_file))
                elif "checkpoint" in os.listdir(
                        self._resume_path) and "config" in os.listdir(
                            self._resume_path):
                    checkpoint_file = get_last_file(self._resume_path)
                    base_dir = self._resume_path
                else:
                    warnings.warn(
                        "You have not selected a valid experiment folder, will search all sub folders",
                        UserWarning)
                    if self.elog is not None:
                        self.elog.text_logger.log_to(
                            "You have not selected a valid experiment folder, will search all "
                            "sub folders", "warnings")
                    checkpoint_file = get_last_file(self._resume_path)
                    base_dir = os.path.dirname(
                        os.path.dirname(checkpoint_file))

        if base_dir:
            if not self._ignore_resume_config:
                load_config = Config()
                load_config.load(os.path.join(base_dir, "config/config.json"))
                self._config_raw = load_config
                self.config = Config.init_objects(self._config_raw)
                self.print("Loaded existing config from:", base_dir)
                if self.n_epochs is None:
                    self.n_epochs = self._config_raw.get("n_epochs")

        if checkpoint_file:
            self.load_checkpoint(name="",
                                 path=checkpoint_file,
                                 save_types=self._resume_save_types)
            self._resume_path = checkpoint_file
            shutil.copyfile(
                checkpoint_file,
                os.path.join(self.elog.checkpoint_dir, "0_checkpoint.pth.tar"))
            self.print("Loaded existing checkpoint from:", checkpoint_file)

            self._resume_reset_epochs = reset_epochs
            if self._resume_reset_epochs:
                self._epoch_idx = 0
Esempio n. 2
0
    def __init__(self,
                 base_dir,
                 exp_dir="",
                 name=None,
                 decode_config_clean_str=True):

        super(ExperimentReader, self).__init__()

        self.base_dir = base_dir
        self.exp_dir = exp_dir
        self.work_dir = os.path.abspath(
            os.path.join(self.base_dir, self.exp_dir))
        self.config_dir = os.path.join(self.work_dir, "config")
        self.log_dir = os.path.join(self.work_dir, "log")
        self.checkpoint_dir = os.path.join(self.work_dir, "checkpoint")
        self.img_dir = os.path.join(self.work_dir, "img")
        self.plot_dir = os.path.join(self.work_dir, "plot")
        self.save_dir = os.path.join(self.work_dir, "save")
        self.result_dir = os.path.join(self.work_dir, "result")

        self.config = Config()
        if decode_config_clean_str:
            self.config.load(os.path.join(self.config_dir, "config.json"),
                             decoder_cls_=StringMultiTypeDecoder)
        else:
            self.config.load(os.path.join(self.config_dir, "config.json"),
                             decoder_cls_=None)

        self.exp_info = Config()
        exp_info_file = os.path.join(self.config_dir, "exp.json")
        if os.path.exists(exp_info_file):
            self.exp_info.load(exp_info_file)

        self.__results_dict = None

        self.meta_name = None
        self.meta_star = False
        self.meta_ignore = False
        self.read_meta_info()

        if name is not None:
            self.exp_name = name
        elif self.meta_name is not None:
            self.exp_name = self.meta_name
        elif "name" in self.exp_info:
            self.exp_name = self.exp_info['name']
        elif "exp_name" in self.config:
            self.exp_name = self.config['exp_name']
        else:
            self.exp_name = "experiments"

        self.ignore = self.meta_ignore
        self.star = self.meta_star
Esempio n. 3
0
def get_config():

    # Set your own path, if needed.
    data_root_dir = os.path.abspath(
        'data')  # The path where the downloaded dataset is stored.

    c = Config(
        update_from_argv=
        True,  # If set 'True', it allows to update each configuration by a cmd/terminal parameter.

        # Train parameters
        num_classes=2,
        in_channels=1,
        batch_size=3,  # works with 6 on GB GPU
        patch_size=512,
        n_epochs=1,
        learning_rate=0.0002,
        fold=
        0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
        device=
        "cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Basic_Unet',
        author='kleina',  # Author of this project
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=
        False,  # Appends a random string to the experiment name to make it unique.
        start_visdom=
        True,  # You can either start a visom server manually or have trixi start it for you.
        do_instancenorm=
        True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir='',

        # Adapt to your own path, if needed.
        google_drive_id='1jzeNU1EKnK81PyTsrx0ujfNl-t0Jo8uE',  #spleen
        dataset_name='Task09_Spleen',
        base_dir=os.path.abspath(
            'output_experiment'),  # Where to log the output of the experiment.
        data_root_dir=
        data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(
            data_root_dir, 'Task09_Spleen/preprocessed'
        ),  # This is where your training and validation data is stored
        data_test_dir=os.path.join(data_root_dir, 'Task09_Spleen/preprocessed'
                                   ),  # This is where your test data is stored
        split_dir=os.path.join(
            data_root_dir, 'Task09_Spleen'
        ),  # This is where the 'splits.pkl' file is located, that holds your splits.

        # execute a segmentation process on a specific image using the model
        model_dir=os.path.join(
            os.path.abspath('output_experiment'),
            '20200108-035420_Basic_Unet/checkpoint/checkpoint_current'
        ),  # the model being used for segmentation
    )

    print(c)
    return c
Esempio n. 4
0
def group_experiments_by(exps, group_by_list):
    configs_flat = [e.config.flat()
                    for e in exps]  ### Exclude already combined experiments
    config_diff = Config.difference_config_static(*configs_flat)

    group_diff_key_list = []
    group_diff_val_list = []

    for diff_key in group_by_list:
        if diff_key in config_diff:
            group_diff_key_list.append(diff_key)
            group_diff_val_list.append(set(config_diff[diff_key]))

    val_combis = itertools.product(*group_diff_val_list)

    group_dict = defaultdict(list)

    for val_combi in val_combis:
        for e in exps:
            e_config = e.config.flat()
            is_match = True

            for key, val in zip(group_diff_key_list, val_combi):
                if e_config[key] != val:
                    is_match = False

            if is_match:
                group_dict[val_combi].append(e)

    return list(group_dict.values())
Esempio n. 5
0
    def get_config(self):
        combi_config = copy.deepcopy(self.experiments[0].config)
        config_diff = Config.difference_config_static(
            *[e.config for e in self.experiments], only_set=True)
        combi_config.update(config_diff)

        return combi_config
def gen_config(parser):
    args = parser.parse_args()
    # prepare config dictionary, add all arguments from args
    c = Config()
    for arg in vars(args):
        c[arg] = getattr(args, arg)
    return c
Esempio n. 7
0
    def load_config(self, name, **kwargs):
        """
        Loads a config from a json file from the experiment config dir

        Args:
            name: the name of the config file

        Returns: A Config/ dict filled with the json file content

        """

        if not name.endswith(".json"):
            name += ".json"
        c = Config()
        c.load(os.path.join(self.config_dir, name), **kwargs)
        return c
def get_config():
    # Set your own path, if needed.
    data_root_dir = os.path.abspath(
        'data')  # The path where the downloaded dataset is stored.

    c = Config(
        update_from_argv=
        True,  # If set 'True', it allows to update each configuration by a cmd/terminal parameter.

        # Train parameters
        num_classes=3,
        in_channels=1,
        batch_size=8,
        patch_size=64,
        n_epochs=10,
        learning_rate=0.0002,
        fold=
        0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
        device=
        "cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Basic_Unet',
        author='kleina',  # Author of this project
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=
        False,  # Appends a random string to the experiment name to make it unique.
        start_visdom=
        True,  # You can either start a visom server manually or have trixi start it for you.
        do_instancenorm=
        True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir='',

        # Adapt to your own path, if needed.
        google_drive_id=
        '1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',  # This id is used to download the example dataset.
        dataset_name='Task04_Hippocampus',
        base_dir=os.path.abspath(
            'output_experiment'),  # Where to log the output of the experiment.
        data_root_dir=
        data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(
            data_root_dir, 'Task04_Hippocampus/preprocessed'
        ),  # This is where your training and validation data is stored
        data_test_dir=os.path.join(data_root_dir,
                                   'Task04_Hippocampus/preprocessed'
                                   ),  # This is where your test data is stored
        split_dir=os.path.join(
            data_root_dir, 'Task04_Hippocampus'
        ),  # This is where the 'splits.pkl' file is located, that holds your splits.
    )

    print(c)
    return c
    def _save_exp_config(self):

        if self.elog is not None and not isinstance(self.elog, Mock):
            cur_time = time.strftime("%y-%m-%d_%H:%M:%S", time.localtime(time.time()))
            self.elog.save_config(Config(**{'name': self.name,
                                            'time': cur_time,
                                            'state': "Stub",
                                            'current_time': cur_time,
                                            'epoch': 0
                                            }),
                                  "exp")
Esempio n. 10
0
    def __init__(self,
                 base_dir,
                 exp_dirs=(),
                 name=None,
                 decode_config_clean_str=True):

        self.base_dir = base_dir

        if name is None or name == "":
            self.exp_name = "combi-experiments"
        else:
            self.exp_name = name

        self.experiments = []
        for exp_dir in exp_dirs:
            self.experiments.append(
                ExperimentReader(
                    base_dir=base_dir,
                    exp_dir=exp_dir,
                    decode_config_clean_str=decode_config_clean_str))

        exp_base_dirs = os.path.commonpath(
            [os.path.dirname(e.work_dir) for e in self.experiments])
        if exp_base_dirs != "":
            self.base_dir = exp_base_dirs

        self.exp_info = Config()

        self.exp_info["epoch"] = -1
        self.exp_info["name"] = self.exp_name
        self.exp_info["state"] = "Combined"
        self.exp_info["time"] = time.strftime("%y-%m-%d_%H:%M:%S",
                                              time.localtime(time.time()))

        self.__results_dict = None

        self.work_dir = None
        self.config_dir = "not_saved_yet"
        self.log_dir = "not_saved_yet"
        self.checkpoint_dir = "not_saved_yet"
        self.img_dir = "not_saved_yet"
        self.plot_dir = "not_saved_yet"
        self.save_dir = "not_saved_yet"
        self.result_dir = "not_saved_yet"
        self.exp_dir = "not_saved_yet"

        self.meta_name = None
        self.meta_star = False
        self.meta_ignore = False

        self.config = self.get_config()

        self.elog = None
def get_config():
    # Set your own path, if needed.
    #data_root_dir = os.path.abspath('data')  # The path where the downloaded dataset is stored.
    #data_root_dir = "/home/ramesh/Desktop/WS/Implementation/experiment/Data/Filtereddataset"
    #data_root_dir ="/home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet"
    data_root_dir = "/home/ramesh/Desktop/IIITB/experiment/data/NucliiData"
    taskName = "Nuclii Segmentation"
    #taskName = "Task09_Spleen"
    c = Config(
        update_from_argv=True,
        # Train parameters
        num_classes=1,
        in_channels=1,
        batch_size=6,
        patch_size=64,
        n_epochs=50,
        learning_rate=0.00001,
        fold=
        0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
        device=
        "cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Segmentation_Experiment_Unet_Nuclii',
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=False,
        start_visdom=True,
        do_instancenorm=
        True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir=' ',

        # Adapt to your own path, if needed.
        #google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
        dataset_name=taskName,
        base_dir=os.path.abspath(
            'output_experiment'),  # Where to log the output of the experiment.
        data_root_dir=
        data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(
            data_root_dir, taskName, 'imagesTr'
        ),  # This is where your training and validation data is stored

        #data_test_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'),  # This is where your test data is stored
        split_dir=os.path.join(
            data_root_dir, taskName, 'preprocessed'
        ),  # This is where the 'splits.pkl' file is located, that holds your splits.
    )

    print(c)
    return c
def get_config():
    # Set your own path, if needed.
    #data_root_dir = os.path.abspath('data')  # The path where the downloaded dataset is stored.
    #data_root_dir = "/home/ramesh/Desktop/WS/Implementation/data/MedSegDecathlon"
    data_root_dir = "/home/ramesh/Desktop/WS/Implementation/data/Brats17"
    c = Config(
        update_from_argv=True,

        # Train parameters
        num_classes=4,
        in_channels=1,
        batch_size=8,
        patch_size=160,
        n_epochs=1,
        learning_rate=0.0002,
        fold=
        0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
        device=
        "cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Basic_Unet',
        plot_freq=1,  # How often should stuff be shown in visdom
        append_rnd_string=False,
        start_visdom=True,
        do_instancenorm=
        True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir=
        '/home/ramesh/Desktop/WS/Implementation/experiment/basic_unet_example/output_experiment/20190504-180510_Basic_Unet',

        # Adapt to your own path, if needed.
        google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
        dataset_name='Brats17',
        base_dir=os.path.abspath(
            'output_experiment'),  # Where to log the output of the experiment.
        data_root_dir=
        data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(
            data_root_dir, 'Brats17_preprocessed'
        ),  # This is where your training and validation data is stored
        data_test_dir=os.path.join(
            data_root_dir,
            'Brats17_preprocessed'),  # This is where your test data is stored
        split_dir=os.path.join(
            data_root_dir, 'Brats17_preprocessed'
        ),  # This is where the 'splits.pkl' file is located, that holds your splits.
    )

    print(c)
    return c
Esempio n. 13
0
    def _save_exp_config(self):

        if self.elog is not None:
            cur_time = time.strftime("%y-%m-%d_%H:%M:%S",
                                     time.localtime(time.time()))
            self.elog.save_config(
                Config(
                    **{
                        'name': self.exp_name,
                        'time': self._time_start,
                        'state': self._exp_state,
                        'current_time': cur_time,
                        'epoch': self._epoch_idx
                    }), "exp")
def get_add_config():
    # Set your own path, if needed.
    data_root_dir = os.path.abspath(
        'data')  # The path where the downloaded dataset is stored.

    c = Config(
        do_instancenorm=
        True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir='',

        # Adapt to your own path, if needed.
        google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
        dataset_name='CHD_segmentation_dataset',
        base_dir=os.path.abspath(
            'output_experiment'),  # Where to log the output of the experiment.
        data_root_dir=
        data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(
            data_root_dir, 'CHD_interpolation_dataset/preprocessed'
        ),  # This is where your training and validation data is stored
        data_test_dir=os.path.join(data_root_dir,
                                   'CHD_interpolation_dataset/preprocessed'
                                   ),  # This is where your test data is stored
        split_dir=os.path.join(
            data_root_dir, 'CHD_interpolation_dataset'
        ),  # This is where the 'splits.pkl' file is located, that holds your splits.
        scaled_image_16_dir=os.path.join(
            data_root_dir, 'CHD_interpolation_dataset/scaled_to_16'),
        scaled_image_32_dir=os.path.join(
            data_root_dir, 'CHD_interpolation_dataset/scaled_to_32'),
        scaled_image_64_dir=os.path.join(
            data_root_dir, 'CHD_interpolation_dataset/scaled_to_64'),
        stage_1_dir=os.path.join(data_root_dir,
                                 'CHD_interpolation_dataset/stage_1'),
        stage_1_dir_32=os.path.join(data_root_dir,
                                    'CHD_interpolation_dataset/stage_1_32'),
        # stage_1_dir = os.path.join(data_root_dir, 'CHD_segmentation_dataset/stage_1')
    )

    print(c)
    return c
def get_config():
    c = Config()

    c.batch_size = 6
    c.patch_size = 512
    c.n_epochs = 20
    c.learning_rate = 0.0002
    c.do_ce_weighting = True
    c.do_batchnorm = True
    if torch.cuda.is_available():
        c.use_cuda = True
    else:
        c.use_cuda = False
    c.rnd_seed = 1
    c.log_interval = 200
    c.base_dir='/media/kleina/Data2/output/meddec'
    c.data_dir='/media/kleina/Data2/Data/meddec/Task07_Pancreas_expert_preprocessed'
    c.split_dir='/media/kleina/Data2/Data/meddec/Task07_Pancreas_preprocessed'
    c.data_file = 'C:/dev/data/Endoviz2018/GIANA/polyp_detection_segmentation/image_gt_data_file_list_all_640x640.csv'
    c.additional_slices=5
    c.name=''

    print(c)
    return c
Esempio n. 16
0
def get_config_heart(fine_tune_type='None',
                     exp_name='',
                     checkpoint_filename='',
                     checkpoint_dir='',
                     nr_train_samples=0,
                     download_data_from_drive=False):
    # Set your own path, if needed.
    data_root_dir = os.path.abspath(
        'data')  # The path where the downloaded dataset is stored.

    c = Config(
        update_from_argv=True,

        # Train parameters
        num_classes=2,
        in_channels=1,
        batch_size=8,
        patch_size=256,
        n_epochs=60,
        learning_rate=0.0002,
        fold=
        0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
        device="cuda",
        # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name=exp_name,
        author='maxi',  # Author of this project
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=False,
        start_visdom=False,
        do_instancenorm=
        True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=True,
        checkpoint_filename=checkpoint_filename,
        checkpoint_dir=checkpoint_dir,
        fine_tune=fine_tune_type,

        # Adapt to your own path, if needed.
        download_data=download_data_from_drive,
        google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
        dataset_name='Task02_Heart',
        base_dir=os.path.abspath(
            'output_experiment'),  # Where to log the output of the experiment.
        data_root_dir=
        data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(data_root_dir, 'Task02_Heart/preprocessed'),
        # This is where your training and validation data is stored
        data_test_dir=os.path.join(data_root_dir, 'Task02_Heart/preprocessed'),
        # This is where your test data is stored
        split_dir=os.path.join(data_root_dir, 'Task02_Heart'),
        # This is where the 'splits.pkl' file is located, that holds your splits.
        train_samples=nr_train_samples,
        # This is the amount of samples used in the train set. Use 0 for original split (1/2 train, 1/4 val, 1/4 test)
        # The validation set will be the same size, and the test set is the rest of the images

        # Testing
        visualize_segm=True)

    print(c)
    return c
Esempio n. 17
0
def run_experiment(experiment, configs, args, mods=None, **kwargs):

    # set a few defaults
    if "explogger_kwargs" not in kwargs:
        kwargs["explogger_kwargs"] = dict(
            folder_format="{experiment_name}_%Y%m%d-%H%M%S")
    if "explogger_freq" not in kwargs:
        kwargs["explogger_freq"] = 1
    if "resume_save_types" not in kwargs:
        kwargs["resume_save_types"] = ("model", "simple", "th_vars", "results")

    config = Config(file_=args.config) if args.config is not None else Config()
    config.update_missing(configs[args.default_config].deepcopy())
    if args.mods is not None and mods is not None:
        for mod in args.mods:
            config.update(mods[mod])
    config = Config(config=config, update_from_argv=True)

    # GET EXISTING EXPERIMENTS TO BE ABLE TO SKIP CERTAIN CONFIGS
    if args.skip_existing:
        existing_configs = []
        for exp in os.listdir(args.base_dir):
            try:
                existing_configs.append(
                    Config(file_=os.path.join(args.base_dir, exp, "config",
                                              "config.json")))
            except Exception as e:
                pass

    if args.grid is not None:
        grid = GridSearch().read(args.grid)
    else:
        grid = [{}]

    for combi in grid:

        config.update(combi)

        if args.skip_existing:
            skip_this = False
            for existing_config in existing_configs:
                if existing_config.contains(config):
                    skip_this = True
                    break
            if skip_this:
                continue

        if "backup_every" in config:
            kwargs["save_checkpoint_every_epoch"] = config["backup_every"]

        loggers = {}
        if args.visdomlogger:
            loggers["v"] = ("visdom", {}, 1)
        if args.tensorboardxlogger is not None:
            if args.tensorboardxlogger == "same":
                loggers["tx"] = ("tensorboard", {}, 1)
            else:
                loggers["tx"] = ("tensorboard", {
                    "target_dir": args.tensorboardxlogger
                }, 1)

        if args.telegramlogger:
            kwargs["use_telegram"] = True

        if args.automatic_description:
            difference_to_default = Config.difference_config_static(
                config, configs["DEFAULTS"]).flat(keep_lists=True,
                                                  max_split_size=0,
                                                  flatten_int=True)
            description_str = ""
            for key, val in difference_to_default.items():
                val = val[0]
                description_str = "{} = {}\n{}".format(key, val,
                                                       description_str)
            config.description = description_str

        exp = experiment(config=config,
                         base_dir=args.base_dir,
                         resume=args.resume,
                         ignore_resume_config=args.ignore_resume_config,
                         loggers=loggers,
                         **kwargs)

        trained = False
        if args.resume is None or args.test is False:
            exp.run()
            trained = True
        if args.test:
            exp.run_test(setup=not trained)
            if isinstance(args.resume,
                          str) and exp.elog is not None and args.copy_test:
                for f in glob.glob(os.path.join(exp.elog.save_dir, "test*")):
                    if os.path.isdir(f):
                        shutil.copytree(
                            f,
                            os.path.join(args.resume, "save",
                                         os.path.basename(f)))
                    else:
                        shutil.copy(f, os.path.join(args.resume, "save"))
Esempio n. 18
0
def experiment(base_dir):
    experiment_paths = request.args.getlist('exp')
    name = request.args.get('name', "")
    do_save = request.args.get('save', "")
    combi = request.args.get('combi', 'false')

    experiments = []

    if combi == "true":
        combi_exp = CombiExperimentReader(base_dir,
                                          experiment_paths,
                                          name=name)
        if do_save == "true":
            combi_exp.save()
        experiments = [combi_exp]
    else:
        # Get all Experiments
        for experiment_path in sorted(experiment_paths):
            exp = ExperimentReader(base_dir, experiment_path)
            experiments.append(exp)

    # Assign unique names
    exp_names = [exp.exp_name for exp in experiments]
    if len(exp_names) > len(set(exp_names)):
        for i, exp in enumerate(experiments):
            exp.exp_name += str(i)
    exp_names = [exp.exp_name for exp in experiments]

    # Site Content
    content = {}

    # Get config
    default_val = "-"
    combi_config = {}
    exp_configs = [exp.config.flat(False) for exp in experiments]
    diff_config_keys = list(
        Config.difference_config_static(*exp_configs).keys())
    config_keys = set([k for c in exp_configs for k in c.keys()])
    for k in sorted(config_keys):
        combi_config[k] = []
        for conf in exp_configs:
            combi_config[k].append(conf.get(k, default_val))
    config_keys = list(sorted(list(config_keys)))

    # Get results
    default_val = "-"
    combi_results = {}
    exp_results = [exp.get_results() for exp in experiments]
    result_keys = set([k for r in exp_results for k in r.keys()])
    for k in sorted(result_keys):
        combi_results[k] = []
        for res in exp_results:
            combi_results[k].append(res.get(k, default_val))
    result_keys = list(sorted(list(result_keys)))

    # Get images
    images = OrderedDict({})
    image_keys = set()
    image_path = {}
    for exp in experiments:
        exp_images = exp.get_images()
        img_groups = group_images(exp_images)
        images[exp.exp_name] = img_groups
        image_path[exp.exp_name] = exp.img_dir
        image_keys.update(list(img_groups.keys()))
    image_keys = list(image_keys)
    image_keys.sort()

    # Get Plots
    plots = OrderedDict({})
    for exp in experiments:
        exp_plots = exp.get_plots()
        plots[exp.exp_name] = exp_plots

    # Get logs
    logs_dict = OrderedDict({})
    for exp in experiments:
        exp_logs = [(os.path.basename(l), exp.exp_dir) for l in exp.get_logs()]
        logs_dict[exp.exp_name] = exp_logs

    content["title"] = experiments
    content["images"] = {
        "img_path": image_path,
        "imgs": images,
        "img_keys": image_keys
    }
    content["plots"] = {"plots": plots}
    content["config"] = {
        "exps": experiments,
        "configs": combi_config,
        "keys": config_keys,
        "diff_keys": diff_config_keys
    }
    content["results"] = {
        "exps": exp_names,
        "results": combi_results,
        "keys": result_keys
    }
    content["logs"] = {"logs_dict": logs_dict}

    return render_template('experiment.html', **content)
def make_defaults(patch_size=112,
                  in_channels=4,
                  latent_size=3,
                  labels=[0, 1, 2, 3]):

    if hasattr(patch_size, "__iter__"):
        if len(patch_size) > 1:
            patch_size = tuple(patch_size)
        else:
            patch_size = patch_size[0]
    if not hasattr(patch_size, "__iter__"):
        patch_size = tuple([
            patch_size,
        ] * 3)

    DEFAULTS = Config(

        # Base
        name=os.path.basename(__file__).split(".")[0],
        description=DESCRIPTION,
        n_epochs=50000,
        batch_size=2,
        batch_size_val=1,
        patch_size=patch_size,
        in_channels=in_channels,
        out_channels=len(labels),
        latent_size=latent_size,
        seed=1,
        device="cuda",

        # Data
        split_val=3,
        split_test=4,
        data_module=data,
        data_dir=None,  # we're setting data_module.data_dir if this is given
        mmap_mode="r",
        npz=False,
        debug=0,  # 1 selects (10, 5, 5) patients, 2 a single batch
        train_on_all=False,  # adds val and test to training set
        generator_train=data.RandomBatchGenerator,
        generator_val=data.LinearBatchGenerator,
        transforms_train={
            0: {
                "type": SpatialTransform,
                "kwargs": {
                    "patch_size": patch_size,
                    "patch_center_dist_from_border": patch_size[0] // 2,
                    "do_elastic_deform": False,
                    "p_el_per_sample": 0.2,
                    "p_rot_per_sample": 0.3,
                    "p_scale_per_sample": 0.3
                },
                "active": True
            },
            1: {
                "type": MirrorTransform,
                "kwargs": {
                    "axes": (0, 1, 2)
                },
                "active": True
            },
            2: {
                "type": SegLabelSelectionBinarizeTransform,
                "kwargs": {
                    "label": [1, 2, 3]
                },
                "active": False
            }
        },
        transforms_val={
            0: {
                "type": CenterCropTransform,
                "kwargs": {
                    "crop_size": patch_size
                },
                "active": False
            },
            1: {
                "type": SegLabelSelectionBinarizeTransform,
                "kwargs": {
                    "label": [1, 2, 3]
                },
                "active": False
            },
            2: {
                "type": SpatialTransform,
                "kwargs": {
                    "patch_size": patch_size,
                    "patch_center_dist_from_border": patch_size[0] // 2,
                    "do_elastic_deform": False,
                    "do_rotation": False,
                    "do_scale": True,
                    "p_scale_per_sample": 1,
                    "scale": (1.25, 1.25)
                },
                "active": False
            }
        },
        augmenter_train=MultiThreadedAugmenter,
        augmenter_train_kwargs={
            "num_processes": 11,
            "num_cached_per_queue": 6,
            "pin_memory": True
        },
        augmenter_val=MultiThreadedAugmenter,
        augmenter_val_kwargs={
            "num_processes": 2,
            "pin_memory": True
        },

        # Model
        model=ProbabilisticSegmentationNet,
        model_kwargs={
            "in_channels": in_channels,
            "out_channels": len(labels),
            "num_feature_maps": 24,
            "latent_size": latent_size,
            "depth": 5,
            "latent_distribution": distributions.Normal,
            "task_op": InjectionUNet3D,
            "task_kwargs": {
                "output_activation_op": nn.LogSoftmax,
                "output_activation_kwargs": {
                    "dim": 1
                },
                "activation_kwargs": {
                    "inplace": True
                }
            },
            "prior_op": InjectionConvEncoder3D,
            "prior_kwargs": {
                "in_channels": in_channels,
                "out_channels": latent_size * 2,
                "depth": 5,
                "block_depth": 2,
                "num_feature_maps": 24,
                "feature_map_multiplier": 2,
                "activation_kwargs": {
                    "inplace": True
                },
                "norm_depth": 2,
            },
            "posterior_op": InjectionConvEncoder3D,
            "posterior_kwargs": {
                "in_channels": in_channels + len(labels),
                "out_channels": latent_size * 2,
                "depth": 5,
                "block_depth": 2,
                "num_feature_maps": 24,
                "feature_map_multiplier": 2,
                "activation_kwargs": {
                    "inplace": True
                },
                "norm_depth": 2,
            },
        },
        model_init_weights_args=[nn.init.kaiming_uniform_, 0],
        model_init_bias_args=[nn.init.constant_, 0],

        # Learning
        optimizer=optim.Adam,
        optimizer_kwargs={"lr": 1e-4},
        scheduler=optim.lr_scheduler.StepLR,
        scheduler_kwargs={
            "step_size": 200,
            "gamma": 0.985
        },
        criterion_segmentation=nn.NLLLoss,
        criterion_segmentation_kwargs={"reduction": "sum"},
        criterion_latent=distributions.kl_divergence,
        criterion_latent_kwargs={},
        criterion_latent_init=False,
        criterion_segmentation_seg_onehot=False,
        criterion_segmentation_weight=1.0,
        criterion_latent_weight=1.0,
        criterion_segmentation_seg_dtype=torch.long,

        # Logging
        backup_every=1000,
        validate_every=1000,
        validate_subset=0.1,  # validate only this percentage randomly
        show_every=10,
        validate_metrics=["Dice"],
        labels=labels,
        evaluator=Evaluator,
        evaluator_kwargs={
            "label_values": list(labels) + [tuple(labels[1:])],
            "label_names": {
                0: "Background",
                1: "Edema",
                2: "Enhancing",
                3: "Necrosis",
                tuple(labels[1:]): "Whole Tumor"
            },
            "nan_for_nonexisting": True
        },
        val_save_output=False,
        val_example_samples=10,
        val_save_images=False,
        latent_plot_range=[-5, 5],
        test_on_val=True,
        test_save_output=False,
        test_future=True,
        test_std_factor=3,
        test_std_scale=1.)

    TASKMEAN = Config(
        criterion_segmentation_kwargs={"reduction": "elementwise_mean"})

    ELASTIC = Config(
        transforms_train={0: {
            "kwargs": {
                "do_elastic_deform": True
            }
        }})

    NONORM = Config(model_kwargs={
        "prior_kwargs": {
            "norm_depth": 0
        },
        "posterior_kwargs": {
            "norm_depth": 0
        }
    })

    FULLNORM = Config(
        model_kwargs={
            "prior_kwargs": {
                "norm_depth": "full"
            },
            "posterior_kwargs": {
                "norm_depth": "full"
            }
        })

    BATCHNORM = Config(
        model_kwargs={
            "prior_kwargs": {
                "norm_op": nn.BatchNorm3d
            },
            "posterior_kwargs": {
                "norm_op": nn.BatchNorm3d
            },
            "task_kwargs": {
                "norm_op": nn.BatchNorm3d
            }
        })

    WHOLETUMOR = Config(transforms_train={2: {
        "active": True
    }},
                        transforms_val={1: {
                            "active": True
                        }},
                        out_channels=2,
                        labels=[0, 1],
                        model_kwargs={
                            "out_channels": 2,
                            "posterior_kwargs": {
                                "in_channels": in_channels + 2
                            }
                        },
                        evaluator_kwargs={
                            "label_values": [0, 1],
                            "label_names": {
                                0: "Background",
                                1: "Whole Tumor"
                            }
                        })

    ENHANCING = Config(
        transforms_train={2: {
            "kwargs": {
                "label": 2
            },
            "active": True
        }},
        transforms_val={1: {
            "kwargs": {
                "label": 2
            },
            "active": True
        }},
        out_channels=2,
        labels=[0, 1],
        model_kwargs={
            "out_channels": 2,
            "posterior_kwargs": {
                "in_channels": in_channels + 2
            }
        },
        evaluator_kwargs={
            "label_values": [0, 1],
            "label_names": {
                0: "Background",
                1: "Whole Tumor"
            }
        })

    NOAUGMENT = Config(
        transforms_train={
            0: {
                "kwargs": {
                    "p_el_per_sample": 0,
                    "p_rot_per_sample": 0,
                    "p_scale_per_sample": 0
                }
            },
            1: {
                "active": False
            }
        })

    LOWAUGMENT = Config(
        transforms_train={
            0: {
                "kwargs": {
                    "p_el_per_sample": 0.,
                    "p_rot_per_sample": 0.15,
                    "p_scale_per_sample": 0.15
                }
            }
        })

    NOBG = Config(criterion_segmentation_kwargs={"ignore_index": 0})

    VALIDATEPATCHED = Config(transforms_val={2: {"active": True}})

    MODS = {
        "TASKMEAN": TASKMEAN,
        "ELASTIC": ELASTIC,
        "NONORM": NONORM,
        "FULLNORM": FULLNORM,
        "BATCHNORM": BATCHNORM,
        "WHOLETUMOR": WHOLETUMOR,
        "ENHANCING": ENHANCING,
        "NOAUGMENT": NOAUGMENT,
        "LOWAUGMENT": LOWAUGMENT,
        "NOBG": NOBG,
        "VALIDATEPATCHED": VALIDATEPATCHED
    }

    return {"DEFAULTS": DEFAULTS}, MODS
Esempio n. 20
0
def test_2Dexperiment():
    c = Config()

    c.batch_size = 200
    c.n_epochs = 40
    c.learning_rate = 0.001
    if torch.cuda.is_available():
        c.use_cuda = True
    else:
        c.use_cuda = False
    c.rnd_seed = 1
    c.log_interval = 200
    # model-specific
    c.n_coupling = 8
    c.prior = 'gauss'

    exp = SmileyExperiment(
        c,
        name='gauss',
        n_epochs=c.n_epochs,
        seed=42,
        base_dir='experiment_dir',
        loggers={'visdom': ['visdom', {
            "exp_name": "myenv"
        }]})

    exp.run()

    # sampling
    samples = exp.model.sample(1000).cpu().numpy()
    sns.jointplot(samples[:, 0], samples[:, 1])
    plt.show()
Esempio n. 21
0
class ExperimentReader(object):
    """Reader class to read out experiments created by :class:`trixi.experimentlogger.ExperimentLogger`.

    Args:
        work_dir (str): Directory with the structure defined by
                        :class:`trixi.experimentlogger.ExperimentLogger`.
        name (str): Optional name for the experiment. If None, will try
                    to read name from experiment config.

    """
    def __init__(self,
                 base_dir,
                 exp_dir="",
                 name=None,
                 decode_config_clean_str=True):

        super(ExperimentReader, self).__init__()

        self.base_dir = base_dir
        self.exp_dir = exp_dir
        self.work_dir = os.path.abspath(
            os.path.join(self.base_dir, self.exp_dir))
        self.config_dir = os.path.join(self.work_dir, "config")
        self.log_dir = os.path.join(self.work_dir, "log")
        self.checkpoint_dir = os.path.join(self.work_dir, "checkpoint")
        self.img_dir = os.path.join(self.work_dir, "img")
        self.plot_dir = os.path.join(self.work_dir, "plot")
        self.save_dir = os.path.join(self.work_dir, "save")
        self.result_dir = os.path.join(self.work_dir, "result")

        self.config = Config()
        if decode_config_clean_str:
            self.config.load(os.path.join(self.config_dir, "config.json"),
                             decoder_cls_=StringMultiTypeDecoder)
        else:
            self.config.load(os.path.join(self.config_dir, "config.json"),
                             decoder_cls_=None)

        self.exp_info = Config()
        exp_info_file = os.path.join(self.config_dir, "exp.json")
        if os.path.exists(exp_info_file):
            self.exp_info.load(exp_info_file)

        self.__results_dict = None

        self.meta_name = None
        self.meta_star = False
        self.meta_ignore = False
        self.read_meta_info()

        if name is not None:
            self.exp_name = name
        elif self.meta_name is not None:
            self.exp_name = self.meta_name
        elif "name" in self.exp_info:
            self.exp_name = self.exp_info['name']
        elif "exp_name" in self.config:
            self.exp_name = self.config['exp_name']
        else:
            self.exp_name = "experiments"

        self.ignore = self.meta_ignore
        self.star = self.meta_star

    @staticmethod
    def get_file_contents(folder):
        """Get all files in a folder.

        Returns:
            list: All files joined with folder path.
        """

        if os.path.isdir(folder):
            list_ = map(lambda x: os.path.join(folder, x),
                        sorted(os.listdir(folder)))
            return list(filter(lambda x: os.path.isfile(x), list_))
        else:
            return []

    def get_images(self):
        imgs = []
        imgs += ExperimentReader.get_file_contents(self.img_dir)
        if os.path.isdir(self.img_dir):
            for f in os.listdir(self.img_dir):
                f = os.path.join(self.img_dir, f)
                if os.path.isdir(f):
                    imgs += ExperimentReader.get_file_contents(f)
        return imgs

    def get_plots(self):
        return ExperimentReader.get_file_contents(self.plot_dir)

    def get_checkpoints(self):
        return ExperimentReader.get_file_contents(self.checkpoint_dir)

    def get_logs(self):
        return ExperimentReader.get_file_contents(self.log_dir)

    def get_log_file_content(self, file_name):
        """Read out log file and HTMLify.

        Args:
            file_name (str): Name of the log file.

        Returns:
            str: Log file contents as HTML ready string.
        """

        content = ""
        log_file = os.path.join(self.log_dir, file_name)

        if os.path.exists(log_file):
            with open(log_file, 'r') as f:
                content = f.read()
                content = content.replace("\n", "<br>")

        return content

    def get_results_log(self):
        """Build result dictionary.

        During the experiment result items are
        written out as a stream of quasi-atomic units. This reads the stream and
        builds arrays of corresponding items.
        The resulting dict looks like this::

            {
                "result group": {
                    "result": {
                        "counter": x-array,
                        "data": y-array
                    }
                }
            }

        Returns:
            dict: Result dictionary.

        """

        results_merged = {}

        results = []
        try:
            with open(os.path.join(self.result_dir, "results-log.json"),
                      "r") as results_file:
                results = json.load(results_file)
        except Exception as e:
            try:
                with open(os.path.join(self.result_dir, "results-log.json"),
                          "r") as results_file:
                    results_str = results_file.readlines()
                    if "]" in "]" in "".join(results_str):
                        results_str = [
                            rs.replace("]", ",") for rs in results_str
                        ]
                    results_str.append("{}]")
                    results_json = "".join(results_str)
                    results = json.loads(results_json)
            except Exception as ee:
                print("Could not load result log from", self.result_dir)
                print(ee)

        for result in results:
            for key in result.keys():
                counter = result[key]["counter"]
                data = result[key]["data"]
                label = str(result[key]["label"])
                if label not in results_merged:
                    results_merged[label] = {}
                if key not in results_merged[label]:
                    results_merged[label][key] = defaultdict(list)
                results_merged[label][key]["data"].append(data)
                results_merged[label][key]["counter"].append(counter)
                if "max" in result[key] and "min" in result[key]:
                    results_merged[label][key]["max"].append(
                        result[key]["max"])
                    results_merged[label][key]["min"].append(
                        result[key]["min"])

        return results_merged

    def get_results(self):
        """Get the last result item.

        Returns:
            dict: The last result item in the experiment.

        """

        if self.__results_dict is None:

            self.__results_dict = {}
            results_file = os.path.join(self.result_dir, "results.json")

            if os.path.exists(results_file):
                try:
                    with open(results_file, "r") as f:
                        self.__results_dict = json.load(f)
                except Exception as e:
                    pass

        return self.__results_dict

    def ignore_experiment(self):
        """Create a flag file, so the browser ignores this experiment."""
        self.update_meta_info(ignore=True)

    def read_meta_info(self):
        """Reads the meta info of the experiment i.e. new name, stared or ignored"""
        meta_dict = {}
        meta_file = os.path.join(self.work_dir, ".exp_info")
        if os.path.exists(meta_file):
            with open(meta_file, "r") as mf:
                meta_dict = json.load(mf)
            self.meta_name = meta_dict.get("name")
            self.meta_star = meta_dict.get("star", False)
            self.meta_ignore = meta_dict.get("ignore", False)

    def update_meta_info(self, name=None, star=None, ignore=None):
        """
        Updates the meta info i.e. new name, stared or ignored and saves it in the experiment folder

        Args:
            name (str): New name of the experiment
            star (bool): Flag, if experiment is starred/ favorited
            ignore (boll): Flag, if experiment should be ignored
        """

        if name is not None:
            self.meta_name = name
        if star is not None:
            self.meta_star = star
        if ignore is not None:
            self.meta_ignore = ignore

        meta_dict = {
            "name": self.meta_name,
            "star": self.meta_star,
            "ignore": self.meta_ignore
        }
        meta_file = os.path.join(self.work_dir, ".exp_info")
        with open(meta_file, "w") as mf:
            json.dump(meta_dict, mf)
Esempio n. 22
0
    def __init__(self,
                 config=None,
                 name=None,
                 n_epochs=None,
                 seed=None,
                 base_dir=None,
                 globs=None,
                 resume=None,
                 ignore_resume_config=False,
                 resume_save_types=("model", "optimizer", "simple", "th_vars",
                                    "results"),
                 resume_reset_epochs=True,
                 parse_sys_argv=False,
                 parse_config_sys_argv=True,
                 checkpoint_to_cpu=True,
                 safe_checkpoint_every_epoch=1,
                 use_visdomlogger=True,
                 visdomlogger_kwargs=None,
                 visdomlogger_c_freq=1,
                 use_explogger=True,
                 explogger_kwargs=None,
                 explogger_c_freq=100,
                 use_telegrammessagelogger=False,
                 telegrammessagelogger_kwargs=None,
                 telegrammessagelogger_c_freq=1000,
                 append_rnd_to_name=False):

        # super(PytorchExperiment, self).__init__()
        Experiment.__init__(self)

        if parse_sys_argv:
            config_path, resume_path = get_vars_from_sys_argv()
            if config_path:
                config = config_path
            if resume_path:
                resume = resume_path

        self._config_raw = None
        if isinstance(config, str):
            self._config_raw = Config(file_=config,
                                      update_from_argv=parse_config_sys_argv)
        elif isinstance(config, Config):
            self._config_raw = Config(config=config,
                                      update_from_argv=parse_config_sys_argv)
        elif isinstance(config, dict):
            self._config_raw = Config(config=config,
                                      update_from_argv=parse_config_sys_argv)
        else:
            self._config_raw = Config(update_from_argv=parse_config_sys_argv)

        self.n_epochs = n_epochs
        if 'n_epochs' in self._config_raw:
            self.n_epochs = self._config_raw["n_epochs"]
        if self.n_epochs is None:
            self.n_epochs = 0

        self._seed = seed
        if 'seed' in self._config_raw:
            self._seed = self._config_raw.seed
        if self._seed is None:
            random_data = os.urandom(4)
            seed = int.from_bytes(random_data, byteorder="big")
            self._config_raw.seed = seed
            self._seed = seed

        self.exp_name = name
        if 'name' in self._config_raw:
            self.exp_name = self._config_raw["name"]
        if append_rnd_to_name:
            rnd_str = ''.join(
                random.choice(string.ascii_letters + string.digits)
                for _ in range(5))
            self.exp_name += "_" + rnd_str

        if 'base_dir' in self._config_raw:
            base_dir = self._config_raw["base_dir"]

        self._checkpoint_to_cpu = checkpoint_to_cpu
        self._safe_checkpoint_every_epoch = safe_checkpoint_every_epoch

        self.results = dict()

        # Init loggers
        logger_list = []
        self.vlog = None
        if use_visdomlogger:
            if visdomlogger_kwargs is None:
                visdomlogger_kwargs = {}
            self.vlog = PytorchVisdomLogger(name=self.exp_name,
                                            **visdomlogger_kwargs)
            if visdomlogger_c_freq is not None and visdomlogger_c_freq > 0:
                logger_list.append((self.vlog, visdomlogger_c_freq))
        self.elog = None
        if use_explogger:
            if explogger_kwargs is None:
                explogger_kwargs = {}
            self.elog = PytorchExperimentLogger(base_dir=base_dir,
                                                experiment_name=self.exp_name,
                                                **explogger_kwargs)
            if explogger_c_freq is not None and explogger_c_freq > 0:
                logger_list.append((self.elog, explogger_c_freq))

            # Set results log dict to the right path
            self.results = ResultLogDict("results-log.json",
                                         base_dir=self.elog.result_dir)
        self.tlog = None
        if use_telegrammessagelogger:
            if telegrammessagelogger_kwargs is None:
                telegrammessagelogger_kwargs = {}
            self.tlog = TelegramMessageLogger(**telegrammessagelogger_kwargs,
                                              exp_name=self.exp_name)
            if telegrammessagelogger_c_freq is not None and telegrammessagelogger_c_freq > 0:
                logger_list.append((self.tlog, telegrammessagelogger_c_freq))

        self.clog = CombinedLogger(*logger_list)

        set_seed(self._seed)

        # Do the resume stuff
        self._resume_path = None
        self._resume_save_types = resume_save_types
        self._ignore_resume_config = ignore_resume_config
        self._resume_reset_epochs = resume_reset_epochs
        if resume is not None:
            if isinstance(resume, str):
                if resume == "last":
                    self._resume_path = os.path.join(
                        base_dir,
                        sorted(os.listdir(base_dir))[-1])
                else:
                    self._resume_path = resume
            elif isinstance(resume, PytorchExperiment):
                self._resume_path = resume.elog.base_dir

        if self._resume_path is not None and not self._ignore_resume_config:
            self._config_raw.update(Config(file_=os.path.join(
                self._resume_path, "config", "config.json")),
                                    ignore=list(
                                        map(lambda x: re.sub("^-+", "", x),
                                            sys.argv)))

        # self.elog.save_config(self.config, "config_pre")
        if globs is not None:
            zip_name = os.path.join(self.elog.save_dir, "sources.zip")
            SourcePacker.zip_sources(globs, zip_name)

        # Init objects in config
        self.config = Config.init_objects(self._config_raw)

        atexit.register(self.at_exit_func)
Esempio n. 23
0
class PytorchExperiment(Experiment):
    """
    A PytorchExperiment extends the basic
    functionality of the :class:`.Experiment` class with
    convenience features for PyTorch (and general logging) such as creating a folder structure,
    saving, plotting results and checkpointing your experiment.

    The basic life cycle of a PytorchExperiment is the same as
    :class:`.Experiment`::

        setup()
        prepare()

        for epoch in n_epochs:
            train()
            validate()

        end()

    where the distinction between the first two is that between them
    PytorchExperiment will automatically restore checkpoints and save the
    :attr:`_config_raw` in :meth:`._setup_internal`. Please see below for more
    information on this.
    To get your own experiment simply inherit from the PytorchExperiment and
    overwrite the :meth:`.setup`, :meth:`.prepare`, :meth:`.train`,
    :meth:`.validate` method (or you can use the `very` experimental decorator
    :func:`.experimentify` to convert your class into a experiment).
    Then you can run your own experiment by calling the :meth:`.run` method.

    Internally PytorchExperiment will provide a number of member variables which
    you can access.

        - n_epochs
            Number of epochs.
        - exp_name
            Name of your experiment.
        - config
            The (initialized) :class:`.Config` of your experiment. You can
            access the uninitialized one via :attr:`_config_raw`.
        - result
            A dict in which you can store your result values. If a
            :class:`.PytorchExperimentLogger` is used, results will be a
            :class:`.ResultLogDict` that directly automatically writes to a file
            and also stores the N last entries for each key for quick access
            (e.g. to quickly get the running mean).
        - vlog (if use_visdomlogger is True)
            A :class:`.PytorchVisdomLogger` instance which can log your results
            to a running visdom server. Start the server via
            :code:`python -m visdom.server` or pass :data:`auto_start=True` in
            the :attr:`visdomlogger_kwargs`.
        - elog (if use_explogger is True)
            A :class:`.PytorchExperimentLogger` instance which can log your
            results to a given folder.
        - tlog (if use_telegrammessagelogger is True)
            A :class:`.TelegramMessageLogger` instance which can send the results to
            your telegram account
        - clog
            A :class:`.CombinedLogger` instance which logs to all loggers with
            different frequencies (specified with the :attr:`_c_freq` for each
            logger where 1 means every time and N means every Nth time,
            e.g. if you only want to send stuff to Visdom every 10th time).

    The most important attribute is certainly :attr:`.config`, which is the
    initialized :class:`.Config` for the experiment. To understand how it needs
    to be structured to allow for automatic instantiation of types, please refer
    to its documentation. If you decide not to use this functionality,
    :attr:`config` and :attr:`_config_raw` are identical. **Beware however that
    by default the Pytorchexperiment only saves the raw config** after
    :meth:`.setup`. If you modify :attr:`config` during setup, make sure
    to implement :meth:`._setup_internal` yourself should you want the modified
    config to be saved::

        def _setup_internal(self):

            super(YourExperiment, self)._setup_internal() # calls .prepare_resume()
            self.elog.save_config(self.config, "config")

    Args:
        config (dict or Config): Configures your experiment. If :attr:`name`,
            :attr:`n_epochs`, :attr:`seed`, :attr:`base_dir` are given in the
            config, it will automatically
            overwrite the other args/kwargs with the values from the config.
            In addition (defined by :attr:`parse_config_sys_argv`) the config
            automatically parses the argv arguments and updates its values if a
            key matches a console argument.
        name (str):
            The name of the PytorchExperiment.
        n_epochs (int): The number of epochs (number of times the training
            cycle will be executed).
        seed (int): A random seed (which will set the random, numpy and
            torch seed).
        base_dir (str): A base directory in which the experiment result folder
            will be created.
        globs: The :func:`globals` of the script which is run. This is necessary
            to get and save the executed files in the experiment folder.
        resume (str or PytorchExperiment): Another PytorchExperiment or path to
            the result dir from another PytorchExperiment from which it will
            load the PyTorch modules and other member variables and resume
            the experiment.
        ignore_resume_config (bool): If :obj:`True` it will not resume with the
            config from the resume experiment but take the current/own config.
        resume_save_types (list or tuple): A list which can define which values
            to restore when resuming. Choices are:

                - "model" <-- Pytorch models
                - "optimizer" <-- Optimizers
                - "simple" <-- Simple python variables (basic types and lists/tuples
                - "th_vars" <-- torch tensors/variables
                - "results" <-- The result dict

        parse_sys_argv (bool): Parsing the console arguments (argv) to get a
            :attr:`config path` and/or :attr:`resume_path`.
        parse_config_sys_argv (bool): Parse argv to update the config
            (if the keys match).
        checkpoint_to_cpu (bool): When checkpointing, transfer all tensors to
            the CPU beforehand.
        safe_checkpoint_every_epoch (int): Determines after how many epochs a
            checkpoint is stored.
        use_visdomlogger (bool): Use a :class:`.PytorchVisdomLogger`. Is
            accessible via the :attr:`vlog` attribute.
        visdomlogger_kwargs (dict): Keyword arguments for :attr:`vlog`
            instantiation.
        visdomlogger_c_freq (int): The frequency x (meaning one in x) with which
            the :attr:`clog` will call the :attr:`vlog`.
        use_explogger (bool): Use a :class:`.PytorchExperimentLogger`. Is
            accessible via the :attr:`elog` attribute. This will create the
            experiment folder structure.
        explogger_kwargs (dict): Keyword arguments for :attr:`elog`
            instantiation.
        explogger_c_freq (int): The frequency x (meaning one in x) with which
            the :attr:`clog` will call the :attr:`elog`.
        use_telegrammessagelogger (bool): Use a :class:`.TelegramMessageLogger`. Is
            accessible via the :attr:`tlog` attribute.
        telegrammessagelogger_kwargs (dict): Keyword arguments for :attr:`tlog`
            instantiation.
        telegrammessagelogger_c_freq (int): The frequency x (meaning one in x) with which
            the :attr:`clog` will call the :attr:`tlog`.
        append_rnd_to_name (bool): If :obj:`True`, will append a random six
            digit string to the experiment name.

     """
    def __init__(self,
                 config=None,
                 name=None,
                 n_epochs=None,
                 seed=None,
                 base_dir=None,
                 globs=None,
                 resume=None,
                 ignore_resume_config=False,
                 resume_save_types=("model", "optimizer", "simple", "th_vars",
                                    "results"),
                 resume_reset_epochs=True,
                 parse_sys_argv=False,
                 parse_config_sys_argv=True,
                 checkpoint_to_cpu=True,
                 safe_checkpoint_every_epoch=1,
                 use_visdomlogger=True,
                 visdomlogger_kwargs=None,
                 visdomlogger_c_freq=1,
                 use_explogger=True,
                 explogger_kwargs=None,
                 explogger_c_freq=100,
                 use_telegrammessagelogger=False,
                 telegrammessagelogger_kwargs=None,
                 telegrammessagelogger_c_freq=1000,
                 append_rnd_to_name=False):

        # super(PytorchExperiment, self).__init__()
        Experiment.__init__(self)

        if parse_sys_argv:
            config_path, resume_path = get_vars_from_sys_argv()
            if config_path:
                config = config_path
            if resume_path:
                resume = resume_path

        self._config_raw = None
        if isinstance(config, str):
            self._config_raw = Config(file_=config,
                                      update_from_argv=parse_config_sys_argv)
        elif isinstance(config, Config):
            self._config_raw = Config(config=config,
                                      update_from_argv=parse_config_sys_argv)
        elif isinstance(config, dict):
            self._config_raw = Config(config=config,
                                      update_from_argv=parse_config_sys_argv)
        else:
            self._config_raw = Config(update_from_argv=parse_config_sys_argv)

        self.n_epochs = n_epochs
        if 'n_epochs' in self._config_raw:
            self.n_epochs = self._config_raw["n_epochs"]
        if self.n_epochs is None:
            self.n_epochs = 0

        self._seed = seed
        if 'seed' in self._config_raw:
            self._seed = self._config_raw.seed
        if self._seed is None:
            random_data = os.urandom(4)
            seed = int.from_bytes(random_data, byteorder="big")
            self._config_raw.seed = seed
            self._seed = seed

        self.exp_name = name
        if 'name' in self._config_raw:
            self.exp_name = self._config_raw["name"]
        if append_rnd_to_name:
            rnd_str = ''.join(
                random.choice(string.ascii_letters + string.digits)
                for _ in range(5))
            self.exp_name += "_" + rnd_str

        if 'base_dir' in self._config_raw:
            base_dir = self._config_raw["base_dir"]

        self._checkpoint_to_cpu = checkpoint_to_cpu
        self._safe_checkpoint_every_epoch = safe_checkpoint_every_epoch

        self.results = dict()

        # Init loggers
        logger_list = []
        self.vlog = None
        if use_visdomlogger:
            if visdomlogger_kwargs is None:
                visdomlogger_kwargs = {}
            self.vlog = PytorchVisdomLogger(name=self.exp_name,
                                            **visdomlogger_kwargs)
            if visdomlogger_c_freq is not None and visdomlogger_c_freq > 0:
                logger_list.append((self.vlog, visdomlogger_c_freq))
        self.elog = None
        if use_explogger:
            if explogger_kwargs is None:
                explogger_kwargs = {}
            self.elog = PytorchExperimentLogger(base_dir=base_dir,
                                                experiment_name=self.exp_name,
                                                **explogger_kwargs)
            if explogger_c_freq is not None and explogger_c_freq > 0:
                logger_list.append((self.elog, explogger_c_freq))

            # Set results log dict to the right path
            self.results = ResultLogDict("results-log.json",
                                         base_dir=self.elog.result_dir)
        self.tlog = None
        if use_telegrammessagelogger:
            if telegrammessagelogger_kwargs is None:
                telegrammessagelogger_kwargs = {}
            self.tlog = TelegramMessageLogger(**telegrammessagelogger_kwargs,
                                              exp_name=self.exp_name)
            if telegrammessagelogger_c_freq is not None and telegrammessagelogger_c_freq > 0:
                logger_list.append((self.tlog, telegrammessagelogger_c_freq))

        self.clog = CombinedLogger(*logger_list)

        set_seed(self._seed)

        # Do the resume stuff
        self._resume_path = None
        self._resume_save_types = resume_save_types
        self._ignore_resume_config = ignore_resume_config
        self._resume_reset_epochs = resume_reset_epochs
        if resume is not None:
            if isinstance(resume, str):
                if resume == "last":
                    self._resume_path = os.path.join(
                        base_dir,
                        sorted(os.listdir(base_dir))[-1])
                else:
                    self._resume_path = resume
            elif isinstance(resume, PytorchExperiment):
                self._resume_path = resume.elog.base_dir

        if self._resume_path is not None and not self._ignore_resume_config:
            self._config_raw.update(Config(file_=os.path.join(
                self._resume_path, "config", "config.json")),
                                    ignore=list(
                                        map(lambda x: re.sub("^-+", "", x),
                                            sys.argv)))

        # self.elog.save_config(self.config, "config_pre")
        if globs is not None:
            zip_name = os.path.join(self.elog.save_dir, "sources.zip")
            SourcePacker.zip_sources(globs, zip_name)

        # Init objects in config
        self.config = Config.init_objects(self._config_raw)

        atexit.register(self.at_exit_func)

    def process_err(self, e):
        if self.elog is not None:
            self.elog.text_logger.log_to(
                "\n".join(traceback.format_tb(e.__traceback__)), "err")

    def update_attributes(self, var_dict, ignore=()):
        """
        Updates the member attributes with the attributes given in the var_dict

        Args:
            var_dict (dict): dict in which the update values stored. If a key matches a member attribute name
                the member attribute will be updated
            ignore (list or tuple): iterable of keys to ignore

        """
        for key, val in var_dict.items():
            if key == "results":
                self.results.load(val)
                continue
            if key in ignore:
                continue
            if hasattr(self, key):
                setattr(self, key, val)

    def get_pytorch_modules(self, from_config=True):
        """
        Returns all torch.nn.Modules stored in the experiment in a dict.

        Args:
            from_config (bool): Also get modules that are stored in the :attr:`.config` attribute.

        Returns:
            dict: Dictionary of PyTorch modules

        """

        pyth_modules = dict()
        for key, val in self.__dict__.items():
            if isinstance(val, torch.nn.Module):
                pyth_modules[key] = val
        if from_config:
            for key, val in self.config.items():
                if isinstance(val, torch.nn.Module):
                    if type(key) == str:
                        key = "config." + key
                    pyth_modules[key] = val
        return pyth_modules

    def get_pytorch_optimizers(self, from_config=True):
        """
        Returns all torch.optim.Optimizers stored in the experiment in a dict.

        Args:
            from_config (bool): Also get optimizers that are stored in the :attr:`.config`
                attribute.

        Returns:
            dict: Dictionary of PyTorch optimizers

        """

        pyth_optimizers = dict()
        for key, val in self.__dict__.items():
            if isinstance(val, torch.optim.Optimizer):
                pyth_optimizers[key] = val
        if from_config:
            for key, val in self.config.items():
                if isinstance(val, torch.optim.Optimizer):
                    if type(key) == str:
                        key = "config." + key
                    pyth_optimizers[key] = val
        return pyth_optimizers

    def get_simple_variables(self, ignore=()):
        """
        Returns all standard variables in the experiment in a dict.
        Specifically, this looks for types :class:`int`, :class:`float`, :class:`bytes`,
        :class:`bool`, :class:`str`, :class:`set`, :class:`list`, :class:`tuple`.

        Args:
            ignore (list or tuple): Iterable of names which will be ignored

        Returns:
            dict: Dictionary of variables

        """

        simple_vars = dict()
        for key, val in self.__dict__.items():
            if key in ignore:
                continue
            if isinstance(val,
                          (int, float, bytes, bool, str, set, list, tuple)):
                simple_vars[key] = val
        return simple_vars

    def get_pytorch_tensors(self, ignore=()):
        """
        Returns all torch.tensors in the experiment in a dict.

        Args:
            ignore (list or tuple): Iterable of names which will be ignored

        Returns:
            dict: Dictionary of PyTorch tensor

        """

        pytorch_vars = dict()
        for key, val in self.__dict__.items():
            if key in ignore:
                continue
            if torch.is_tensor(val):
                pytorch_vars[key] = val
        return pytorch_vars

    def get_pytorch_variables(self, ignore=()):
        """Same as :meth:`.get_pytorch_tensors`."""
        return self.get_pytorch_tensors(ignore)

    def save_results(self, name="results.json"):
        """
        Saves the result dict as a json file in the result dir of the experiment logger.

        Args:
            name (str): The name of the json file in which the results are written.

        """
        if self.elog is None:
            return
        with open(os.path.join(self.elog.result_dir, name), "w") as file_:
            json.dump(self.results, file_, indent=4)

    def save_pytorch_models(self):
        """Saves all torch.nn.Modules as model files in the experiment checkpoint folder."""

        if self.elog is None:
            return

        pyth_modules = self.get_pytorch_modules()
        for key, val in pyth_modules.items():
            self.elog.save_model(val, key)

    def load_pytorch_models(self):
        """Loads all model files from the experiment checkpoint folder."""

        if self.elog is None:
            return
        pyth_modules = self.get_pytorch_modules()
        for key, val in pyth_modules.items():
            self.elog.load_model(val, key)

    def log_simple_vars(self):
        """
        Logs all simple python member variables as a json file in the experiment log folder.
        The file will be names 'simple_vars.json'.
        """

        if self.elog is None:
            return
        simple_vars = self.get_simple_variables()
        with open(os.path.join(self.elog.log_dir, "simple_vars.json"),
                  "w") as file_:
            json.dump(simple_vars, file_)

    def load_simple_vars(self):
        """
        Restores all simple python member variables from the 'simple_vars.json' file in the log
        folder.
        """

        if self.elog is None:
            return
        simple_vars = {}
        with open(os.path.join(self.elog.log_dir, "simple_vars.json"),
                  "r") as file_:
            simple_vars = json.load(file_)
        self.update_attributes(simple_vars)

    def save_checkpoint(self,
                        name="checkpoint",
                        save_types=("model", "optimizer", "simple", "th_vars",
                                    "results"),
                        n_iter=None,
                        iter_format="{:05d}",
                        prefix=False):
        """
        Saves a current model checkpoint from the experiment.

        Args:
            name (str): The name of the checkpoint file
            save_types (list or tuple): What kind of member variables should be stored? Choices are:
                "model" <-- Pytorch models,
                "optimizer" <-- Optimizers,
                "simple" <-- Simple python variables (basic types and lists/tuples),
                "th_vars" <-- torch tensors,
                "results" <-- The result dict
            n_iter (int): Number of iterations. Together with the name, defined by the iter_format,
                a file name will be created.
            iter_format (str): Defines how the name and the n_iter will be combined.
            prefix (bool): If True, the formatted n_iter will be prepended, otherwise appended.

        """

        if self.elog is None:
            return

        model_dict = {}
        optimizer_dict = {}
        simple_dict = {}
        th_vars_dict = {}
        results_dict = {}

        if "model" in save_types:
            model_dict = self.get_pytorch_modules()
        if "optimizer" in save_types:
            optimizer_dict = self.get_pytorch_optimizers()
        if "simple" in save_types:
            simple_dict = self.get_simple_variables()
        if "th_vars" in save_types:
            th_vars_dict = self.get_pytorch_variables()
        if "results" in save_types:
            results_dict = {"results": self.results}

        checkpoint_dict = {
            **model_dict,
            **optimizer_dict,
            **simple_dict,
            **th_vars_dict,
            **results_dict
        }

        self.elog.save_checkpoint(name=name,
                                  n_iter=n_iter,
                                  iter_format=iter_format,
                                  prefix=prefix,
                                  move_to_cpu=self._checkpoint_to_cpu,
                                  **checkpoint_dict)

    def load_checkpoint(self,
                        name="checkpoint",
                        save_types=("model", "optimizer", "simple", "th_vars",
                                    "results"),
                        n_iter=None,
                        iter_format="{:05d}",
                        prefix=False,
                        path=None):
        """
        Loads a checkpoint and restores the experiment.

        Make sure you have your torch stuff already on the right devices beforehand,
        otherwise this could lead to errors e.g. when making a optimizer step
        (and for some reason the Adam states are not already on the GPU:
        https://discuss.pytorch.org/t/loading-a-saved-model-for-continue-training/17244/3 )

        Args:
            name (str): The name of the checkpoint file
            save_types (list or tuple): What kind of member variables should be loaded? Choices are:
                "model" <-- Pytorch models,
                "optimizer" <-- Optimizers,
                "simple" <-- Simple python variables (basic types and lists/tuples),
                "th_vars" <-- torch tensors,
                "results" <-- The result dict
            n_iter (int): Number of iterations. Together with the name, defined by the iter_format,
                a file name will be created and searched for.
            iter_format (str): Defines how the name and the n_iter will be combined.
            prefix (bool): If True, the formatted n_iter will be prepended, otherwise appended.
            path (str): If no path is given then it will take the current experiment dir and formatted
                name, otherwise it will simply use the path and the formatted name to define the
                checkpoint file.

        """
        if self.elog is None:
            return

        model_dict = {}
        optimizer_dict = {}
        simple_dict = {}
        th_vars_dict = {}
        results_dict = {}

        if "model" in save_types:
            model_dict = self.get_pytorch_modules()
        if "optimizer" in save_types:
            optimizer_dict = self.get_pytorch_optimizers()
        if "simple" in save_types:
            simple_dict = self.get_simple_variables()
        if "th_vars" in save_types:
            th_vars_dict = self.get_pytorch_variables()
        if "results" in save_types:
            results_dict = {"results": self.results}

        checkpoint_dict = {
            **model_dict,
            **optimizer_dict,
            **simple_dict,
            **th_vars_dict,
            **results_dict
        }

        if n_iter is not None:
            name = name_and_iter_to_filename(name,
                                             n_iter,
                                             ".pth.tar",
                                             iter_format=iter_format,
                                             prefix=prefix)

        if path is None:
            restore_dict = self.elog.load_checkpoint(name=name,
                                                     **checkpoint_dict)
        else:
            checkpoint_path = os.path.join(path, name)
            if checkpoint_path.endswith("/"):
                checkpoint_path = checkpoint_path[:-1]
            restore_dict = self.elog.load_checkpoint_static(
                checkpoint_file=checkpoint_path, **checkpoint_dict)

        self.update_attributes(restore_dict)

    def _end_internal(self):
        """Ends the experiment and stores the final results/checkpoint"""
        if isinstance(self.results, ResultLogDict):
            self.results.close()
        self.save_results()
        self.save_end_checkpoint()
        self._save_exp_config()
        self.print("Experiment ended. Checkpoints stored =)")

    def _end_test_internal(self):
        """Ends the experiment after test and stores the final results and config"""
        self.save_results()
        self._save_exp_config()
        self.print("Testing ended. Results stored =)")

    def at_exit_func(self):
        """
        Stores the results and checkpoint at the end (if not already stored).
        This method is also called if an error occurs.
        """

        if self._exp_state not in ("Ended", "Tested"):
            if isinstance(self.results, ResultLogDict):
                self.results.print_to_file("]")
            self.save_checkpoint(name="checkpoint_exit-" + self._exp_state)
            self.save_results()
            self._save_exp_config()
            self.print("Experiment exited. Checkpoints stored =)")
        time.sleep(
            2
        )  # allow checkpoint saving to finish. We need a better solution for this :D

    def _setup_internal(self):
        self.prepare_resume()

        if self.elog is not None:
            self.elog.save_config(self._config_raw, "config")
        self._save_exp_config()

    def _start_internal(self):
        self._save_exp_config()

    def prepare_resume(self):
        """Tries to resume the experiment by using the defined resume path or PytorchExperiment."""

        checkpoint_file = ""
        base_dir = ""

        reset_epochs = self._resume_reset_epochs

        if self._resume_path is not None:
            if isinstance(self._resume_path, str):
                if self._resume_path.endswith(".pth.tar"):
                    checkpoint_file = self._resume_path
                    base_dir = os.path.dirname(
                        os.path.dirname(checkpoint_file))
                elif self._resume_path.endswith(
                        "checkpoint") or self._resume_path.endswith(
                            "checkpoint/"):
                    checkpoint_file = get_last_file(self._resume_path)
                    base_dir = os.path.dirname(
                        os.path.dirname(checkpoint_file))
                elif "checkpoint" in os.listdir(
                        self._resume_path) and "config" in os.listdir(
                            self._resume_path):
                    checkpoint_file = get_last_file(self._resume_path)
                    base_dir = self._resume_path
                else:
                    warnings.warn(
                        "You have not selected a valid experiment folder, will search all sub folders",
                        UserWarning)
                    if self.elog is not None:
                        self.elog.text_logger.log_to(
                            "You have not selected a valid experiment folder, will search all "
                            "sub folders", "warnings")
                    checkpoint_file = get_last_file(self._resume_path)
                    base_dir = os.path.dirname(
                        os.path.dirname(checkpoint_file))

        if base_dir:
            if not self._ignore_resume_config:
                load_config = Config()
                load_config.load(os.path.join(base_dir, "config/config.json"))
                self._config_raw = load_config
                self.config = Config.init_objects(self._config_raw)
                self.print("Loaded existing config from:", base_dir)

        if checkpoint_file:
            self.load_checkpoint(name="",
                                 path=checkpoint_file,
                                 save_types=self._resume_save_types)
            self._resume_path = checkpoint_file
            shutil.copyfile(
                checkpoint_file,
                os.path.join(self.elog.checkpoint_dir, "0_checkpoint.pth.tar"))
            self.print("Loaded existing checkpoint from:", checkpoint_file)

            self._resume_reset_epochs = reset_epochs
            if self._resume_reset_epochs:
                self._epoch_idx = 0

    def _end_epoch_internal(self, epoch):
        self.save_results()
        if epoch % self._safe_checkpoint_every_epoch == 0:
            self.save_temp_checkpoint()
        self._save_exp_config()

    def _save_exp_config(self):

        if self.elog is not None:
            cur_time = time.strftime("%y-%m-%d_%H:%M:%S",
                                     time.localtime(time.time()))
            self.elog.save_config(
                Config(
                    **{
                        'name': self.exp_name,
                        'time': self._time_start,
                        'state': self._exp_state,
                        'current_time': cur_time,
                        'epoch': self._epoch_idx
                    }), "exp")

    def save_temp_checkpoint(self):
        """Saves the current checkpoint as checkpoint_current."""
        self.save_checkpoint(name="checkpoint_current")

    def save_end_checkpoint(self):
        """Saves the current checkpoint as checkpoint_last."""
        self.save_checkpoint(name="checkpoint_last")

    def add_result(self,
                   value,
                   name,
                   counter=None,
                   tag=None,
                   label=None,
                   plot_result=True,
                   plot_running_mean=False):
        """
        Saves a results and add it to the result dict, this is similar to results[key] = val,
        but in addition also logs the value to the combined logger
        (it also stores in the results-logs file).

        **This should be your preferred method to log your numeric values**

        Args:
            value: The value of your variable
            name (str): The name/key of your variable
            counter (int or float): A counter which can be seen as the x-axis of your value.
                Normally you would just use the current epoch for this.
            tag (str): A label/tag which can group similar values and will plot values with the same
                label in the same plot
            label: deprecated label
            plot_result (bool): By default True, will also log all your values to the combined
                logger (with show_value).

        """

        if label is not None:
            warnings.warn(
                "label in add_result is deprecated, please use tag instead")

            if tag is None:
                tag = label

        tag_name = tag
        if tag_name is None:
            tag_name = name

        r_elem = ResultElement(data=value,
                               label=tag_name,
                               epoch=self._epoch_idx,
                               counter=counter)

        self.results[name] = r_elem

        if plot_result:
            if tag is None:
                legend = False
            else:
                legend = True
            if plot_running_mean:
                value = np.mean(self.results.running_mean_dict[name])
            self.clog.show_value(value=value,
                                 name=name,
                                 tag=tag_name,
                                 counter=counter,
                                 show_legend=legend)

    def get_result(self, name):
        """
        Similar to result[key] this will return the values in the results dictionary with the given
        name/key.

        Args:
            name (str): the name/key for which a value is stored.

        Returns:
            The value with the key 'name' in the results dict.

        """
        return self.results.get(name)

    def add_result_without_epoch(self, val, name):
        """
        A faster method to store your results, has less overhead and does not call the combined
        logger. Will only store to the results dictionary.

        Args:
            val: the value you want to add.
            name (str): the name/key of your value.

        """
        self.results[name] = val

    def get_result_without_epoch(self, name):
        """
        Similar to result[key] this will return the values in result with the given name/key.

        Args:
            name (str): the name/ key for which a value is stores.

        Returns:
            The value with the key 'name' in the results dict.

        """
        return self.results.get(name)

    def print(self, *args):
        """
        Calls 'print' on the experiment logger or uses builtin 'print' if former is not
        available.
        """

        if self.elog is None:
            print(*args)
        else:
            self.elog.print(*args)
Esempio n. 24
0
def test_MNIST_experiment():
    c = Config()

    c.batch_size = 64
    c.n_epochs = 50
    c.learning_rate = 0.001
    c.weight_decay = 5e-5
    if torch.cuda.is_available():
        c.use_cuda = True
    else:
        c.use_cuda = False
    c.rnd_seed = 1
    c.log_interval = 100
    c.subset_size = 10
    # model-specific
    c.n_coupling = 8
    c.n_filters = 64

    exp = MNISTExperiment(
        c,
        name='mnist_test',
        n_epochs=c.n_epochs,
        seed=42,
        base_dir='experiment_dir',
        loggers={'visdom': ['visdom', {
            "exp_name": "myenv"
        }]})

    exp.run()

    exp.model.eval()
    exp.model.to('cpu')
    with torch.no_grad():
        samples = exp.model.sample(16, device='cpu')
        img_grid = make_grid(samples).permute((1, 2, 0))
    plt.imshow(img_grid)
    plt.show()
    return exp.model
Esempio n. 25
0
def run_experiment(experiment, configs, args, mods=None, **kwargs):

    config = Config(file_=args.config) if args.config is not None else Config()
    config.update_missing(configs[args.default_config])
    if args.mods is not None:
        for mod in args.mods:
            config.update(mods[mod])
    config = Config(config=config, update_from_argv=True)

    # GET EXISTING EXPERIMENTS TO BE ABLE TO SKIP CERTAIN CONFIGS
    if args.skip_existing:
        existing_configs = []
        for exp in os.listdir(args.base_dir):
            try:
                existing_configs.append(
                    Config(file_=os.path.join(args.base_dir, exp, "config",
                                              "config.json")))
            except Exception as e:
                pass

    if args.grid is not None:
        grid = GridSearch().read(args.grid)
    else:
        grid = [{}]

    for combi in grid:

        config.update(combi)

        if args.skip_existing:
            skip_this = False
            for existing_config in existing_configs:
                if existing_config.contains(config):
                    skip_this = True
                    break
            if skip_this:
                continue

        loggers = {}
        if args.visdomlogger:
            loggers["visdom"] = ("visdom", {}, 1)

        exp = experiment(config=config,
                         base_dir=args.base_dir,
                         resume=args.resume,
                         ignore_resume_config=args.ignore_resume_config,
                         loggers=loggers,
                         **kwargs)

        if not args.test:
            exp.run()
        else:
            exp.run_test()
Esempio n. 26
0
import numpy as np
import torch
from trixi.util import Config
from experiment import MNISTexperiment
from util import plot_dependency_map
import matplotlib.pyplot as plt

c = Config()
c.batch_size = 128
c.n_epochs = 10
c.learning_rate = 0.001
if torch.cuda.is_available():
    c.use_cuda = True
else:
    c.use_cuda = False
c.rnd_seed = 1
c.log_interval = 100

exp = MNISTexperiment(config=c,
                      name='test',
                      n_epochs=c.n_epochs,
                      seed=c.rnd_seed,
                      base_dir='./experiment_dir',
                      loggers={"visdom": ["visdom", {
                          "exp_name": "myenv"
                      }]})

# # run backpropagation for each dimension to compute what other
# # dimensions it depends on.
# exp.setup()
# d = 28
Esempio n. 27
0
def test_Resnet():
    c = Config()

    c.batch_size = 64
    c.batch_size_test = 1000
    c.n_epochs = 10
    c.learning_rate = 0.01
    c.momentum = 0.9
    if torch.cuda.is_available():
        c.use_cuda = True
    else:
        c.use_cuda = False
    c.rnd_seed = 1
    c.log_interval = 200

    exp = MNIST_classification(config=c,
                               name='experiment',
                               n_epochs=c.n_epochs,
                               seed=42,
                               base_dir='./experiment_dir',
                               loggers={"visdom": "visdom"})

    exp.run()
Esempio n. 28
0
def process_base_dir(base_dir,
                     view_dir="",
                     default_val="-",
                     short_len=25,
                     ignore_keys=IGNORE_KEYS):
    """Create an overview table of all experiments in the given directory.

    Args:
        directory (str): A directory containing experiment folders.
        default_val (str): Default value if an entry is missing.
        short_len (int): Cut strings to this length. Full string in alt-text.

    Returns:
        dict: {"ccols": Columns for config entries,
               "rcols": Columns for result entries,
               "rows": The actual data}

    """

    full_dir = os.path.join(base_dir, view_dir)

    config_keys = set()
    result_keys = set()
    exps = []
    non_exps = []

    ### Load Experiments with keys / different param values
    for sub_dir in sorted(os.listdir(full_dir)):
        dir_path = os.path.join(full_dir, sub_dir)
        if os.path.isdir(dir_path):
            try:
                exp = ExperimentReader(full_dir, sub_dir)
                if exp.ignore:
                    continue
                config_keys.update(list(exp.config.flat().keys()))
                result_keys.update(list(exp.get_results().keys()))
                exps.append(exp)
            except Exception as e:
                print("Could not load experiment: ", dir_path)
                print(e)
                print("-" * 20)
                non_exps.append(os.path.join(view_dir, sub_dir))

    ### Get not common val keys
    diff_keys = list(
        Config.difference_config_static(*[xp.config for xp in exps]).flat())

    ### Remove unwanted keys
    config_keys -= set(ignore_keys)
    result_keys -= set(ignore_keys)

    ### Generate table rows
    sorted_c_keys1 = sorted([c for c in config_keys if c in diff_keys],
                            key=lambda x: str(x).lower())
    sorted_c_keys2 = sorted([c for c in config_keys if c not in diff_keys],
                            key=lambda x: str(x).lower())
    sorted_r_keys = sorted(result_keys, key=lambda x: str(x).lower())

    rows = []
    for exp in exps:
        config_row = []
        for key in sorted_c_keys1:
            attr_strng = str(exp.config.flat().get(key, default_val))
            config_row.append((attr_strng, attr_strng[:short_len]))
        for key in sorted_c_keys2:
            attr_strng = str(exp.config.flat().get(key, default_val))
            config_row.append((attr_strng, attr_strng[:short_len]))
        result_row = []
        for key in sorted_r_keys:
            attr_strng = str(exp.get_results().get(key, default_val))
            result_row.append((attr_strng, attr_strng[:short_len]))

        name = exp.exp_name
        time = exp.exp_info.get(
            "time", default_val) if "time" in exp.exp_info else exp.config.get(
                "time", default_val)
        state = exp.exp_info.get(
            "state",
            default_val) if "state" in exp.exp_info else exp.config.get(
                "state", default_val)
        epoch = exp.exp_info.get(
            "epoch",
            default_val) if "epoch" in exp.exp_info else exp.config.get(
                "epoch", default_val)

        rows.append((os.path.relpath(exp.work_dir,
                                     base_dir), exp.star, str(name), str(time),
                     str(state), str(epoch), config_row, result_row))

    return {
        "ccols1": sorted_c_keys1,
        "ccols2": sorted_c_keys2,
        "rcols": sorted_r_keys,
        "rows": rows,
        "noexp": non_exps
    }
Esempio n. 29
0
def make_defaults():

    DEFAULTS = Config(

        # Base
        name="gqn",
        description=DESCRIPTION,
        n_epochs=1000000,
        batch_size=36,
        batch_size_val=36,
        seed=1,
        device="cuda",

        # Data
        split_val=3,  # index for set of 5
        split_test=4,  # index for set of 5
        data_module=loader,
        dataset="shepard_metzler_5_parts",
        data_dir=None,  # will be set for data_module if not None
        debug=
        0,  # 1 for single repeating batch, 2 for single viewpoint (i.e. reconstruct known images)
        generator_train=loader.RandomBatchGenerator,
        generator_val=loader.LinearBatchGenerator,
        num_viewpoints_val=8,  # use this many viewpoints in validation
        shuffle_viewpoints_val=False,
        augmenter=MultiThreadedAugmenter,
        augmenter_kwargs={"num_processes": 8},

        # Model
        model=GenerativeQueryNetwork,
        model_kwargs={
            "in_channels": 3,
            "query_channels": 7,
            "r_channels": 256,
            "encoder_kwargs": {
                "activation_op": nn.ReLU
            },
            "decoder_kwargs": {
                "z_channels": 64,
                "h_channels": 128,
                "scale": 4,
                "core_repeat": 12
            }
        },
        model_init_weights_args=None,  # e.g. [nn.init.kaiming_normal_, 1e-2],
        model_init_bias_args=None,  # e.g. [nn.init.constant_, 0],

        # Learning
        optimizer=optim.Adam,
        optimizer_kwargs={"weight_decay": 1e-5},
        lr_initial=5e-4,
        lr_final=5e-5,
        lr_cutoff=16e4,  # lr is increased linearly in cutoff epochs
        sigma_initial=2.0,
        sigma_final=0.7,
        sigma_cutoff=2e4,  # sigma is increased linearly in cutoff epochs
        kl_weight_initial=0.05,
        kl_weight_final=1.0,
        kl_weight_cutoff=1e5,  # kl_weight is increased linearly in cutoff epochs
        nll_weight=1.0,

        # Logging
        backup_every=10000,
        validate_every=1000,
        validate_subset=0.01,  # validate only this percentage randomly
        show_every=100,
        val_example_samples=
        10,  # draw this many random samples for last validation item
        test_on_val=True,  # test on the validation set
    )

    SHAREDCORES = Config(
        model_kwargs={"decoder_kwargs": {
            "core_shared": True
        }})

    MODS = {"SHAREDCORES": SHAREDCORES}

    return {"DEFAULTS": DEFAULTS}, MODS
Esempio n. 30
0
def get_config():
    # Set your own path, if needed.
    data_root_dir = os.path.abspath(
        'data')  # The path where the downloaded dataset is stored.

    c = Config(
        update_from_argv=True,

        # Train parameters
        num_classes=8,
        in_channels=1,
        batch_size=8,
        patch_size=64,
        n_epochs=50,
        learning_rate=0.0002,
        fold=
        0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.
        device=
        "cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Basic_Unet',
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=False,
        start_visdom=True,
        do_instancenorm=
        True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir='',

        # Adapt to your own path, if needed.
        google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
        dataset_name='CHD_segmentation_dataset',
        base_dir=os.path.abspath(
            'output_experiment'),  # Where to log the output of the experiment.
        data_root_dir=
        data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(
            data_root_dir, 'CHD_segmentation_dataset/preprocessed'
        ),  # This is where your training and validation data is stored
        data_test_dir=os.path.join(data_root_dir,
                                   'CHD_segmentation_dataset/preprocessed'
                                   ),  # This is where your test data is stored
        split_dir=os.path.join(
            data_root_dir, 'CHD_segmentation_dataset'
        ),  # This is where the 'splits.pkl' file is located, that holds your splits.
        scaled_image_16_dir=os.path.join(
            data_root_dir, 'CHD_segmentation_dataset/scaled_to_16'),
        scaled_image_32_dir=os.path.join(
            data_root_dir, 'CHD_segmentation_dataset/scaled_to_32'),
        scaled_image_64_dir=os.path.join(
            data_root_dir, 'CHD_segmentation_dataset/scaled_to_64'),
        stage_1_dir=os.path.join(data_root_dir,
                                 'CHD_segmentation_dataset/stage_1'),
        stage_1_dir_32=os.path.join(data_root_dir,
                                    'CHD_segmentation_dataset/stage_1_32'),
        # stage_1_dir = os.path.join(data_root_dir, 'CHD_segmentation_dataset/stage_1')
    )

    print(c)
    return c