def test_interpretation_heatmap(): data = MDPDataBunch.from_env('maze-random-5x5-v0', render='human') model = DQN(data) learn = AgentLearner(data, model) epochs = 10 callbacks = learn.model.callbacks # type: Collection[LearnerCallback] [c.on_train_begin(learn=learn, n_epochs=epochs) for c in callbacks] for epoch in range(epochs): [c.on_epoch_begin(epoch=epoch) for c in callbacks] learn.model.train() for element in learn.data.train_dl: learn.data.train_ds.actions = learn.predict(element) [c.on_step_end(learn=learn) for c in callbacks] [c.on_epoch_end() for c in callbacks] # For now we are going to avoid executing learner_callbacks here. learn.model.eval() for element in learn.data.valid_dl: learn.data.valid_ds.actions = learn.predict(element) if epoch % 1 == 0: interp = AgentInterpretationAlpha(learn) interp.plot_heatmapped_episode(epoch) [c.on_train_end() for c in callbacks]
def test_interpretation_plot_sequence(): data = MDPDataBunch.from_env('maze-random-5x5-v0', render='human', max_steps=1000) model = DQN(data) learn = AgentLearner(data, model) epochs = 20 callbacks = learn.model.callbacks # type: Collection[LearnerCallback] [c.on_train_begin(learn=learn, n_epochs=epochs) for c in callbacks] for epoch in range(epochs): [c.on_epoch_begin(epoch=epoch) for c in callbacks] learn.model.train() counter = 0 for element in learn.data.train_dl: learn.data.train_ds.actions = learn.predict(element) [c.on_step_end(learn=learn) for c in callbacks] counter += 1 # if counter % 100 == 0:# or counter == 0: interp = AgentInterpretationAlpha(learn, ds_type=DatasetType.Train) interp.plot_heatmapped_episode(epoch) [c.on_epoch_end() for c in callbacks] [c.on_train_end() for c in callbacks]