def test_getitem(numpy_replay: NumpyReplayBuffer, sample_batch: SampleBatch, idx): replay = numpy_replay batch = replay[idx] assert isinstance(batch, dict) assert all([ np.allclose(batch[k], sample_batch[k][idx]) for k in sample_batch.keys() ]) mean = np.mean(sample_batch[SampleBatch.CUR_OBS], axis=0) std = np.std(sample_batch[SampleBatch.CUR_OBS], axis=0) replay.update_obs_stats() batch = replay[idx] for key in SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS: expected = (sample_batch[key][idx] - mean) / (std + 1e-7) assert np.allclose(batch[key], expected)
def test_getitem(filled_replay: NumpyReplayBuffer, sample_batch: SampleBatch, idx): replay = filled_replay batch = replay[idx] assert isinstance(batch, dict) assert all([ np.allclose(batch[k], sample_batch[k][idx]) for k in sample_batch.keys() ]) mean = np.mean(sample_batch[SampleBatch.CUR_OBS], axis=0) std = np.std(sample_batch[SampleBatch.CUR_OBS], axis=0) std[std < 1e-12] = 1.0 replay.compute_stats = True batch = replay[idx] for key in SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS: expected = (sample_batch[key][idx] - mean) / std assert np.allclose(batch[key], expected)