Ejemplo n.º 1
0
    def test_len(self):
        buf = experience.ExperienceReplayBuffer(self.source, buffer_size=2)
        self.assertEqual(0, len(buf))
        self.assertEqual([], list(buf))

        buf.populate(1)
        self.assertEqual(1, len(buf))

        buf.populate(2)
        self.assertEqual(2, len(buf))
Ejemplo n.º 2
0
    def test_sample(self):
        buf = experience.ExperienceReplayBuffer(self.source, buffer_size=10)
        buf.populate(10)
        b = buf.sample(4)
        self.assertEqual(4, len(b))

        buf_ids = list(map(id, buf))
        check = list(map(lambda v: id(v) in buf_ids, b))
        self.assertTrue(all(check))

        b = buf.sample(20)
        self.assertEqual(10, len(b))
Ejemplo n.º 3
0
    def test_batches(self):
        buf = experience.ExperienceReplayBuffer(self.source)
        buf.populate(10)

        b = list(buf.batches(batch_size=2))
        self.assertEqual(5, len(b))
        self.assertEqual(2, len(b[0]))

        buf.populate(1)
        b = list(buf.batches(batch_size=2))
        self.assertEqual(5, len(b))

        buf.populate(1)
        b = list(buf.batches(batch_size=2))
        self.assertEqual(6, len(b))

        pass
Ejemplo n.º 4
0
    loss_fn = nn.MSELoss(size_average=False)
    optimizer = optim.Adam(model.parameters(),
                           lr=run.getfloat("learning", "lr"))

    action_selector = ActionSelectorEpsilonGreedy(epsilon=run.getfloat(
        "defaults", "epsilon"),
                                                  params=params)
    target_net = agent.TargetNet(model)
    dqn_agent = agent.DQNAgent(dqn_model=model,
                               action_selector=action_selector)
    exp_source = experience.ExperienceSource(env=env_pool,
                                             agent=dqn_agent,
                                             steps_count=run.getint(
                                                 "defaults", "n_steps"))
    exp_replay = experience.ExperienceReplayBuffer(exp_source,
                                                   buffer_size=run.getint(
                                                       "exp_buffer", "size"))

    use_target_dqn = run.getboolean("dqn", "target_dqn", fallback=False)
    use_double_dqn = run.getboolean("dqn", "double_dqn", fallback=False)

    if use_target_dqn:
        target_model = target_net.target_model
    else:
        target_model = model

    def batch_to_train(batch):
        """
        Convert batch into training data using bellman's equation
        :param batch: list of tuples with Experience instances 
        :return: