def test_ddpg_trainer(self): environment = GridworldContinuous() samples = environment.generate_samples(200000, 1.0) epochs = 3 trainer = DDPGTrainer( self.get_ddpg_parameters(), environment.normalization, environment.normalization_action, ) evaluator = GridworldDDPGEvaluator(environment, True) tdps = environment.preprocess_samples(samples, self.minibatch_size) for epoch in range(epochs): print("On epoch {} of {}".format(epoch + 1, epochs)) critic_predictor = trainer.predictor() evaluator.evaluate_critic(critic_predictor) for tdp in tdps: training_samples = [ tdp.states, tdp.actions, tdp.rewards.flatten(), tdp.next_states, None, 1 - tdp.not_terminals.flatten(), # done None, None, [1 for i in range(len(tdp.states))], # time diff ] trainer.train(training_samples) critic_predictor = trainer.predictor() error = evaluator.evaluate_critic(critic_predictor) print("gridworld MAE: {0:.3f}".format(error))
def test_ddpg_trainer(self): environment = GridworldContinuous() samples = environment.generate_samples(500000, 0.25) trainer = DDPGTrainer( self.get_ddpg_parameters(), environment.normalization, environment.normalization_action, environment.min_action_range, environment.max_action_range, ) evaluator = GridworldDDPGEvaluator(environment, True, DISCOUNT, False, samples) tdps = environment.preprocess_samples(samples, self.minibatch_size) critic_predictor = trainer.predictor(actor=False) evaluator.evaluate_critic(critic_predictor) for tdp in tdps: tdp.rewards = tdp.rewards.flatten() tdp.not_terminals = tdp.not_terminals.flatten() trainer.train(tdp) # Make sure actor predictor works actor = trainer.predictor(actor=True) evaluator.evaluate_actor(actor) # Evaluate critic predicor for correctness critic_predictor = trainer.predictor(actor=False) error = evaluator.evaluate_critic(critic_predictor) print("gridworld MAE: {0:.3f}".format(error))