Esempio n. 1
0
    def test_ddpg_trainer(self):
        environment = GridworldContinuous()
        samples = environment.generate_samples(200000, 1.0)
        epochs = 3
        trainer = DDPGTrainer(
            self.get_ddpg_parameters(),
            environment.normalization,
            environment.normalization_action,
        )
        evaluator = GridworldDDPGEvaluator(environment, True)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        for epoch in range(epochs):
            print("On epoch {} of {}".format(epoch + 1, epochs))
            critic_predictor = trainer.predictor()
            evaluator.evaluate_critic(critic_predictor)
            for tdp in tdps:
                training_samples = [
                    tdp.states,
                    tdp.actions,
                    tdp.rewards.flatten(),
                    tdp.next_states,
                    None,
                    1 - tdp.not_terminals.flatten(),  # done
                    None,
                    None,
                    [1 for i in range(len(tdp.states))],  # time diff
                ]
                trainer.train(training_samples)

        critic_predictor = trainer.predictor()
        error = evaluator.evaluate_critic(critic_predictor)
        print("gridworld MAE: {0:.3f}".format(error))
Esempio n. 2
0
    def test_ddpg_trainer(self):
        environment = GridworldContinuous()
        samples = environment.generate_samples(500000, 0.25)
        trainer = DDPGTrainer(
            self.get_ddpg_parameters(),
            environment.normalization,
            environment.normalization_action,
            environment.min_action_range,
            environment.max_action_range,
        )
        evaluator = GridworldDDPGEvaluator(environment, True, DISCOUNT, False,
                                           samples)
        tdps = environment.preprocess_samples(samples, self.minibatch_size)

        critic_predictor = trainer.predictor(actor=False)
        evaluator.evaluate_critic(critic_predictor)
        for tdp in tdps:
            tdp.rewards = tdp.rewards.flatten()
            tdp.not_terminals = tdp.not_terminals.flatten()
            trainer.train(tdp)

        # Make sure actor predictor works
        actor = trainer.predictor(actor=True)
        evaluator.evaluate_actor(actor)

        # Evaluate critic predicor for correctness
        critic_predictor = trainer.predictor(actor=False)
        error = evaluator.evaluate_critic(critic_predictor)
        print("gridworld MAE: {0:.3f}".format(error))