Exemplo n.º 1
0
    class MyConfig:
        """A configuration

        Attributes:
            w: An integer
        """

        __xpmid__ = "annotations.class_variable.config"

        x: Param[int]
        y: Param[float] = 2.3
        y2: Annotated[float, default(2.3)]
        z: Param[Optional[float]]
        t: Param[List[float]]
        w: Param[int]
        opt: Param[Optional[int]]
        path: Annotated[Path, pathgenerator("world")]
        option: Option[str]
Exemplo n.º 2
0
class ProgressingTask(Task):
    path: Annotated[Path, pathgenerator("progress.txt")]

    def execute(self):
        _progress = 0.0

        while True:
            time.sleep(1e-4)
            if self.path.exists():
                try:
                    _level, _progress, _desc = self.path.read_text().split(
                        " ", maxsplit=2)
                    _progress = float(_progress)
                    _level = int(_level)
                    progress(_progress, level=_level, desc=_desc or None)
                    if _progress == 1.0 and _level == 0:
                        break
                except:
                    pass
Exemplo n.º 3
0
class NestedProgressingTask(Task):
    PROGRESSES: NestedTasks = (
        "Task 1",
        [
            ("Task 1.1", 2),
            ("Task 1.2", [("Task 1.2.1", 3), ("Task 1.2.2", 4)]),
            ("Task 1.3", 1),
        ],
    )

    path: Annotated[Path, pathgenerator("progress.txt")]

    def execute(self):
        self._execute(NestedProgressingTask.PROGRESSES)

    def wait(self):
        while not self.path.exists():
            time.sleep(1e-4)
        self.path.unlink()

    def _execute(self, tasks: NestedTasks):
        self.wait()
        name, subtasks = tasks
        if isinstance(subtasks, list):
            for subtasks in tqdm(subtasks,
                                 desc=name,
                                 miniters=1,
                                 mininterval=0):
                self._execute(subtasks)
            self.wait()
        else:
            for _ in tqdm(range(subtasks),
                          desc=name,
                          miniters=1,
                          mininterval=0):
                self.wait()
            self.wait()
Exemplo n.º 4
0
class ShuffledTrainingTripletsLines(Task):
    data: Param[ir.TrainingTripletsLines]
    path: Annotated[Path, pathgenerator("triplets.lst")]
    seed: Param[int]

    def config(self):
        data = self.data.copy()
        data.path = self.path
        return data

    def execute(self):
        # --- Shuffle using the shuf command with a seed
        r = RandomStream(self.seed)
        command = [
            "shuf",
            f"--random-source={r.filepath}",
            "-o",
            self.path,
            self.data.path,
        ]
        p = subprocess.Popen(command)
        with r:
            p.wait()
            assert p.returncode == 0
Exemplo n.º 5
0
class FaissIndex(Index):
    normalize: Param[bool]
    faiss_index: Annotated[Path, pathgenerator("faiss.dat")]
Exemplo n.º 6
0
 class A(Config):
     path: Annotated[Path, pathgenerator("test.txt")]
Exemplo n.º 7
0
class ValidationListener(LearnerListener):
    """Learning validation early-stopping

    Computes a validation metric and stores the best result

    Attributes:
        warmup: Number of warmup epochs
        early_stop: Maximum number of epochs without improvement on validation
        validation: How to compute the validation metric
        validation_interval: interval for computing validation metrics
        metrics: Dictionary whose keys are the metrics to record, and boolean
            values whether the best performance checkpoint should be kept for
            the associated metric
    """

    metrics: Param[Dict[str, bool]] = {"map": True}
    dataset: Param[Adhoc]
    retriever: Param[Retriever]
    validation_interval: Param[int] = 1
    warmup: Param[int] = -1
    bestpath: Annotated[Path, pathgenerator("best")]
    info: Annotated[Path, pathgenerator("info.json")]
    early_stop: Param[int] = 20

    def initialize(self, key: str, learner: "Learner", context: TrainContext):
        super().initialize(key, learner, context)

        self.retriever.initialize()
        self.bestpath.mkdir(exist_ok=True, parents=True)

        # Checkpoint start
        try:
            with self.info.open("rt") as fp:
                self.top = json.load(fp)  # type: Dict[str, Dict[str, float]]
        except Exception:
            self.top = {}

    def update_metrics(self, metrics: Dict[str, float]):
        if self.top:
            # Just use another key
            for metric in self.metrics.keys():
                metrics[f"{self.key}/final/{metric}"] = self.top[metric][
                    "value"]

    def taskoutputs(self, learner: "Learner"):
        """Experimaestro outputs"""
        return {
            key: SerializedConfig(learner.scorer,
                                  SavedScorer(str(self.bestpath / key)))
            for key, store in self.metrics.items() if store
        }

    def __call__(self, state):
        if state.epoch % self.validation_interval == 0:
            # Compute validation metrics
            means, _ = evaluate(None, self.retriever, self.dataset,
                                list(self.metrics.keys()))

            for metric, keep in self.metrics.items():
                value = means[metric]

                self.context.writer.add_scalar(f"{self.key}/{metric}", value,
                                               state.epoch)

                # Update the top validation
                if state.epoch >= self.warmup:
                    topstate = self.top.get(metric, None)
                    if topstate is None or value > topstate["value"]:
                        # Save the new top JSON
                        self.top[metric] = {
                            "value": value,
                            "epoch": self.context.epoch
                        }

                        # Copy in corresponding directory
                        if keep:
                            self.context.copy(self.bestpath / metric)

            # Update information
            with self.info.open("wt") as fp:
                json.dump(self.top, fp)

        # Early stopping?
        if self.early_stop > 0 and self.top:
            epochs_since_imp = self.context.epoch - max(
                info["epoch"] for info in self.top.values())
            if epochs_since_imp >= self.early_stop:
                return False

        # No, proceed...
        return True
Exemplo n.º 8
0
class Learner(Task, EasyLogger):
    """Learns a model

    The learner task is generic, and takes two main arguments:
    (1) the scorer defines the model (e.g. DRMM), and
    (2) the trainer defines the loss (e.g. pointwise, pairwise, etc.)

    Attributes:

        max_epoch: Maximum training epoch
        early_stop: Maximum number of epochs without improvement on validation
        checkpoint_interval: Number of epochs between each checkpoint
        scorer: The scorer to learn
        trainer: The trainer used to learn the parameters of the scorer
        listeners: learning process listeners (e.g. validation or other metrics)
        random: Random generator
    """

    # Training
    random: Param[Random]

    max_epoch: Param[int] = 1000
    trainer: Param[Trainer]
    scorer: Param[LearnableScorer]

    listeners: Param[Dict[str, LearnerListener]]

    # Checkpoints
    checkpoint_interval: Param[int] = 1

    # Paths
    logpath: Annotated[Path, pathgenerator("runs")]
    checkpointspath: Annotated[Path, pathgenerator("checkpoints")]

    def taskoutputs(self):
        return {
            "listeners": {
                key: listener.taskoutputs(self)
                for key, listener in self.listeners.items()
            }
        }

    # The Trainer
    def execute(self):
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        for handler in logger.handlers:
            handler.setLevel(logging.INFO)

        self.only_cached = False

        # Initialize the scorer and trainer
        self.logger.info("Scorer initialization")
        self.scorer.initialize(self.random.state)

        # Initialize the listeners
        context = TrainContext(self.logpath, self.checkpointspath)
        for key, listener in self.listeners.items():
            listener.initialize(key, self, context)

        self.logger.info("Trainer initialization")
        self.trainer.initialize(self.random.state, self.scorer, context)

        self.logger.info("Starting to train")

        current = 0
        state = None
        with tqdm(self.trainer.iter_train(self.max_epoch),
                  total=self.max_epoch) as states:
            for state in states:
                # Report progress
                states.update(state.epoch - current)

                if state.epoch >= 0 and not self.only_cached:
                    message = f"epoch {state.epoch}"
                    if state.cached:
                        self.logger.debug(f"[train] [cached] {message}")
                    else:
                        self.logger.debug(f"[train] {message}")

                if state.epoch == -1:
                    continue

                if not state.cached and state.epoch % self.checkpoint_interval == 0:
                    # Save checkpoint if needed
                    context.save_checkpoint()

                # Call listeners
                stop = False
                for listener in self.listeners.values():
                    stop = listener(state) and stop

                if stop:
                    self.logger.warn(
                        "stopping after epoch {epoch} ({early_stop} epochs since "
                        "all listeners asked for it")

                # Stop if max epoch is reached
                if context.epoch >= self.max_epoch:
                    self.logger.warn(
                        "stopping after epoch {max_epoch} (max_epoch)".format(
                            **self.__dict__))
                    break

            # End of the learning process
            if state is not None and not state.cached:
                # Set the hyper-parameters
                metrics = {}
                for listener in self.listeners.values():
                    listener.update_metrics(metrics)
                context.writer.add_hparams(self.__tags__, metrics)
Exemplo n.º 9
0
 class B(Config):
     p: Annotated[Path, pathgenerator("p.txt")]
Exemplo n.º 10
0
class RandomFold(Task):
    """Extracts a random subset of topics from a dataset

    Attributes:
        seed: Random seed used to compute the fold
        sizes: Number of topics of each fold (or percentage)
        dataset: The Adhoc dataset from which a fold is extracted
        fold: Which fold to take
    """

    seed: Param[int]
    sizes: Param[List[float]]
    dataset: Param[Adhoc]
    fold: Param[int]
    exclude: Param[Optional[AdhocTopics]]

    assessments: Annotated[Path, pathgenerator("assessments.tsv")]
    topics: Annotated[Path, pathgenerator("topics.tsv")]

    def __validate__(self):
        assert self.fold < len(self.sizes)

    @staticmethod
    def folds(
        seed: int,
        sizes: List[float],
        dataset: Param[Adhoc],
        exclude: Param[AdhocTopics] = None,
        submit=True,
    ):
        """Creates folds

        Parameters:

        - submit: if true (default), submits the fold tasks to experimaestro
        """

        folds = []
        for ix in range(len(sizes)):
            fold = RandomFold(seed=seed,
                              sizes=sizes,
                              dataset=dataset,
                              exclude=exclude,
                              fold=ix)
            if submit:
                fold = fold.submit()
            folds.append(fold)

        return folds

    def config(self) -> Adhoc:
        return Adhoc(
            topics=CSVAdhocTopics(path=self.topics),
            assessments=TrecAdhocAssessments(path=self.assessments),
            documents=self.dataset.documents,
        )

    def execute(self):
        import numpy as np

        # Get topics
        badids = (set(
            topic.qid
            for topic in self.exclude.iter()) if self.exclude else set())
        topics = [
            topic for topic in self.dataset.topics.iter()
            if topic.qid not in badids
        ]
        random = np.random.RandomState(self.seed)
        random.shuffle(topics)

        # Get the fold
        sizes = np.array([0] + self.sizes)
        sizes = np.round(len(topics) * sizes / sizes.sum())
        assert sizes[self.fold + 1] > 0

        indices = sizes.cumsum().astype(int)
        topics = topics[indices[self.fold]:indices[self.fold + 1]]

        # Write topics and assessments
        ids = set()
        self.topics.parent.mkdir(parents=True, exist_ok=True)
        with self.topics.open("wt") as fp:
            for topic in topics:
                ids.add(topic.qid)
                # FIXME: hardcoded title...
                fp.write(f"""{topic.qid}\t{topic.title}\n""")

        with self.assessments.open("wt") as fp:
            for qrels in self.dataset.assessments.iter():
                if qrels.qid in ids:
                    for qrel in qrels.assessments:
                        fp.write(
                            f"""{qrels.qid} 0 {qrel.docno} {qrel.rel}\n""")