def Run(self) -> None: """Run the epoch worker thread.""" rolling_results = batches.RollingResults() for i, batch in enumerate(self.batch_iterator.batches): self.batch_count += 1 self.ctx.i += batch.graph_count # Record the unique graph IDs. for graph_id in batch.graph_ids: self.graph_ids.add(graph_id) # We have run out of batches. if batch.end_of_batches: break # Skip an empty batch. if not batch.graph_count: continue # Run the batch through the model. with self.ctx.Profile( 3, lambda t: (f"Batch {i+1} with " f"{batch.graph_count} graphs: " f"{batch_results}"), ) as batch_timer: batch_results = self.model.RunBatch(self.epoch_type, batch) # Record the batch results. self.logger.OnBatchEnd( run_id=self.model.run_id, epoch_type=self.epoch_type, epoch_num=self.model.epoch_num, batch_num=i + 1, timer=batch_timer, data=batch, results=batch_results, ) rolling_results.Update(batch, batch_results, weight=batch_results.target_count) self.ctx.bar.set_postfix( loss=rolling_results.loss, acc=rolling_results.accuracy, prec=rolling_results.precision, rec=rolling_results.recall, ) self.results = epoch.Results.FromRollingResults(rolling_results) self.logger.OnEpochEnd(self.model.run_id, self.epoch_type, self.model.epoch_num, self.results)
def test_RollingResults_iteration_count(weight: float): """Test aggreation of model iteration count and convergence.""" rolling_results = batches.RollingResults() data = batches.Data(graph_ids=[1], data=None) results = batches.Results.Create( np.random.rand(1, 10), np.random.rand(1, 10), iteration_count=1, model_converged=True, ) for _ in range(10): rolling_results.Update(data, results, weight=weight) assert rolling_results.iteration_count == 1 assert rolling_results.model_converged == 1