Esempio n. 1
0
def get_input_data(eval_root):
    csv_path = os.path.join(eval_root, 'model_output.csv')
    meta_path = os.path.join(eval_root, 'meta.yaml')

    if os.path.exists(csv_path):
        with open(csv_path, 'r') as cf:
            yaml_string = ''
            for line in cf.readlines():
                if "# " in line:
                    yaml_string += line[2:] + "\n"
                else:
                    break
        config = yaml.full_load(yaml_string)

    elif os.path.exists(meta_path):
        config = yaml.full_load(meta_path)

    else:
        raise ValueError(
            'eval_root must point to a folder containing a csv of meta')

    impl = get_obj_from_str(config["dataset"])
    in_data = impl(config)

    return in_data
Esempio n. 2
0
    def __init__(self, config):
        self.dataset = retrieve(config, "RandomlyJoinedDataset/dataset")
        self.dataset = get_obj_from_str(self.dataset)
        self.dataset = self.dataset(config)
        self.key = retrieve(config, "RandomlyJoinedDataset/key")
        self.n_joins = retrieve(config,
                                "RandomlyJoinedDataset/n_joins",
                                default=2)

        self.test_mode = retrieve(config, "test_mode", default=False)
        self.avoid_identity = retrieve(config,
                                       "RandomlyJoinedDataset/avoid_identity",
                                       default=True)
        self.balance = retrieve(config,
                                "RandomlyJoinedDataset/balance",
                                default=False)

        # self.index_map is used to select a partner for each example.
        # In test_mode it is a list containing a single partner index for each
        # example, otherwise it is a dict containing all indices for a given
        # join label
        self.join_labels = np.asarray(self.dataset.labels[self.key])
        unique_labels = np.unique(self.join_labels)
        self.index_map = dict()
        for value in unique_labels:
            self.index_map[value] = np.nonzero(self.join_labels == value)[0]
        if self.test_mode:
            prng = np.random.RandomState(0)
            self.index_map = [
                prng.choice(self.index_map[self.join_labels[i]],
                            self.n_joins - 1) for i in range(len(self.dataset))
            ]
Esempio n. 3
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     # wrap save and restore into a LambdaCheckpointHook
     self.ckpthook = LambdaCheckpointHook(
         root_path=ProjectManager.checkpoints,
         global_step_getter=self.get_global_step,
         global_step_setter=self.set_global_step,
         save=self.save,
         restore=self.restore,
         interval=set_default(self.config, "ckpt_freq", None),
     )
     if not self.config.get("test_mode", False):
         # in training, excute train ops and add logginghook
         self._train_ops = set_default(
             self.config, "train_ops", ["step_ops/train_op"]
         )
         self._log_ops = set_default(self.config, "log_ops", ["step_ops/log_op"])
         # logging
         self.loghook = LoggingHook(
             paths=self._log_ops, root_path=ProjectManager.train, interval=1
         )
         # wrap it in interval hook
         self.ihook = IntervalHook(
             [self.loghook],
             interval=set_default(self.config, "start_log_freq", 1),
             modify_each=1,
             max_interval=set_default(self.config, "log_freq", 1000),
             get_step=self.get_global_step,
         )
         self.hooks.append(self.ihook)
         # write checkpoints after epoch or when interrupted
         self.hooks.append(self.ckpthook)
     else:
         # evaluate
         self._eval_op = set_default(
             self.config, "eval_hook/eval_op", "step_ops/eval_op"
         )
         self._eval_callbacks = set_default(
             self.config, "eval_hook/eval_callbacks", list()
         )
         if not isinstance(self._eval_callbacks, list):
             self._eval_callbacks = [self._eval_callbacks]
         self._eval_callbacks = [
             get_obj_from_str(name) for name in self._eval_callbacks
         ]
         label_key = set_default(
             self.config, "eval_hook/label_key", "step_ops/eval_op/labels"
         )
         self.evalhook = TemplateEvalHook(
             dataset=self.dataset,
             step_getter=self.get_global_step,
             keypath=self._eval_op,
             meta=self.config,
             callbacks=self._eval_callbacks,
             label_key=label_key,
         )
         self.hooks.append(self.evalhook)
         self._train_ops = []
         self._log_ops = []
Esempio n. 4
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # wrap save and restore into a LambdaCheckpointHook
        self.ckpthook = LambdaCheckpointHook(
            root_path=ProjectManager.checkpoints,
            global_step_getter=self.get_global_step,
            global_step_setter=self.set_global_step,
            save=self.save,
            restore=self.restore,
            interval=set_default(self.config, "ckpt_freq", None),
        )
        if not self.config.get("test_mode", False):
            # in training, excute train ops and add logginghook for train and
            # validation batches
            self._train_ops = set_default(self.config, "train_ops",
                                          ["step_ops/train_op"])
            self._log_ops = set_default(self.config, "log_ops",
                                        ["step_ops/log_op"])
            # logging
            self.loghook = LoggingHook(
                paths=self._log_ops,
                root_path=ProjectManager.train,
                interval=1,
                name="train",
            )
            # wrap it in interval hook
            self.ihook = IntervalHook(
                [self.loghook],
                interval=set_default(self.config, "start_log_freq", 1),
                modify_each=1,
                max_interval=set_default(self.config, "log_freq", 1000),
                get_step=self.get_global_step,
            )
            self.hooks.append(self.ihook)
            # validation logging
            self._validation_log_ops = set_default(self.config,
                                                   "validation_log_ops",
                                                   ["validation_ops/log_op"])
            self._validation_root = os.path.join(ProjectManager.train,
                                                 "validation")
            os.makedirs(self._validation_root, exist_ok=True)
            # logging
            self.validation_loghook = LoggingHook(
                paths=self._validation_log_ops,
                root_path=self._validation_root,
                interval=1,
                name="validation",
            )
            self.hooks.append(self.validation_loghook)
            # write checkpoints after epoch or when interrupted
            self.hooks.append(self.ckpthook)
            wandb_logging = set_default(self.config, "integrations/wandb",
                                        False)
            if wandb_logging:
                import wandb
                from edflow.hooks.logging_hooks.wandb_handler import log_wandb

                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = ProjectManager.root.replace(
                    "/", "-")
                wandb.init(name=ProjectManager.root, config=self.config)
                self.loghook.handlers["scalars"].append(log_wandb)
                self.validation_loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_wandb(
                        *args, **kwargs, prefix="validation"))
            tensorboardX_logging = set_default(self.config,
                                               "integrations/tensorboardX",
                                               False)
            if tensorboardX_logging:
                from tensorboardX import SummaryWriter
                from edflow.hooks.logging_hooks.tensorboardX_handler import (
                    log_tensorboard_config,
                    log_tensorboard_scalars,
                )

                self.tensorboardX_writer = SummaryWriter(ProjectManager.root)
                log_tensorboard_config(self.tensorboardX_writer, self.config,
                                       self.get_global_step())
                self.loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_tensorboard_scalars(
                        self.tensorboardX_writer, *args, **kwargs))
                self.validation_loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_tensorboard_scalars(
                        self.tensorboardX_writer,
                        *args,
                        **kwargs,
                        prefix="validation"))

        else:
            # evaluate
            self._eval_op = set_default(self.config, "eval_hook/eval_op",
                                        "step_ops/eval_op")
            self._eval_callbacks = set_default(self.config,
                                               "eval_hook/eval_callbacks",
                                               dict())
            if not isinstance(self._eval_callbacks, dict):
                self._eval_callbacks = {"cb": self._eval_callbacks}
            for k in self._eval_callbacks:
                self._eval_callbacks[k] = get_obj_from_str(
                    self._eval_callbacks[k])
            label_key = set_default(self.config, "eval_hook/label_key",
                                    "step_ops/eval_op/labels")
            self.evalhook = TemplateEvalHook(
                dataset=self.dataset,
                step_getter=self.get_global_step,
                keypath=self._eval_op,
                config=self.config,
                callbacks=self._eval_callbacks,
                labels_key=label_key,
            )
            self.hooks.append(self.evalhook)
            self._train_ops = []
            self._log_ops = []