示例#1
0
    def train_network(self, train_dataset, eval_dataset, epochs: int):
        num_batches = int(len(train_dataset) / self.minibatch_size)
        logger.info(
            "Read in batch data set of size {} examples. Data split "
            "into {} batches of size {}.".format(
                len(train_dataset), num_batches, self.minibatch_size
            )
        )

        start_time = time.time()
        for epoch in range(epochs):
            train_dataset.reset_iterator()
            data_streamer = DataStreamer(train_dataset, pin_memory=self.trainer.use_gpu)
            preprocess_handler = self.preprocess_handler
            dtype = self.trainer.dtype

            def preprocess(batch):
                tdp = preprocess_handler.preprocess(batch)
                tdp.set_type(dtype)
                return tdp

            feed_pages(
                data_streamer,
                len(train_dataset),
                epoch,
                self.minibatch_size,
                self.trainer.use_gpu,
                TrainingPageHandler(self.trainer),
                batch_preprocessor=preprocess,
            )

            if hasattr(self.trainer, "q_network_cpe"):
                # TODO: Add CPE support to DDPG/SAC, Parametric DQN (once moved to modular)
                eval_dataset.reset_iterator()
                data_streamer = DataStreamer(
                    eval_dataset, pin_memory=self.trainer.use_gpu
                )
                eval_page_handler = EvaluationPageHandler(
                    self.trainer, self.evaluator, self
                )
                feed_pages(
                    data_streamer,
                    len(eval_dataset),
                    epoch,
                    self.minibatch_size,
                    self.trainer.use_gpu,
                    eval_page_handler,
                    batch_preprocessor=preprocess,
                )

                SummaryWriterContext.increase_global_step()

        through_put = (len(train_dataset) * epochs) / (time.time() - start_time)
        logger.info(
            "Training finished. Processed ~{} examples / s.".format(round(through_put))
        )
    def test_basic_two_workers(self):
        data = self.get_test_data(1000)
        batch_size = 100
        num_shards = 2
        reader = NpArrayReader(data,
                               batch_size=batch_size,
                               num_shards=num_shards)

        splits = [
            reader._get_split(reader.data, i, batch_size) for i in range(10)
        ]

        streamer = DataStreamer(reader, num_workers=num_shards)
        for _i, batch in enumerate(streamer):
            match = False
            for split in splits:
                try:
                    self.assert_batch_equal(split, batch, 0, batch_size)
                except Exception:
                    pass
                else:
                    match = True
                    break
            self.assertTrue(match)
        self.assertEqual(9, _i)
 def test_drop_small_one_worker(self):
     data = self.get_test_data(999)
     batch_size = 100
     reader = NpArrayReader(data, batch_size=batch_size, num_shards=1)
     streamer = DataStreamer(reader, num_workers=1)
     for i, batch in enumerate(streamer):
         self.assert_batch_equal(data, batch, i * batch_size, batch_size)
     self.assertEqual(8, i)
 def test_basic(self):
     data = self.get_test_data(1000)
     batch_size = 100
     reader = NpArrayReader(data, batch_size=batch_size)
     streamer = DataStreamer(reader)
     for i, batch in enumerate(streamer):
         self.assert_batch_equal(data, batch, i * batch_size, batch_size)
     self.assertEqual(9, i)
 def test_not_drop_small(self):
     data = self.get_test_data(999)
     batch_size = 100
     reader = NpArrayReader(data, batch_size=batch_size, drop_small=False)
     streamer = DataStreamer(reader)
     for i, batch in enumerate(streamer):
         self.assert_batch_equal(data, batch, i * batch_size,
                                 batch_size if i != 9 else 99)
     self.assertEqual(9, i)