def train_network(self, train_dataset, eval_dataset, epochs: int): num_batches = int(len(train_dataset) / self.minibatch_size) logger.info( "Read in batch data set of size {} examples. Data split " "into {} batches of size {}.".format( len(train_dataset), num_batches, self.minibatch_size ) ) start_time = time.time() for epoch in range(epochs): train_dataset.reset_iterator() data_streamer = DataStreamer(train_dataset, pin_memory=self.trainer.use_gpu) preprocess_handler = self.preprocess_handler dtype = self.trainer.dtype def preprocess(batch): tdp = preprocess_handler.preprocess(batch) tdp.set_type(dtype) return tdp feed_pages( data_streamer, len(train_dataset), epoch, self.minibatch_size, self.trainer.use_gpu, TrainingPageHandler(self.trainer), batch_preprocessor=preprocess, ) if hasattr(self.trainer, "q_network_cpe"): # TODO: Add CPE support to DDPG/SAC, Parametric DQN (once moved to modular) eval_dataset.reset_iterator() data_streamer = DataStreamer( eval_dataset, pin_memory=self.trainer.use_gpu ) eval_page_handler = EvaluationPageHandler( self.trainer, self.evaluator, self ) feed_pages( data_streamer, len(eval_dataset), epoch, self.minibatch_size, self.trainer.use_gpu, eval_page_handler, batch_preprocessor=preprocess, ) SummaryWriterContext.increase_global_step() through_put = (len(train_dataset) * epochs) / (time.time() - start_time) logger.info( "Training finished. Processed ~{} examples / s.".format(round(through_put)) )
def test_basic_two_workers(self): data = self.get_test_data(1000) batch_size = 100 num_shards = 2 reader = NpArrayReader(data, batch_size=batch_size, num_shards=num_shards) splits = [ reader._get_split(reader.data, i, batch_size) for i in range(10) ] streamer = DataStreamer(reader, num_workers=num_shards) for _i, batch in enumerate(streamer): match = False for split in splits: try: self.assert_batch_equal(split, batch, 0, batch_size) except Exception: pass else: match = True break self.assertTrue(match) self.assertEqual(9, _i)
def test_drop_small_one_worker(self): data = self.get_test_data(999) batch_size = 100 reader = NpArrayReader(data, batch_size=batch_size, num_shards=1) streamer = DataStreamer(reader, num_workers=1) for i, batch in enumerate(streamer): self.assert_batch_equal(data, batch, i * batch_size, batch_size) self.assertEqual(8, i)
def test_basic(self): data = self.get_test_data(1000) batch_size = 100 reader = NpArrayReader(data, batch_size=batch_size) streamer = DataStreamer(reader) for i, batch in enumerate(streamer): self.assert_batch_equal(data, batch, i * batch_size, batch_size) self.assertEqual(9, i)
def test_not_drop_small(self): data = self.get_test_data(999) batch_size = 100 reader = NpArrayReader(data, batch_size=batch_size, drop_small=False) streamer = DataStreamer(reader) for i, batch in enumerate(streamer): self.assert_batch_equal(data, batch, i * batch_size, batch_size if i != 9 else 99) self.assertEqual(9, i)