Exemplo n.º 1
0
    def test_synchronize(self):
        """Synchronize applies filter buffer onto own filter"""
        filt1 = MeanStdFilter(())
        for i in range(10):
            filt1(i)
        self.assertEqual(filt1.rs.n, 10)
        filt1.clear_buffer()
        self.assertEqual(filt1.buffer.n, 0)

        RemoteWorker = ray.remote(_MockWorker)
        remote_e = RemoteWorker.remote(sample_count=10)
        remote_e.sample.remote()

        FilterManager.synchronize(
            {
                "obs_filter": filt1,
                "rew_filter": filt1.copy()
            }, [remote_e])

        filters = ray.get(remote_e.get_filters.remote())
        obs_f = filters["obs_filter"]
        self.assertEqual(filt1.rs.n, 20)
        self.assertEqual(filt1.buffer.n, 0)
        self.assertEqual(obs_f.rs.n, filt1.rs.n)
        self.assertEqual(obs_f.buffer.n, filt1.buffer.n)
Exemplo n.º 2
0
class _MockWorker(object):
    def __init__(self, sample_count=10):
        self._weights = np.array([-10, -10, -10, -10])
        self._grad = np.array([1, 1, 1, 1])
        self._sample_count = sample_count
        self.obs_filter = MeanStdFilter(())
        self.rew_filter = MeanStdFilter(())
        self.filters = {
            "obs_filter": self.obs_filter,
            "rew_filter": self.rew_filter
        }

    def sample(self):
        samples_dict = {"observations": [], "rewards": []}
        for i in range(self._sample_count):
            samples_dict["observations"].append(
                self.obs_filter(np.random.randn()))
            samples_dict["rewards"].append(self.rew_filter(np.random.randn()))
        return SampleBatch(samples_dict)

    def compute_gradients(self, samples):
        return self._grad * samples.count, {"batch_count": samples.count}

    def apply_gradients(self, grads):
        self._weights += self._grad

    def get_weights(self):
        return self._weights

    def set_weights(self, weights):
        self._weights = weights

    def get_filters(self, flush_after=False):
        obs_filter = self.obs_filter.copy()
        rew_filter = self.rew_filter.copy()
        if flush_after:
            self.obs_filter.clear_buffer(), self.rew_filter.clear_buffer()

        return {"obs_filter": obs_filter, "rew_filter": rew_filter}

    def sync_filters(self, new_filters):
        assert all(k in new_filters for k in self.filters)
        for k in self.filters:
            self.filters[k].sync(new_filters[k])
Exemplo n.º 3
0
class _MockEvaluator(object):
    def __init__(self, sample_count=10):
        self._weights = np.array([-10, -10, -10, -10])
        self._grad = np.array([1, 1, 1, 1])
        self._sample_count = sample_count
        self.obs_filter = MeanStdFilter(())
        self.rew_filter = MeanStdFilter(())
        self.filters = {
            "obs_filter": self.obs_filter,
            "rew_filter": self.rew_filter
        }

    def sample(self):
        samples_dict = {"observations": [], "rewards": []}
        for i in range(self._sample_count):
            samples_dict["observations"].append(
                self.obs_filter(np.random.randn()))
            samples_dict["rewards"].append(self.rew_filter(np.random.randn()))
        return SampleBatch(samples_dict)

    def compute_gradients(self, samples):
        return self._grad * samples.count, {"batch_count": samples.count}

    def apply_gradients(self, grads):
        self._weights += self._grad

    def get_weights(self):
        return self._weights

    def set_weights(self, weights):
        self._weights = weights

    def get_filters(self, flush_after=False):
        obs_filter = self.obs_filter.copy()
        rew_filter = self.rew_filter.copy()
        if flush_after:
            self.obs_filter.clear_buffer(), self.rew_filter.clear_buffer()

        return {"obs_filter": obs_filter, "rew_filter": rew_filter}

    def sync_filters(self, new_filters):
        assert all(k in new_filters for k in self.filters)
        for k in self.filters:
            self.filters[k].sync(new_filters[k])
Exemplo n.º 4
0
    def testBasic(self):
        for shape in [(), (3, ), (3, 4, 4)]:
            filt = MeanStdFilter(shape)
            for i in range(5):
                filt(np.ones(shape))
            self.assertEqual(filt.rs.n, 5)
            self.assertEqual(filt.buffer.n, 5)

            filt2 = MeanStdFilter(shape)
            filt2.sync(filt)
            self.assertEqual(filt2.rs.n, 5)
            self.assertEqual(filt2.buffer.n, 5)

            filt.clear_buffer()
            self.assertEqual(filt.buffer.n, 0)
            self.assertEqual(filt2.buffer.n, 5)

            filt.apply_changes(filt2, with_buffer=False)
            self.assertEqual(filt.buffer.n, 0)
            self.assertEqual(filt.rs.n, 10)

            filt.apply_changes(filt2, with_buffer=True)
            self.assertEqual(filt.buffer.n, 5)
            self.assertEqual(filt.rs.n, 15)
Exemplo n.º 5
0
    def testSynchronize(self):
        """Synchronize applies filter buffer onto own filter"""
        filt1 = MeanStdFilter(())
        for i in range(10):
            filt1(i)
        self.assertEqual(filt1.rs.n, 10)
        filt1.clear_buffer()
        self.assertEqual(filt1.buffer.n, 0)

        RemoteEvaluator = ray.remote(_MockEvaluator)
        remote_e = RemoteEvaluator.remote(sample_count=10)
        remote_e.sample.remote()

        FilterManager.synchronize({
            "obs_filter": filt1,
            "rew_filter": filt1.copy()
        }, [remote_e])

        filters = ray.get(remote_e.get_filters.remote())
        obs_f = filters["obs_filter"]
        self.assertEqual(filt1.rs.n, 20)
        self.assertEqual(filt1.buffer.n, 0)
        self.assertEqual(obs_f.rs.n, filt1.rs.n)
        self.assertEqual(obs_f.buffer.n, filt1.buffer.n)
Exemplo n.º 6
0
    def testBasic(self):
        for shape in [(), (3, ), (3, 4, 4)]:
            filt = MeanStdFilter(shape)
            for i in range(5):
                filt(np.ones(shape))
            self.assertEqual(filt.rs.n, 5)
            self.assertEqual(filt.buffer.n, 5)

            filt2 = MeanStdFilter(shape)
            filt2.sync(filt)
            self.assertEqual(filt2.rs.n, 5)
            self.assertEqual(filt2.buffer.n, 5)

            filt.clear_buffer()
            self.assertEqual(filt.buffer.n, 0)
            self.assertEqual(filt2.buffer.n, 5)

            filt.apply_changes(filt2, with_buffer=False)
            self.assertEqual(filt.buffer.n, 0)
            self.assertEqual(filt.rs.n, 10)

            filt.apply_changes(filt2, with_buffer=True)
            self.assertEqual(filt.buffer.n, 5)
            self.assertEqual(filt.rs.n, 15)