예제 #1
0
    def Run(self) -> None:
        """Run the epoch worker thread."""
        rolling_results = batches.RollingResults()

        for i, batch in enumerate(self.batch_iterator.batches):
            self.batch_count += 1
            self.ctx.i += batch.graph_count

            # Record the unique graph IDs.
            for graph_id in batch.graph_ids:
                self.graph_ids.add(graph_id)

            # We have run out of batches.
            if batch.end_of_batches:
                break

            # Skip an empty batch.
            if not batch.graph_count:
                continue

            # Run the batch through the model.
            with self.ctx.Profile(
                    3,
                    lambda t: (f"Batch {i+1} with "
                               f"{batch.graph_count} graphs: "
                               f"{batch_results}"),
            ) as batch_timer:
                batch_results = self.model.RunBatch(self.epoch_type, batch)

            # Record the batch results.
            self.logger.OnBatchEnd(
                run_id=self.model.run_id,
                epoch_type=self.epoch_type,
                epoch_num=self.model.epoch_num,
                batch_num=i + 1,
                timer=batch_timer,
                data=batch,
                results=batch_results,
            )
            rolling_results.Update(batch,
                                   batch_results,
                                   weight=batch_results.target_count)
            self.ctx.bar.set_postfix(
                loss=rolling_results.loss,
                acc=rolling_results.accuracy,
                prec=rolling_results.precision,
                rec=rolling_results.recall,
            )

        self.results = epoch.Results.FromRollingResults(rolling_results)
        self.logger.OnEpochEnd(self.model.run_id, self.epoch_type,
                               self.model.epoch_num, self.results)
예제 #2
0
def test_RollingResults_iteration_count(weight: float):
    """Test aggreation of model iteration count and convergence."""
    rolling_results = batches.RollingResults()

    data = batches.Data(graph_ids=[1], data=None)
    results = batches.Results.Create(
        np.random.rand(1, 10),
        np.random.rand(1, 10),
        iteration_count=1,
        model_converged=True,
    )

    for _ in range(10):
        rolling_results.Update(data, results, weight=weight)

    assert rolling_results.iteration_count == 1
    assert rolling_results.model_converged == 1