def train_network(self, train_dataset, eval_dataset, epochs: int): num_batches = int(len(train_dataset) / self.minibatch_size) logger.info( "Read in batch data set of size {} examples. Data split " "into {} batches of size {}.".format( len(train_dataset), num_batches, self.minibatch_size ) ) start_time = time.time() for epoch in range(epochs): train_dataset.reset_iterator() data_streamer = DataStreamer(train_dataset, pin_memory=self.trainer.use_gpu) preprocess_handler = self.preprocess_handler dtype = self.trainer.dtype def preprocess(batch): tdp = preprocess_handler.preprocess(batch) tdp.set_type(dtype) return tdp feed_pages( data_streamer, len(train_dataset), epoch, self.minibatch_size, self.trainer.use_gpu, TrainingPageHandler(self.trainer), batch_preprocessor=preprocess, ) if hasattr(self.trainer, "q_network_cpe"): # TODO: Add CPE support to DDPG/SAC, Parametric DQN (once moved to modular) eval_dataset.reset_iterator() data_streamer = DataStreamer( eval_dataset, pin_memory=self.trainer.use_gpu ) eval_page_handler = EvaluationPageHandler( self.trainer, self.evaluator, self ) feed_pages( data_streamer, len(eval_dataset), epoch, self.minibatch_size, self.trainer.use_gpu, eval_page_handler, batch_preprocessor=preprocess, ) SummaryWriterContext.increase_global_step() through_put = (len(train_dataset) * epochs) / (time.time() - start_time) logger.info( "Training finished. Processed ~{} examples / s.".format(round(through_put)) )
def test_global_step(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1)) SummaryWriterContext.increase_global_step() SummaryWriterContext.add_scalar("test", torch.zeros(1)) writer.add_scalar.assert_has_calls([ call("test", torch.ones(1), global_step=0), call("test", torch.zeros(1), global_step=1), ]) self.assertEqual(2, len(writer.add_scalar.mock_calls))
def test_global_step(self): with TemporaryDirectory() as tmp_dir: writer = SummaryWriter(tmp_dir) writer.add_scalar = MagicMock() with summary_writer_context(writer): SummaryWriterContext.add_scalar("test", torch.ones(1)) SummaryWriterContext.increase_global_step() SummaryWriterContext.add_scalar("test", torch.zeros(1)) writer.add_scalar.assert_has_calls( [ call("test", torch.ones(1), global_step=0), call("test", torch.zeros(1), global_step=1), ] ) self.assertEqual(2, len(writer.add_scalar.mock_calls))
def handle(self, tdp: PreprocessedTrainingBatch) -> None: SummaryWriterContext.increase_global_step() self.trainer_or_evaluator.train(tdp)
def handle(self, tdp: TrainingDataPage) -> None: SummaryWriterContext.increase_global_step() self.trainer.train(tdp)