Пример #1
0
    def _update_vf(self, dataset):
        """Update the value function using a given dataset.

        The value function is updated via SGD to minimize TD(lambda) errors.
        """

        assert "state" in dataset[0]
        assert "v_teacher" in dataset[0]

        for batch in _yield_minibatches(
            dataset, minibatch_size=self.vf_batch_size, num_epochs=self.vf_epochs
        ):
            states = batch_states([b["state"] for b in batch], self.device, self.phi)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            vs_teacher = torch.as_tensor(
                [b["v_teacher"] for b in batch],
                device=self.device,
                dtype=torch.float,
            )
            vs_pred = self.vf(states)
            vf_loss = F.mse_loss(vs_pred, vs_teacher[..., None])
            self.vf.zero_grad()
            vf_loss.backward()
            if self.max_grad_norm is not None:
                clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm)
            self.vf_optimizer.step()
Пример #2
0
def test_yield_minibatches_smaller_dataset():
    # dataset smaller than minibatch
    dataset = [1, 2]
    minibatches = list(ppo._yield_minibatches(dataset, minibatch_size=4, num_epochs=3))
    assert len(minibatches) == 2
    samples = sum(minibatches, [])
    assert len(samples) == 8
    assert samples.count(1) == 4
    assert samples.count(2) == 4
Пример #3
0
def test_yield_minibatches_divisible():
    dataset = [1, 2, 3, 4]
    minibatches = list(ppo._yield_minibatches(dataset, minibatch_size=2, num_epochs=3))
    assert len(minibatches) == 6
    samples = sum(minibatches, [])
    assert len(samples) == 12
    assert {1, 2, 3, 4} == set(samples[:4])
    assert {1, 2, 3, 4} == set(samples[4:8])
    assert {1, 2, 3, 4} == set(samples[8:12])
Пример #4
0
def test_yield_minibatches_indivisible():
    dataset = [1, 2, 3]
    minibatches = list(ppo._yield_minibatches(dataset, minibatch_size=2, num_epochs=3))
    assert len(minibatches) == 5
    samples = sum(minibatches, [])
    assert len(samples) == 10
    # samples[:6] is from the first two epochs
    assert samples[:6].count(1) == 2
    assert samples[:6].count(2) == 2
    assert samples[:6].count(3) == 2
    # samples[6:] is from the final epoch
    assert 1 <= samples[6:].count(1) <= 2
    assert 1 <= samples[6:].count(2) <= 2
    assert 1 <= samples[6:].count(3) <= 2