예제 #1
0
def logged_metrics():
    return [
        ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 1),
        ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 2),
        ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), 3),
        ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100),
        ScalarMetricLogEntry("training.accuracy", 20, datetime.datetime.utcnow(), 200),
        ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300),
        ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), 10),
        ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), 20),
        ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30),
    ]
예제 #2
0
def test_linearize_metrics():
    entries = [
        ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(),
                             100),
        ScalarMetricLogEntry("training.accuracy", 5,
                             datetime.datetime.utcnow(), 50),
        ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(),
                             200),
        ScalarMetricLogEntry("training.accuracy", 10,
                             datetime.datetime.utcnow(), 100),
        ScalarMetricLogEntry("training.accuracy", 15,
                             datetime.datetime.utcnow(), 150),
        ScalarMetricLogEntry("training.accuracy", 30,
                             datetime.datetime.utcnow(), 300),
    ]
    linearized = linearize_metrics(entries)
    assert type(linearized) == dict
    assert len(linearized.keys()) == 2
    assert "training.loss" in linearized
    assert "training.accuracy" in linearized
    assert len(linearized["training.loss"]["steps"]) == 2
    assert len(linearized["training.loss"]["values"]) == 2
    assert len(linearized["training.loss"]["timestamps"]) == 2
    assert len(linearized["training.accuracy"]["steps"]) == 4
    assert len(linearized["training.accuracy"]["values"]) == 4
    assert len(linearized["training.accuracy"]["timestamps"]) == 4
    assert linearized["training.accuracy"]["steps"] == [5, 10, 15, 30]
    assert linearized["training.accuracy"]["values"] == [50, 100, 150, 300]
    assert linearized["training.loss"]["steps"] == [10, 20]
    assert linearized["training.loss"]["values"] == [100, 200]
예제 #3
0
    def export(self, observer, base_dir, remove_sources=False,
               overwrite=None):
        """
        Exports the file log into another observer.
        Requires sacred to be installed.
        Args:
            observer: Observer to export to
            base_dir: root path to sources
            remove_sources: if sources are too complicated to match
            overwrite: whether to overwrite an experiment
        """
        from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics

        # Start simulation of run
        experiment = self.experiment.copy()
        experiment['base_dir'] = base_dir
        # FIXME
        experiment['sources'] = [] if remove_sources else update_source_path_prefix(base_dir, experiment['sources'])
        try:
            observer.started_event(
                experiment,
                self.command,
                self.host,
                datetime.datetime.fromisoformat(self.start_time),
                self.config,
                self.meta,
                _id=overwrite
            )
        except FileNotFoundError as e:
            raise FileNotFoundError("The sources are incorrect. Try fixing paths or use `remove_sources=True`."
                                    f" Original error: {e}")

        # Add artifacts
        for artifact_name in self.artifacts:
            observer.artifact_event(
                name=artifact_name,
                filename=(self.path / artifact_name)
            )

        # Add resources
        for resource in self.resources:
            observer.resource_event(resource[0])

        # Add metrics
        size_metrics = {}
        # If overwrite, get the already added metrics.
        # FIXME: issue if steps are not increasing
        if overwrite is not None:
            metrics = observer.metrics.find({"run_id": overwrite})
            for metric in metrics:
                size_metrics[metric['name']] = len(metric['steps'])

        log_metrics = []
        for metric_name, metric in self.metrics.items():
            steps = metric['steps'] if metric_name not in size_metrics else metric['steps'][size_metrics[metric_name]:]
            timestamps = metric['timestamps'] if metric_name not in size_metrics else metric['timestamps'][
                                                                                      size_metrics[metric_name]:]
            values = metric['values'] if metric_name not in size_metrics else metric['values'][
                                                                              size_metrics[metric_name]:]
            for step, timestamp, value in zip(steps, timestamps, values):
                metric_log_entry = ScalarMetricLogEntry(metric_name, step,
                                                        datetime.datetime.fromisoformat(timestamp), value)
                log_metrics.append(metric_log_entry)
        observer.log_metrics(linearize_metrics(log_metrics), {})

        observer.heartbeat_event(
            info=self.info if 'info' in self.run else None,
            captured_out=self.cout,
            beat_time=datetime.datetime.fromisoformat(self.heartbeat),
            result=self.result
        )

        # End simulation
        if self.status != "RUNNING":
            stop_time = datetime.datetime.fromisoformat(self.stop_time)

            if self.status in ["COMPLETED", "RUNNING"]:  # If still running we force it as a finished experiment
                observer.completed_event(stop_time, self.result)
            elif self.status == "INTERRUPTED":
                observer.interrupted_event(stop_time, 'INTERRUPTED')
            elif self.status == "FAILED":
                observer.failed_event(stop_time, self.fail_trace)