def get_input_data(eval_root): csv_path = os.path.join(eval_root, 'model_output.csv') meta_path = os.path.join(eval_root, 'meta.yaml') if os.path.exists(csv_path): with open(csv_path, 'r') as cf: yaml_string = '' for line in cf.readlines(): if "# " in line: yaml_string += line[2:] + "\n" else: break config = yaml.full_load(yaml_string) elif os.path.exists(meta_path): config = yaml.full_load(meta_path) else: raise ValueError( 'eval_root must point to a folder containing a csv of meta') impl = get_obj_from_str(config["dataset"]) in_data = impl(config) return in_data
def __init__(self, config): self.dataset = retrieve(config, "RandomlyJoinedDataset/dataset") self.dataset = get_obj_from_str(self.dataset) self.dataset = self.dataset(config) self.key = retrieve(config, "RandomlyJoinedDataset/key") self.n_joins = retrieve(config, "RandomlyJoinedDataset/n_joins", default=2) self.test_mode = retrieve(config, "test_mode", default=False) self.avoid_identity = retrieve(config, "RandomlyJoinedDataset/avoid_identity", default=True) self.balance = retrieve(config, "RandomlyJoinedDataset/balance", default=False) # self.index_map is used to select a partner for each example. # In test_mode it is a list containing a single partner index for each # example, otherwise it is a dict containing all indices for a given # join label self.join_labels = np.asarray(self.dataset.labels[self.key]) unique_labels = np.unique(self.join_labels) self.index_map = dict() for value in unique_labels: self.index_map[value] = np.nonzero(self.join_labels == value)[0] if self.test_mode: prng = np.random.RandomState(0) self.index_map = [ prng.choice(self.index_map[self.join_labels[i]], self.n_joins - 1) for i in range(len(self.dataset)) ]
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # wrap save and restore into a LambdaCheckpointHook self.ckpthook = LambdaCheckpointHook( root_path=ProjectManager.checkpoints, global_step_getter=self.get_global_step, global_step_setter=self.set_global_step, save=self.save, restore=self.restore, interval=set_default(self.config, "ckpt_freq", None), ) if not self.config.get("test_mode", False): # in training, excute train ops and add logginghook self._train_ops = set_default( self.config, "train_ops", ["step_ops/train_op"] ) self._log_ops = set_default(self.config, "log_ops", ["step_ops/log_op"]) # logging self.loghook = LoggingHook( paths=self._log_ops, root_path=ProjectManager.train, interval=1 ) # wrap it in interval hook self.ihook = IntervalHook( [self.loghook], interval=set_default(self.config, "start_log_freq", 1), modify_each=1, max_interval=set_default(self.config, "log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(self.ihook) # write checkpoints after epoch or when interrupted self.hooks.append(self.ckpthook) else: # evaluate self._eval_op = set_default( self.config, "eval_hook/eval_op", "step_ops/eval_op" ) self._eval_callbacks = set_default( self.config, "eval_hook/eval_callbacks", list() ) if not isinstance(self._eval_callbacks, list): self._eval_callbacks = [self._eval_callbacks] self._eval_callbacks = [ get_obj_from_str(name) for name in self._eval_callbacks ] label_key = set_default( self.config, "eval_hook/label_key", "step_ops/eval_op/labels" ) self.evalhook = TemplateEvalHook( dataset=self.dataset, step_getter=self.get_global_step, keypath=self._eval_op, meta=self.config, callbacks=self._eval_callbacks, label_key=label_key, ) self.hooks.append(self.evalhook) self._train_ops = [] self._log_ops = []
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # wrap save and restore into a LambdaCheckpointHook self.ckpthook = LambdaCheckpointHook( root_path=ProjectManager.checkpoints, global_step_getter=self.get_global_step, global_step_setter=self.set_global_step, save=self.save, restore=self.restore, interval=set_default(self.config, "ckpt_freq", None), ) if not self.config.get("test_mode", False): # in training, excute train ops and add logginghook for train and # validation batches self._train_ops = set_default(self.config, "train_ops", ["step_ops/train_op"]) self._log_ops = set_default(self.config, "log_ops", ["step_ops/log_op"]) # logging self.loghook = LoggingHook( paths=self._log_ops, root_path=ProjectManager.train, interval=1, name="train", ) # wrap it in interval hook self.ihook = IntervalHook( [self.loghook], interval=set_default(self.config, "start_log_freq", 1), modify_each=1, max_interval=set_default(self.config, "log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(self.ihook) # validation logging self._validation_log_ops = set_default(self.config, "validation_log_ops", ["validation_ops/log_op"]) self._validation_root = os.path.join(ProjectManager.train, "validation") os.makedirs(self._validation_root, exist_ok=True) # logging self.validation_loghook = LoggingHook( paths=self._validation_log_ops, root_path=self._validation_root, interval=1, name="validation", ) self.hooks.append(self.validation_loghook) # write checkpoints after epoch or when interrupted self.hooks.append(self.ckpthook) wandb_logging = set_default(self.config, "integrations/wandb", False) if wandb_logging: import wandb from edflow.hooks.logging_hooks.wandb_handler import log_wandb os.environ["WANDB_RESUME"] = "allow" os.environ["WANDB_RUN_ID"] = ProjectManager.root.replace( "/", "-") wandb.init(name=ProjectManager.root, config=self.config) self.loghook.handlers["scalars"].append(log_wandb) self.validation_loghook.handlers["scalars"].append( lambda *args, **kwargs: log_wandb( *args, **kwargs, prefix="validation")) tensorboardX_logging = set_default(self.config, "integrations/tensorboardX", False) if tensorboardX_logging: from tensorboardX import SummaryWriter from edflow.hooks.logging_hooks.tensorboardX_handler import ( log_tensorboard_config, log_tensorboard_scalars, ) self.tensorboardX_writer = SummaryWriter(ProjectManager.root) log_tensorboard_config(self.tensorboardX_writer, self.config, self.get_global_step()) self.loghook.handlers["scalars"].append( lambda *args, **kwargs: log_tensorboard_scalars( self.tensorboardX_writer, *args, **kwargs)) self.validation_loghook.handlers["scalars"].append( lambda *args, **kwargs: log_tensorboard_scalars( self.tensorboardX_writer, *args, **kwargs, prefix="validation")) else: # evaluate self._eval_op = set_default(self.config, "eval_hook/eval_op", "step_ops/eval_op") self._eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks", dict()) if not isinstance(self._eval_callbacks, dict): self._eval_callbacks = {"cb": self._eval_callbacks} for k in self._eval_callbacks: self._eval_callbacks[k] = get_obj_from_str( self._eval_callbacks[k]) label_key = set_default(self.config, "eval_hook/label_key", "step_ops/eval_op/labels") self.evalhook = TemplateEvalHook( dataset=self.dataset, step_getter=self.get_global_step, keypath=self._eval_op, config=self.config, callbacks=self._eval_callbacks, labels_key=label_key, ) self.hooks.append(self.evalhook) self._train_ops = [] self._log_ops = []