Esempio n. 1
0
    def train_network(self, train_dataset, eval_dataset, epochs: int):
        num_batches = int(len(train_dataset) / self.minibatch_size)
        logger.info(
            "Read in batch data set of size {} examples. Data split "
            "into {} batches of size {}.".format(
                len(train_dataset), num_batches, self.minibatch_size
            )
        )

        start_time = time.time()
        for epoch in range(epochs):
            train_dataset.reset_iterator()
            data_streamer = DataStreamer(train_dataset, pin_memory=self.trainer.use_gpu)
            preprocess_handler = self.preprocess_handler
            dtype = self.trainer.dtype

            def preprocess(batch):
                tdp = preprocess_handler.preprocess(batch)
                tdp.set_type(dtype)
                return tdp

            feed_pages(
                data_streamer,
                len(train_dataset),
                epoch,
                self.minibatch_size,
                self.trainer.use_gpu,
                TrainingPageHandler(self.trainer),
                batch_preprocessor=preprocess,
            )

            if hasattr(self.trainer, "q_network_cpe"):
                # TODO: Add CPE support to DDPG/SAC, Parametric DQN (once moved to modular)
                eval_dataset.reset_iterator()
                data_streamer = DataStreamer(
                    eval_dataset, pin_memory=self.trainer.use_gpu
                )
                eval_page_handler = EvaluationPageHandler(
                    self.trainer, self.evaluator, self
                )
                feed_pages(
                    data_streamer,
                    len(eval_dataset),
                    epoch,
                    self.minibatch_size,
                    self.trainer.use_gpu,
                    eval_page_handler,
                    batch_preprocessor=preprocess,
                )

                SummaryWriterContext.increase_global_step()

        through_put = (len(train_dataset) * epochs) / (time.time() - start_time)
        logger.info(
            "Training finished. Processed ~{} examples / s.".format(round(through_put))
        )
Esempio n. 2
0
 def test_global_step(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
             SummaryWriterContext.increase_global_step()
             SummaryWriterContext.add_scalar("test", torch.zeros(1))
         writer.add_scalar.assert_has_calls([
             call("test", torch.ones(1), global_step=0),
             call("test", torch.zeros(1), global_step=1),
         ])
         self.assertEqual(2, len(writer.add_scalar.mock_calls))
Esempio n. 3
0
 def test_global_step(self):
     with TemporaryDirectory() as tmp_dir:
         writer = SummaryWriter(tmp_dir)
         writer.add_scalar = MagicMock()
         with summary_writer_context(writer):
             SummaryWriterContext.add_scalar("test", torch.ones(1))
             SummaryWriterContext.increase_global_step()
             SummaryWriterContext.add_scalar("test", torch.zeros(1))
         writer.add_scalar.assert_has_calls(
             [
                 call("test", torch.ones(1), global_step=0),
                 call("test", torch.zeros(1), global_step=1),
             ]
         )
         self.assertEqual(2, len(writer.add_scalar.mock_calls))
Esempio n. 4
0
 def handle(self, tdp: PreprocessedTrainingBatch) -> None:
     SummaryWriterContext.increase_global_step()
     self.trainer_or_evaluator.train(tdp)
Esempio n. 5
0
 def handle(self, tdp: TrainingDataPage) -> None:
     SummaryWriterContext.increase_global_step()
     self.trainer.train(tdp)