def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # wrap save and restore into a LambdaCheckpointHook self.ckpthook = LambdaCheckpointHook( root_path=ProjectManager.checkpoints, global_step_getter=self.get_global_step, global_step_setter=self.set_global_step, save=self.save, restore=self.restore, interval=set_default(self.config, "ckpt_freq", None), ) if not self.config.get("test_mode", False): # in training, excute train ops and add logginghook self._train_ops = set_default( self.config, "train_ops", ["step_ops/train_op"] ) self._log_ops = set_default(self.config, "log_ops", ["step_ops/log_op"]) # logging self.loghook = LoggingHook( paths=self._log_ops, root_path=ProjectManager.train, interval=1 ) # wrap it in interval hook self.ihook = IntervalHook( [self.loghook], interval=set_default(self.config, "start_log_freq", 1), modify_each=1, max_interval=set_default(self.config, "log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(self.ihook) # write checkpoints after epoch or when interrupted self.hooks.append(self.ckpthook) else: # evaluate self._eval_op = set_default( self.config, "eval_hook/eval_op", "step_ops/eval_op" ) self._eval_callbacks = set_default( self.config, "eval_hook/eval_callbacks", list() ) if not isinstance(self._eval_callbacks, list): self._eval_callbacks = [self._eval_callbacks] self._eval_callbacks = [ get_obj_from_str(name) for name in self._eval_callbacks ] label_key = set_default( self.config, "eval_hook/label_key", "step_ops/eval_op/labels" ) self.evalhook = TemplateEvalHook( dataset=self.dataset, step_getter=self.get_global_step, keypath=self._eval_op, meta=self.config, callbacks=self._eval_callbacks, label_key=label_key, ) self.hooks.append(self.evalhook) self._train_ops = [] self._log_ops = []
def _init_step_ops(self): # additional inputs self.pid_placeholder = tf.placeholder(tf.string, shape=[None]) # loss endpoints = self.model.embeddings dists = loss.cdist(endpoints['emb'], endpoints['emb'], metric=self.config.get("metric", "euclidean")) losses, train_top1, prec_at_k, _, neg_dists, pos_dists = ( loss.LOSS_CHOICES["batch_hard"]( dists, self.pid_placeholder, self.config.get("margin", "soft"), batch_precision_at_k=self.config.get("n_views", 4) - 1)) # Count the number of active entries, and compute the total batch loss. loss_mean = tf.reduce_mean(losses) # train op learning_rate = self.config.get("learning_rate", 3e-4) self.logger.info( "Training with learning rate: {}".format(learning_rate)) optimizer = tf.train.AdamOptimizer(learning_rate) with tf.control_dependencies(tf.get_collection( tf.GraphKeys.UPDATE_OPS)): train_op = optimizer.minimize(loss_mean) self._step_ops = train_op tolog = { "loss": loss_mean, "top1": train_top1, "prec@{}".format(self.config.get("n_views", 4) - 1): prec_at_k } loghook = LoggingHook(logs=tolog, scalars=tolog, images={"image": self.model.inputs["image"]}, root_path=ProjectManager().train, interval=1) ckpt_hook = CheckpointHook(root_path=ProjectManager().checkpoints, variables=tf.global_variables(), modelname=self.model.name, step=self.get_global_step, interval=self.config.get("ckpt_freq", 1000), max_to_keep=None) self.hooks.append(ckpt_hook) ihook = IntervalHook([loghook], interval=1, modify_each=1, max_interval=self.config.get("log_freq", 1000)) self.hooks.append(ihook)
def __init__(self, *args, **kwargs): super(Trainer, self).__init__(*args, **kwargs) overviewHook = ImageOverviewHook(images=self.img_ops, root_path=ProjectManager.train, interval=1) ihook = IntervalHook( [overviewHook], interval=1, modify_each=1, max_interval=self.config.get("log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(ihook)
def setup(self): """Init train_placeholders, log_ops and img_ops which can be added to.""" self.train_placeholders = dict() self.log_ops = dict() self.img_ops = dict() self.update_ops = list() self.create_train_op() ckpt_hook = CheckpointHook( root_path=ProjectManager.checkpoints, variables=self.get_checkpoint_variables(), modelname="model", step=self.get_global_step, interval=self.config.get("ckpt_freq", None), max_to_keep=self.config.get("ckpt_keep", None), ) self.hooks.append(ckpt_hook) loghook = LoggingHook( logs=self.log_ops, scalars=self.log_ops, images=self.img_ops, root_path=ProjectManager.train, interval=1, log_images_to_tensorboard=self.config.get( "log_images_to_tensorboard", False ), ) ihook = IntervalHook( [loghook], interval=self.config.get("start_log_freq", 1), modify_each=1, max_interval=self.config.get("log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(ihook)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # wrap save and restore into a LambdaCheckpointHook self.ckpthook = LambdaCheckpointHook( root_path=ProjectManager.checkpoints, global_step_getter=self.get_global_step, global_step_setter=self.set_global_step, save=self.save, restore=self.restore, interval=set_default(self.config, "ckpt_freq", None), ckpt_zero=set_default(self.config, "ckpt_zero", False), ) # write checkpoints after epoch or when interrupted during training if not self.config.get("test_mode", False): self.hooks.append(self.ckpthook) ## hooks - disabled unless -t is specified # execute train ops self._train_ops = set_default(self.config, "train_ops", ["train/train_op"]) train_hook = ExpandHook(paths=self._train_ops, interval=1) self.hooks.append(train_hook) # log train/step_ops/log_ops in increasing intervals self._log_ops = set_default(self.config, "log_ops", ["train/log_op", "validation/log_op"]) self.loghook = LoggingHook(paths=self._log_ops, root_path=ProjectManager.train, interval=1) self.ihook = IntervalHook( [self.loghook], interval=set_default(self.config, "start_log_freq", 1), modify_each=1, max_interval=set_default(self.config, "log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(self.ihook) # setup logging integrations if not self.config.get("test_mode", False): default_wandb_logging = { "active": False, "handlers": ["scalars", "images"] } wandb_logging = set_default(self.config, "integrations/wandb", default_wandb_logging) if wandb_logging["active"]: import wandb from edflow.hooks.logging_hooks.wandb_handler import ( log_wandb, log_wandb_images, ) os.environ["WANDB_RESUME"] = "allow" os.environ["WANDB_RUN_ID"] = ProjectManager.root.strip( "/").replace("/", "-") wandb_project = set_default(self.config, "integrations/wandb/project", None) wandb_entity = set_default(self.config, "integrations/wandb/entity", None) wandb.init( name=ProjectManager.root, config=self.config, project=wandb_project, entity=wandb_entity, ) handlers = set_default( self.config, "integrations/wandb/handlers", default_wandb_logging["handlers"], ) if "scalars" in handlers: self.loghook.handlers["scalars"].append(log_wandb) if "images" in handlers: self.loghook.handlers["images"].append(log_wandb_images) default_tensorboard_logging = { "active": False, "handlers": ["scalars", "images", "figures"], } tensorboard_logging = set_default(self.config, "integrations/tensorboard", default_tensorboard_logging) if tensorboard_logging["active"]: try: from torch.utils.tensorboard import SummaryWriter except: from tensorboardX import SummaryWriter from edflow.hooks.logging_hooks.tensorboard_handler import ( log_tensorboard_config, log_tensorboard_scalars, log_tensorboard_images, log_tensorboard_figures, ) self.tensorboard_writer = SummaryWriter(ProjectManager.root) log_tensorboard_config(self.tensorboard_writer, self.config, self.get_global_step()) handlers = set_default( self.config, "integrations/tensorboard/handlers", default_tensorboard_logging["handlers"], ) if "scalars" in handlers: self.loghook.handlers["scalars"].append( lambda *args, **kwargs: log_tensorboard_scalars( self.tensorboard_writer, *args, **kwargs)) if "images" in handlers: self.loghook.handlers["images"].append( lambda *args, **kwargs: log_tensorboard_images( self.tensorboard_writer, *args, **kwargs)) if "figures" in handlers: self.loghook.handlers["figures"].append( lambda *args, **kwargs: log_tensorboard_figures( self.tensorboard_writer, *args, **kwargs)) ## epoch hooks # evaluate validation/step_ops/eval_op after each epoch self._eval_op = set_default(self.config, "eval_hook/eval_op", "validation/eval_op") _eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks", dict()) if not isinstance(_eval_callbacks, dict): _eval_callbacks = {"cb": _eval_callbacks} eval_callbacks = dict() for k in _eval_callbacks: eval_callbacks[k] = _eval_callbacks[k] if hasattr(self, "callbacks"): iterator_callbacks = retrieve(self.callbacks, "eval_op", default=dict()) for k in iterator_callbacks: import_path = get_str_from_obj(iterator_callbacks[k]) set_value(self.config, "eval_hook/eval_callbacks/{}".format(k), import_path) eval_callbacks[k] = import_path if hasattr(self.model, "callbacks"): model_callbacks = retrieve(self.model.callbacks, "eval_op", default=dict()) for k in model_callbacks: import_path = get_str_from_obj(model_callbacks[k]) set_value(self.config, "eval_hook/eval_callbacks/{}".format(k), import_path) eval_callbacks[k] = import_path callback_handler = None if not self.config.get("test_mode", False): callback_handler = lambda results, paths: self.loghook( results=results, step=self.get_global_step(), paths=paths, ) # offer option to run eval functor: # overwrite step op to only include the evaluation of the functor and # overwrite callbacks to only include the callbacks of the functor if self.config.get("test_mode", False) and "eval_functor" in self.config: # offer option to use eval functor for evaluation eval_functor = get_obj_from_str( self.config["eval_functor"])(config=self.config) self.step_ops = lambda: {"eval_op": eval_functor} eval_callbacks = dict() if hasattr(eval_functor, "callbacks"): for k in eval_functor.callbacks: eval_callbacks[k] = get_str_from_obj( eval_functor.callbacks[k]) set_value(self.config, "eval_hook/eval_callbacks", eval_callbacks) self.evalhook = TemplateEvalHook( datasets=self.datasets, step_getter=self.get_global_step, keypath=self._eval_op, config=self.config, callbacks=eval_callbacks, callback_handler=callback_handler, ) self.epoch_hooks.append(self.evalhook)
def __init__(self, config, root, model, hook_freq=1, num_epochs=100, hooks=[], bar_position=0): super().__init__(config, root, model, hook_freq, num_epochs, hooks, bar_position, desc='Train') image_names = [ 'step_ops/0/generated', 'step_ops/0/images', 'step_ops/0/label', 'step_ops/0/real_A', 'step_ops/0/real_B' ] loss_names = ['G_VGG', 'G_GAN', 'G_GAN_Feat', 'D_real', 'D_fake', 'G_Warp', 'F_Flow', 'F_Warp', 'W'] loss_names_T = ['G_T_GAN', 'G_T_GAN_Feat', 'D_T_real', 'D_T_fake', 'G_T_Warp'] opt = model.opt prefix = 'step_ops/0/losses/per_frame/' scalar_names = [prefix + n for n in loss_names] prefix_t = 'step_ops/0/losses/temporal/' for s in range(opt.n_frames_G - 1): s = '{}/'.format(s) scalar_names += [prefix_t + s + n for n in loss_names_T] ImPlotHook = IntervalHook([PlotImageBatch(P.latest_eval, image_names, time_axis=1)], interval=10, max_interval=500, modify_each=10) checks = [] models = [model.G, model.D, model.F] names = ['gen', 'discr', 'flow'] for n, m in zip(names, models): checks += [PyCheckpointHook(P.checkpoints, m, 'v2v_{}'.format(n), interval=config['ckpt_freq'])] self.hook_freq = 1 self.hooks += [PrepareV2VDataHook()] self.hooks += checks self.hooks += [ToNumpyHook(), ImPlotHook, PyLoggingHook(scalar_keys=scalar_names, log_keys=scalar_names, root_path=P.latest_eval, interval=config['log_freq']), IncreaseLearningRate(model, opt.niter)]
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) PM = PrintMemHook(self.model, self.get_global_step) self.hooks += [IntervalHook([PM], 1)]
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # wrap save and restore into a LambdaCheckpointHook self.ckpthook = LambdaCheckpointHook( root_path=ProjectManager.checkpoints, global_step_getter=self.get_global_step, global_step_setter=self.set_global_step, save=self.save, restore=self.restore, interval=set_default(self.config, "ckpt_freq", None), ) if not self.config.get("test_mode", False): # in training, excute train ops and add logginghook for train and # validation batches self._train_ops = set_default(self.config, "train_ops", ["step_ops/train_op"]) self._log_ops = set_default(self.config, "log_ops", ["step_ops/log_op"]) # logging self.loghook = LoggingHook( paths=self._log_ops, root_path=ProjectManager.train, interval=1, name="train", ) # wrap it in interval hook self.ihook = IntervalHook( [self.loghook], interval=set_default(self.config, "start_log_freq", 1), modify_each=1, max_interval=set_default(self.config, "log_freq", 1000), get_step=self.get_global_step, ) self.hooks.append(self.ihook) # validation logging self._validation_log_ops = set_default(self.config, "validation_log_ops", ["validation_ops/log_op"]) self._validation_root = os.path.join(ProjectManager.train, "validation") os.makedirs(self._validation_root, exist_ok=True) # logging self.validation_loghook = LoggingHook( paths=self._validation_log_ops, root_path=self._validation_root, interval=1, name="validation", ) self.hooks.append(self.validation_loghook) # write checkpoints after epoch or when interrupted self.hooks.append(self.ckpthook) wandb_logging = set_default(self.config, "integrations/wandb", False) if wandb_logging: import wandb from edflow.hooks.logging_hooks.wandb_handler import log_wandb os.environ["WANDB_RESUME"] = "allow" os.environ["WANDB_RUN_ID"] = ProjectManager.root.replace( "/", "-") wandb.init(name=ProjectManager.root, config=self.config) self.loghook.handlers["scalars"].append(log_wandb) self.validation_loghook.handlers["scalars"].append( lambda *args, **kwargs: log_wandb( *args, **kwargs, prefix="validation")) tensorboardX_logging = set_default(self.config, "integrations/tensorboardX", False) if tensorboardX_logging: from tensorboardX import SummaryWriter from edflow.hooks.logging_hooks.tensorboardX_handler import ( log_tensorboard_config, log_tensorboard_scalars, ) self.tensorboardX_writer = SummaryWriter(ProjectManager.root) log_tensorboard_config(self.tensorboardX_writer, self.config, self.get_global_step()) self.loghook.handlers["scalars"].append( lambda *args, **kwargs: log_tensorboard_scalars( self.tensorboardX_writer, *args, **kwargs)) self.validation_loghook.handlers["scalars"].append( lambda *args, **kwargs: log_tensorboard_scalars( self.tensorboardX_writer, *args, **kwargs, prefix="validation")) else: # evaluate self._eval_op = set_default(self.config, "eval_hook/eval_op", "step_ops/eval_op") self._eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks", dict()) if not isinstance(self._eval_callbacks, dict): self._eval_callbacks = {"cb": self._eval_callbacks} for k in self._eval_callbacks: self._eval_callbacks[k] = get_obj_from_str( self._eval_callbacks[k]) label_key = set_default(self.config, "eval_hook/label_key", "step_ops/eval_op/labels") self.evalhook = TemplateEvalHook( dataset=self.dataset, step_getter=self.get_global_step, keypath=self._eval_op, config=self.config, callbacks=self._eval_callbacks, labels_key=label_key, ) self.hooks.append(self.evalhook) self._train_ops = [] self._log_ops = []