示例#1
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     # wrap save and restore into a LambdaCheckpointHook
     self.ckpthook = LambdaCheckpointHook(
         root_path=ProjectManager.checkpoints,
         global_step_getter=self.get_global_step,
         global_step_setter=self.set_global_step,
         save=self.save,
         restore=self.restore,
         interval=set_default(self.config, "ckpt_freq", None),
     )
     if not self.config.get("test_mode", False):
         # in training, excute train ops and add logginghook
         self._train_ops = set_default(
             self.config, "train_ops", ["step_ops/train_op"]
         )
         self._log_ops = set_default(self.config, "log_ops", ["step_ops/log_op"])
         # logging
         self.loghook = LoggingHook(
             paths=self._log_ops, root_path=ProjectManager.train, interval=1
         )
         # wrap it in interval hook
         self.ihook = IntervalHook(
             [self.loghook],
             interval=set_default(self.config, "start_log_freq", 1),
             modify_each=1,
             max_interval=set_default(self.config, "log_freq", 1000),
             get_step=self.get_global_step,
         )
         self.hooks.append(self.ihook)
         # write checkpoints after epoch or when interrupted
         self.hooks.append(self.ckpthook)
     else:
         # evaluate
         self._eval_op = set_default(
             self.config, "eval_hook/eval_op", "step_ops/eval_op"
         )
         self._eval_callbacks = set_default(
             self.config, "eval_hook/eval_callbacks", list()
         )
         if not isinstance(self._eval_callbacks, list):
             self._eval_callbacks = [self._eval_callbacks]
         self._eval_callbacks = [
             get_obj_from_str(name) for name in self._eval_callbacks
         ]
         label_key = set_default(
             self.config, "eval_hook/label_key", "step_ops/eval_op/labels"
         )
         self.evalhook = TemplateEvalHook(
             dataset=self.dataset,
             step_getter=self.get_global_step,
             keypath=self._eval_op,
             meta=self.config,
             callbacks=self._eval_callbacks,
             label_key=label_key,
         )
         self.hooks.append(self.evalhook)
         self._train_ops = []
         self._log_ops = []
示例#2
0
    def _init_step_ops(self):
        # additional inputs
        self.pid_placeholder = tf.placeholder(tf.string, shape=[None])

        # loss
        endpoints = self.model.embeddings
        dists = loss.cdist(endpoints['emb'],
                           endpoints['emb'],
                           metric=self.config.get("metric", "euclidean"))
        losses, train_top1, prec_at_k, _, neg_dists, pos_dists = (
            loss.LOSS_CHOICES["batch_hard"](
                dists,
                self.pid_placeholder,
                self.config.get("margin", "soft"),
                batch_precision_at_k=self.config.get("n_views", 4) - 1))

        # Count the number of active entries, and compute the total batch loss.
        loss_mean = tf.reduce_mean(losses)

        # train op
        learning_rate = self.config.get("learning_rate", 3e-4)
        self.logger.info(
            "Training with learning rate: {}".format(learning_rate))
        optimizer = tf.train.AdamOptimizer(learning_rate)
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op = optimizer.minimize(loss_mean)
        self._step_ops = train_op

        tolog = {
            "loss": loss_mean,
            "top1": train_top1,
            "prec@{}".format(self.config.get("n_views", 4) - 1): prec_at_k
        }
        loghook = LoggingHook(logs=tolog,
                              scalars=tolog,
                              images={"image": self.model.inputs["image"]},
                              root_path=ProjectManager().train,
                              interval=1)
        ckpt_hook = CheckpointHook(root_path=ProjectManager().checkpoints,
                                   variables=tf.global_variables(),
                                   modelname=self.model.name,
                                   step=self.get_global_step,
                                   interval=self.config.get("ckpt_freq", 1000),
                                   max_to_keep=None)
        self.hooks.append(ckpt_hook)
        ihook = IntervalHook([loghook],
                             interval=1,
                             modify_each=1,
                             max_interval=self.config.get("log_freq", 1000))
        self.hooks.append(ihook)
示例#3
0
文件: iterators.py 项目: mritv/edflow
    def __init__(self, *args, **kwargs):
        super(Trainer, self).__init__(*args, **kwargs)

        overviewHook = ImageOverviewHook(images=self.img_ops,
                                         root_path=ProjectManager.train,
                                         interval=1)
        ihook = IntervalHook(
            [overviewHook],
            interval=1,
            modify_each=1,
            max_interval=self.config.get("log_freq", 1000),
            get_step=self.get_global_step,
        )
        self.hooks.append(ihook)
示例#4
0
    def setup(self):
        """Init train_placeholders, log_ops and img_ops which can be added
        to."""
        self.train_placeholders = dict()
        self.log_ops = dict()
        self.img_ops = dict()
        self.update_ops = list()
        self.create_train_op()

        ckpt_hook = CheckpointHook(
            root_path=ProjectManager.checkpoints,
            variables=self.get_checkpoint_variables(),
            modelname="model",
            step=self.get_global_step,
            interval=self.config.get("ckpt_freq", None),
            max_to_keep=self.config.get("ckpt_keep", None),
        )
        self.hooks.append(ckpt_hook)

        loghook = LoggingHook(
            logs=self.log_ops,
            scalars=self.log_ops,
            images=self.img_ops,
            root_path=ProjectManager.train,
            interval=1,
            log_images_to_tensorboard=self.config.get(
                "log_images_to_tensorboard", False
            ),
        )
        ihook = IntervalHook(
            [loghook],
            interval=self.config.get("start_log_freq", 1),
            modify_each=1,
            max_interval=self.config.get("log_freq", 1000),
            get_step=self.get_global_step,
        )
        self.hooks.append(ihook)
示例#5
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # wrap save and restore into a LambdaCheckpointHook
        self.ckpthook = LambdaCheckpointHook(
            root_path=ProjectManager.checkpoints,
            global_step_getter=self.get_global_step,
            global_step_setter=self.set_global_step,
            save=self.save,
            restore=self.restore,
            interval=set_default(self.config, "ckpt_freq", None),
            ckpt_zero=set_default(self.config, "ckpt_zero", False),
        )
        # write checkpoints after epoch or when interrupted during training
        if not self.config.get("test_mode", False):
            self.hooks.append(self.ckpthook)

        ## hooks - disabled unless -t is specified

        # execute train ops
        self._train_ops = set_default(self.config, "train_ops",
                                      ["train/train_op"])
        train_hook = ExpandHook(paths=self._train_ops, interval=1)
        self.hooks.append(train_hook)

        # log train/step_ops/log_ops in increasing intervals
        self._log_ops = set_default(self.config, "log_ops",
                                    ["train/log_op", "validation/log_op"])
        self.loghook = LoggingHook(paths=self._log_ops,
                                   root_path=ProjectManager.train,
                                   interval=1)
        self.ihook = IntervalHook(
            [self.loghook],
            interval=set_default(self.config, "start_log_freq", 1),
            modify_each=1,
            max_interval=set_default(self.config, "log_freq", 1000),
            get_step=self.get_global_step,
        )
        self.hooks.append(self.ihook)

        # setup logging integrations
        if not self.config.get("test_mode", False):
            default_wandb_logging = {
                "active": False,
                "handlers": ["scalars", "images"]
            }
            wandb_logging = set_default(self.config, "integrations/wandb",
                                        default_wandb_logging)
            if wandb_logging["active"]:
                import wandb
                from edflow.hooks.logging_hooks.wandb_handler import (
                    log_wandb,
                    log_wandb_images,
                )

                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = ProjectManager.root.strip(
                    "/").replace("/", "-")
                wandb_project = set_default(self.config,
                                            "integrations/wandb/project", None)
                wandb_entity = set_default(self.config,
                                           "integrations/wandb/entity", None)
                wandb.init(
                    name=ProjectManager.root,
                    config=self.config,
                    project=wandb_project,
                    entity=wandb_entity,
                )

                handlers = set_default(
                    self.config,
                    "integrations/wandb/handlers",
                    default_wandb_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(log_wandb)
                if "images" in handlers:
                    self.loghook.handlers["images"].append(log_wandb_images)

            default_tensorboard_logging = {
                "active": False,
                "handlers": ["scalars", "images", "figures"],
            }
            tensorboard_logging = set_default(self.config,
                                              "integrations/tensorboard",
                                              default_tensorboard_logging)
            if tensorboard_logging["active"]:
                try:
                    from torch.utils.tensorboard import SummaryWriter
                except:
                    from tensorboardX import SummaryWriter

                from edflow.hooks.logging_hooks.tensorboard_handler import (
                    log_tensorboard_config,
                    log_tensorboard_scalars,
                    log_tensorboard_images,
                    log_tensorboard_figures,
                )

                self.tensorboard_writer = SummaryWriter(ProjectManager.root)
                log_tensorboard_config(self.tensorboard_writer, self.config,
                                       self.get_global_step())
                handlers = set_default(
                    self.config,
                    "integrations/tensorboard/handlers",
                    default_tensorboard_logging["handlers"],
                )
                if "scalars" in handlers:
                    self.loghook.handlers["scalars"].append(
                        lambda *args, **kwargs: log_tensorboard_scalars(
                            self.tensorboard_writer, *args, **kwargs))
                if "images" in handlers:
                    self.loghook.handlers["images"].append(
                        lambda *args, **kwargs: log_tensorboard_images(
                            self.tensorboard_writer, *args, **kwargs))
                if "figures" in handlers:
                    self.loghook.handlers["figures"].append(
                        lambda *args, **kwargs: log_tensorboard_figures(
                            self.tensorboard_writer, *args, **kwargs))
        ## epoch hooks

        # evaluate validation/step_ops/eval_op after each epoch
        self._eval_op = set_default(self.config, "eval_hook/eval_op",
                                    "validation/eval_op")
        _eval_callbacks = set_default(self.config, "eval_hook/eval_callbacks",
                                      dict())
        if not isinstance(_eval_callbacks, dict):
            _eval_callbacks = {"cb": _eval_callbacks}
        eval_callbacks = dict()
        for k in _eval_callbacks:
            eval_callbacks[k] = _eval_callbacks[k]
        if hasattr(self, "callbacks"):
            iterator_callbacks = retrieve(self.callbacks,
                                          "eval_op",
                                          default=dict())
            for k in iterator_callbacks:
                import_path = get_str_from_obj(iterator_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        if hasattr(self.model, "callbacks"):
            model_callbacks = retrieve(self.model.callbacks,
                                       "eval_op",
                                       default=dict())
            for k in model_callbacks:
                import_path = get_str_from_obj(model_callbacks[k])
                set_value(self.config, "eval_hook/eval_callbacks/{}".format(k),
                          import_path)
                eval_callbacks[k] = import_path
        callback_handler = None
        if not self.config.get("test_mode", False):
            callback_handler = lambda results, paths: self.loghook(
                results=results,
                step=self.get_global_step(),
                paths=paths,
            )

        # offer option to run eval functor:
        # overwrite step op to only include the evaluation of the functor and
        # overwrite callbacks to only include the callbacks of the functor
        if self.config.get("test_mode",
                           False) and "eval_functor" in self.config:
            # offer option to use eval functor for evaluation
            eval_functor = get_obj_from_str(
                self.config["eval_functor"])(config=self.config)
            self.step_ops = lambda: {"eval_op": eval_functor}
            eval_callbacks = dict()
            if hasattr(eval_functor, "callbacks"):
                for k in eval_functor.callbacks:
                    eval_callbacks[k] = get_str_from_obj(
                        eval_functor.callbacks[k])
            set_value(self.config, "eval_hook/eval_callbacks", eval_callbacks)
        self.evalhook = TemplateEvalHook(
            datasets=self.datasets,
            step_getter=self.get_global_step,
            keypath=self._eval_op,
            config=self.config,
            callbacks=eval_callbacks,
            callback_handler=callback_handler,
        )
        self.epoch_hooks.append(self.evalhook)
示例#6
0
    def __init__(self,
                 config,
                 root,
                 model,
                 hook_freq=1,
                 num_epochs=100,
                 hooks=[],
                 bar_position=0):

        super().__init__(config,
                         root,
                         model,
                         hook_freq,
                         num_epochs,
                         hooks,
                         bar_position,
                         desc='Train')

        image_names = [
                'step_ops/0/generated',
                'step_ops/0/images',
                'step_ops/0/label',
                'step_ops/0/real_A',
                'step_ops/0/real_B'
                ]

        loss_names = ['G_VGG',
                      'G_GAN',
                      'G_GAN_Feat',
                      'D_real',
                      'D_fake',
                      'G_Warp',
                      'F_Flow',
                      'F_Warp',
                      'W']

        loss_names_T = ['G_T_GAN',
                        'G_T_GAN_Feat',
                        'D_T_real',
                        'D_T_fake',
                        'G_T_Warp']

        opt = model.opt

        prefix = 'step_ops/0/losses/per_frame/'
        scalar_names = [prefix + n for n in loss_names]

        prefix_t = 'step_ops/0/losses/temporal/'
        for s in range(opt.n_frames_G - 1):
            s = '{}/'.format(s)
            scalar_names += [prefix_t + s + n for n in loss_names_T]

        ImPlotHook = IntervalHook([PlotImageBatch(P.latest_eval,
                                                  image_names,
                                                  time_axis=1)],
                                  interval=10,
                                  max_interval=500,
                                  modify_each=10)

        checks = []
        models = [model.G, model.D, model.F]
        names = ['gen', 'discr', 'flow']
        for n, m in zip(names, models):
            checks += [PyCheckpointHook(P.checkpoints,
                                        m,
                                        'v2v_{}'.format(n),
                                        interval=config['ckpt_freq'])]

        self.hook_freq = 1
        self.hooks += [PrepareV2VDataHook()]
        self.hooks += checks
        self.hooks += [ToNumpyHook(),
                       ImPlotHook,
                       PyLoggingHook(scalar_keys=scalar_names,
                                     log_keys=scalar_names,
                                     root_path=P.latest_eval,
                                     interval=config['log_freq']),
                       IncreaseLearningRate(model, opt.niter)]
示例#7
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        PM = PrintMemHook(self.model, self.get_global_step)
        self.hooks += [IntervalHook([PM], 1)]
示例#8
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # wrap save and restore into a LambdaCheckpointHook
        self.ckpthook = LambdaCheckpointHook(
            root_path=ProjectManager.checkpoints,
            global_step_getter=self.get_global_step,
            global_step_setter=self.set_global_step,
            save=self.save,
            restore=self.restore,
            interval=set_default(self.config, "ckpt_freq", None),
        )
        if not self.config.get("test_mode", False):
            # in training, excute train ops and add logginghook for train and
            # validation batches
            self._train_ops = set_default(self.config, "train_ops",
                                          ["step_ops/train_op"])
            self._log_ops = set_default(self.config, "log_ops",
                                        ["step_ops/log_op"])
            # logging
            self.loghook = LoggingHook(
                paths=self._log_ops,
                root_path=ProjectManager.train,
                interval=1,
                name="train",
            )
            # wrap it in interval hook
            self.ihook = IntervalHook(
                [self.loghook],
                interval=set_default(self.config, "start_log_freq", 1),
                modify_each=1,
                max_interval=set_default(self.config, "log_freq", 1000),
                get_step=self.get_global_step,
            )
            self.hooks.append(self.ihook)
            # validation logging
            self._validation_log_ops = set_default(self.config,
                                                   "validation_log_ops",
                                                   ["validation_ops/log_op"])
            self._validation_root = os.path.join(ProjectManager.train,
                                                 "validation")
            os.makedirs(self._validation_root, exist_ok=True)
            # logging
            self.validation_loghook = LoggingHook(
                paths=self._validation_log_ops,
                root_path=self._validation_root,
                interval=1,
                name="validation",
            )
            self.hooks.append(self.validation_loghook)
            # write checkpoints after epoch or when interrupted
            self.hooks.append(self.ckpthook)
            wandb_logging = set_default(self.config, "integrations/wandb",
                                        False)
            if wandb_logging:
                import wandb
                from edflow.hooks.logging_hooks.wandb_handler import log_wandb

                os.environ["WANDB_RESUME"] = "allow"
                os.environ["WANDB_RUN_ID"] = ProjectManager.root.replace(
                    "/", "-")
                wandb.init(name=ProjectManager.root, config=self.config)
                self.loghook.handlers["scalars"].append(log_wandb)
                self.validation_loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_wandb(
                        *args, **kwargs, prefix="validation"))
            tensorboardX_logging = set_default(self.config,
                                               "integrations/tensorboardX",
                                               False)
            if tensorboardX_logging:
                from tensorboardX import SummaryWriter
                from edflow.hooks.logging_hooks.tensorboardX_handler import (
                    log_tensorboard_config,
                    log_tensorboard_scalars,
                )

                self.tensorboardX_writer = SummaryWriter(ProjectManager.root)
                log_tensorboard_config(self.tensorboardX_writer, self.config,
                                       self.get_global_step())
                self.loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_tensorboard_scalars(
                        self.tensorboardX_writer, *args, **kwargs))
                self.validation_loghook.handlers["scalars"].append(
                    lambda *args, **kwargs: log_tensorboard_scalars(
                        self.tensorboardX_writer,
                        *args,
                        **kwargs,
                        prefix="validation"))

        else:
            # evaluate
            self._eval_op = set_default(self.config, "eval_hook/eval_op",
                                        "step_ops/eval_op")
            self._eval_callbacks = set_default(self.config,
                                               "eval_hook/eval_callbacks",
                                               dict())
            if not isinstance(self._eval_callbacks, dict):
                self._eval_callbacks = {"cb": self._eval_callbacks}
            for k in self._eval_callbacks:
                self._eval_callbacks[k] = get_obj_from_str(
                    self._eval_callbacks[k])
            label_key = set_default(self.config, "eval_hook/label_key",
                                    "step_ops/eval_op/labels")
            self.evalhook = TemplateEvalHook(
                dataset=self.dataset,
                step_getter=self.get_global_step,
                keypath=self._eval_op,
                config=self.config,
                callbacks=self._eval_callbacks,
                labels_key=label_key,
            )
            self.hooks.append(self.evalhook)
            self._train_ops = []
            self._log_ops = []