def _update_vf(self, dataset): """Update the value function using a given dataset. The value function is updated via SGD to minimize TD(lambda) errors. """ assert "state" in dataset[0] assert "v_teacher" in dataset[0] for batch in _yield_minibatches( dataset, minibatch_size=self.vf_batch_size, num_epochs=self.vf_epochs ): states = batch_states([b["state"] for b in batch], self.device, self.phi) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) vs_teacher = torch.as_tensor( [b["v_teacher"] for b in batch], device=self.device, dtype=torch.float, ) vs_pred = self.vf(states) vf_loss = F.mse_loss(vs_pred, vs_teacher[..., None]) self.vf.zero_grad() vf_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm) self.vf_optimizer.step()
def test_yield_minibatches_smaller_dataset(): # dataset smaller than minibatch dataset = [1, 2] minibatches = list(ppo._yield_minibatches(dataset, minibatch_size=4, num_epochs=3)) assert len(minibatches) == 2 samples = sum(minibatches, []) assert len(samples) == 8 assert samples.count(1) == 4 assert samples.count(2) == 4
def test_yield_minibatches_divisible(): dataset = [1, 2, 3, 4] minibatches = list(ppo._yield_minibatches(dataset, minibatch_size=2, num_epochs=3)) assert len(minibatches) == 6 samples = sum(minibatches, []) assert len(samples) == 12 assert {1, 2, 3, 4} == set(samples[:4]) assert {1, 2, 3, 4} == set(samples[4:8]) assert {1, 2, 3, 4} == set(samples[8:12])
def test_yield_minibatches_indivisible(): dataset = [1, 2, 3] minibatches = list(ppo._yield_minibatches(dataset, minibatch_size=2, num_epochs=3)) assert len(minibatches) == 5 samples = sum(minibatches, []) assert len(samples) == 10 # samples[:6] is from the first two epochs assert samples[:6].count(1) == 2 assert samples[:6].count(2) == 2 assert samples[:6].count(3) == 2 # samples[6:] is from the final epoch assert 1 <= samples[6:].count(1) <= 2 assert 1 <= samples[6:].count(2) <= 2 assert 1 <= samples[6:].count(3) <= 2