Exemplo n.º 1
0
    def __init__(self, config, root, model, **kwargs):
        config["eval_forever"] = True
        kwargs["num_epochs"] = 1
        super().__init__(config, root, model, **kwargs)

        restorer = RestoreTFModelHook(
            variables=tf.global_variables(),
            checkpoint_path=ProjectManager().checkpoints,
            global_step_setter=self.set_global_step)
        self.restorer = restorer

        def has_eval(checkpoint):
            global_step = restorer.parse_global_step(checkpoint)
            eval_file = os.path.join(ProjectManager().latest_eval,
                                     "{:06}_metrics.npz".format(global_step))
            return not os.path.exists(eval_file)

        waiter = WaitForCheckpointHook(
            checkpoint_root=ProjectManager().checkpoints,
            filter_cond=has_eval,
            callback=restorer)
        evaluation = EvalHook(self)

        manager = KeepBestCheckpoints(
            checkpoint_root=ProjectManager().checkpoints,
            metric_template=os.path.join(ProjectManager().latest_eval,
                                         "{:06}_metrics.npz"),
            metric_key="-mAP",
            n_keep=2)
        self.hooks += [waiter, evaluation]
        if config.get("manage_checkpoints", False):
            self.hooks += [manager]
        self.initialize()
Exemplo n.º 2
0
    def _init_step_ops(self):
        # additional inputs
        self.pid_placeholder = tf.placeholder(tf.string, shape=[None])

        # loss
        endpoints = self.model.embeddings
        dists = loss.cdist(endpoints['emb'],
                           endpoints['emb'],
                           metric=self.config.get("metric", "euclidean"))
        losses, train_top1, prec_at_k, _, neg_dists, pos_dists = (
            loss.LOSS_CHOICES["batch_hard"](
                dists,
                self.pid_placeholder,
                self.config.get("margin", "soft"),
                batch_precision_at_k=self.config.get("n_views", 4) - 1))

        # Count the number of active entries, and compute the total batch loss.
        loss_mean = tf.reduce_mean(losses)

        # train op
        learning_rate = self.config.get("learning_rate", 3e-4)
        self.logger.info(
            "Training with learning rate: {}".format(learning_rate))
        optimizer = tf.train.AdamOptimizer(learning_rate)
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op = optimizer.minimize(loss_mean)
        self._step_ops = train_op

        tolog = {
            "loss": loss_mean,
            "top1": train_top1,
            "prec@{}".format(self.config.get("n_views", 4) - 1): prec_at_k
        }
        loghook = LoggingHook(logs=tolog,
                              scalars=tolog,
                              images={"image": self.model.inputs["image"]},
                              root_path=ProjectManager().train,
                              interval=1)
        ckpt_hook = CheckpointHook(root_path=ProjectManager().checkpoints,
                                   variables=tf.global_variables(),
                                   modelname=self.model.name,
                                   step=self.get_global_step,
                                   interval=self.config.get("ckpt_freq", 1000),
                                   max_to_keep=None)
        self.hooks.append(ckpt_hook)
        ihook = IntervalHook([loghook],
                             interval=1,
                             modify_each=1,
                             max_interval=self.config.get("log_freq", 1000))
        self.hooks.append(ihook)
Exemplo n.º 3
0
    def __init__(self, config, root, model, **kwargs):
        config["eval_forever"] = True
        kwargs["num_epochs"] = 1
        super().__init__(config, root, model, **kwargs)

        restorer = RestoreTFModelHook(
            variables=tf.global_variables(),
            checkpoint_path=ProjectManager().checkpoints,
            global_step_setter=self.set_global_step)
        self.restorer = restorer
        waiter = WaitForCheckpointHook(
            checkpoint_root=ProjectManager().checkpoints, callback=restorer)
        evaluation = ExtractHook(self, model)

        self.hooks += [waiter, evaluation]
        self.initialize()
Exemplo n.º 4
0
 def __init__(self, config, root, model, **kwargs):
     unstackhook = Unstack(self)
     kwargs["hook_freq"] = 1
     kwargs["hooks"] = [unstackhook]
     if "num_epochs" not in kwargs:
         kwargs["num_epochs"] = config["num_epochs"]
     super().__init__(config, root, model, **kwargs)
     self._init_step_ops()
     restorer = RestoreTFModelHook(
         variables=self.model.variables,
         checkpoint_path=ProjectManager().checkpoints,
         global_step_setter=self.set_global_step)
     self.restorer = restorer
Exemplo n.º 5
0
import tensorflow as tf
import os

from edflow.iterators.tf_iterator import TFHookedModelIterator
from edflow.hooks.checkpoint_hooks.common import WaitForCheckpointHook
from edflow.hooks.checkpoint_hooks.tf_checkpoint_hook import (
    RestoreModelHook,
    RestoreTFModelHook,
    RestoreCurrentCheckpointHook,
)
from edflow.project_manager import ProjectManager

P = ProjectManager()


class TFBaseEvaluator(TFHookedModelIterator):
    def __init__(self, *args, desc="Eval", hook_freq=1, num_epochs=1, **kwargs):
        """
        New Base evaluator restores given checkpoint path if provided,
        else scans checkpoint directory for latest checkpoint and uses that

        Parameters
        ----------
        desc : str
            a description for the evaluator. This description will be used during the logging.
        hook_freq : int
            Frequency at which hooks are evaluated.
        num_epochs : int
            Number of times to iterate over the data.
        """
        kwargs.update({"desc": desc, "hook_freq": hook_freq, "num_epochs": num_epochs})
Exemplo n.º 6
0
 def has_eval(checkpoint):
     global_step = restorer.parse_global_step(checkpoint)
     eval_file = os.path.join(ProjectManager().latest_eval,
                              "{:06}_metrics.npz".format(global_step))
     return not os.path.exists(eval_file)
Exemplo n.º 7
0
 def __init__(self, iterator):
     self.iterator = iterator
     self.logger = get_logger(self, "latest_eval")
     self.global_step = self.iterator.get_global_step
     self.root = ProjectManager().latest_eval
     self.tb_saver = tf.summary.FileWriter(self.root)
Exemplo n.º 8
0
def _init_project(out_base_dir):
    """Sets up subdirectories given a base directory and copies all scripts."""

    P = ProjectManager(out_base_dir)

    return P.root
Exemplo n.º 9
0
def use_project(project_dir, postfix=None):
    """Must be called at the very beginning of a script."""
    P = ProjectManager(given_directory=project_dir, postfix=postfix)
    LogSingleton(P.root)
    return P
Exemplo n.º 10
0
def init_project(base_dir, code_root=".", postfix=None):
    """Must be called at the very beginning of a script."""
    P = ProjectManager(base_dir, code_root=code_root, postfix=postfix)
    LogSingleton(P.root)
    return P