コード例 #1
0
ファイル: export_swarm.py プロジェクト: Guillemdb/fragile
    def get_empty_export_walkers(self) -> ExportedWalkers:
        """
        Return a :class:`ExportedWalkers` with no walkers inside.

        Used to initialize the algorithm.
        """
        return ExportedWalkers(0)
コード例 #2
0
 def test_cross_fai_iteration(self, export_swarm):
     walkers = ExportedWalkers(export_swarm.n_import)
     local_ix, import_ix = export_swarm._get_merge_indexes(walkers)
     compas_ix, will_clone = export_swarm._cross_fai_iteration(
         local_ix=local_ix, import_ix=import_ix, walkers=walkers
     )
     assert len(compas_ix) == export_swarm.n_import
コード例 #3
0
 def reset(self, *args, **kwargs):
     self.swarm.reset(*args, **kwargs)
     self.param_server.best.update(
         states=self.swarm.walkers.states.best_state,
         id_walkers=self.swarm.walkers.states.best_id,
         rewards=self.swarm.walkers.states.best_reward,
         observs=self.swarm.walkers.states.best_obs,
     )
     self._exchange_next = ExportedWalkers(0)
コード例 #4
0
 def __init__(self,
              swarm,
              max_len: int = 20,
              add_global_best: bool = True,
              *args,
              **kwargs):
     super(ExportDummy, self).__init__(swarm=swarm, *args, **kwargs)
     self.param_server = ParamServer(max_len=max_len,
                                     add_global_best=add_global_best,
                                     minimize=self.swarm.walkers.minimize)
     self._exchange_next = ExportedWalkers(0)
コード例 #5
0
    def test_clone_to_imported(self, export_swarm):
        walkers = ExportedWalkers(3)
        walkers.rewards = tensor([999, 777, 333], dtype=dtype.float)
        walkers.states = tensor([999, 777, 333], dtype=dtype.float)
        walkers.id_walkers = tensor([999, 777, 333], dtype=dtype.float)
        walkers.observs = tensor(
            [[999, 999, 999, 999], [777, 777, 777, 777], [333, 333, 333, 333]],
            dtype=dtype.float)

        compas_ix = tensor([0, 1])
        will_clone = tensor([True, False])
        local_ix = tensor([0, 1])
        import_ix = tensor([0, 1])

        export_swarm._clone_to_imported(
            compas_ix=compas_ix,
            will_clone=will_clone,
            local_ix=local_ix,
            import_ix=import_ix,
            walkers=walkers,
        )
        assert export_swarm.walkers.states.cum_rewards[0] == 999.0
        assert export_swarm.walkers.env_states.states[0] == 999.0
        assert (export_swarm.walkers.env_states.observs[0] == judo.ones(4) *
                999).all()
コード例 #6
0
    def test_clone_to_imported(self, export_swarm):
        walkers = ExportedWalkers(3)
        walkers.rewards = numpy.array([999, 777, 333])
        walkers.states = numpy.array([999, 777, 333])
        walkers.id_walkers = numpy.array([999, 777, 333])
        walkers.observs = numpy.array([[999, 999, 999, 999],
                                       [777, 777, 777, 777],
                                       [333, 333, 333, 333]])

        compas_ix = numpy.array([0, 1])
        will_clone = numpy.array([True, False])
        local_ix = numpy.array([0, 1])
        import_ix = numpy.array([0, 1])

        export_swarm._clone_to_imported(
            compas_ix=compas_ix,
            will_clone=will_clone,
            local_ix=local_ix,
            import_ix=import_ix,
            walkers=walkers,
        )
        assert export_swarm.walkers.states.cum_rewards[0] == 999
        assert export_swarm.walkers.states.id_walkers[0] == 999
        assert export_swarm.walkers.env_states.states[0] == 999
        assert (export_swarm.walkers.env_states.observs[0] == numpy.ones(4) *
                999).all()
コード例 #7
0
 def test_import_best(self, export_swarm):
     walkers = ExportedWalkers(2)
     walkers.rewards = numpy.array([999, 2])
     walkers.states = numpy.array([0, 1])
     walkers.id_walkers = numpy.array([10, 11])
     walkers.observs = numpy.array([[0, 0, 0, 0], [2, 3, 1, 2]])
     export_swarm.import_best(walkers)
     assert export_swarm.best_reward == 999
     assert export_swarm.walkers.states.best_state == walkers.states[0]
     assert (export_swarm.walkers.states.best_obs == walkers.observs[0]).all()
     assert export_swarm.walkers.states.best_id == walkers.id_walkers[0]
コード例 #8
0
 def test_import_best(self, export_swarm):
     walkers = ExportedWalkers(2)
     walkers.rewards = tensor([999.0, 2.0])
     walkers.states = tensor([0.0, 1.0])
     walkers.id_walkers = tensor([10.0, 11.0])
     walkers.observs = tensor([[0, 0, 0, 0], [2, 3, 1, 2]],
                              dtype=dtype.float)
     export_swarm.import_best(walkers)
     assert export_swarm.best_reward == 999
     assert export_swarm.walkers.states.best_state == walkers.states[0]
     assert (
         export_swarm.walkers.states.best_obs == walkers.observs[0]).all()
     assert export_swarm.walkers.states.best_id == walkers.id_walkers[0]
コード例 #9
0
 def test_imported_best_is_better(self, export_swarm):
     export_swarm.reset()
     export_swarm.run_step()
     walkers = ExportedWalkers(1)
     walkers.rewards = tensor([1]) * numpy.inf
     new_is_better = export_swarm._imported_best_is_better(walkers)
     assert new_is_better, export_swarm.best_reward
     walkers = ExportedWalkers(1)
     export_swarm.walkers.minimize = True
     walkers.rewards = tensor([1]) * numpy.NINF
     new_is_better = export_swarm._imported_best_is_better(walkers)
     assert new_is_better, export_swarm.best_reward
     export_swarm.walkers.minimize = False
コード例 #10
0
 def test_imported_best_is_better(self, export_swarm):
     export_swarm.reset()
     export_swarm.run_step()
     walkers = ExportedWalkers(1)
     walkers.rewards = numpy.array([numpy.inf])
     new_is_better = export_swarm._imported_best_is_better(walkers)
     assert new_is_better, export_swarm.best_reward
     walkers = ExportedWalkers(1)
     export_swarm.walkers.minimize = True
     walkers.rewards = numpy.array([-numpy.inf])
     new_is_better = export_swarm._imported_best_is_better(walkers)
     assert new_is_better, export_swarm.best_reward
     export_swarm.walkers.minimize = False
コード例 #11
0
    def test_run_exchange_step(self, export_swarm):
        export_swarm.reset()

        walkers_0 = ExportedWalkers(0)
        exported = export_swarm.run_exchange_step(walkers_0)
        assert len(exported) == export_swarm.n_export
        walkers = ExportedWalkers(3)
        walkers.rewards = tensor([999, 777, 333], dtype=dtype.float)
        walkers.states = tensor(
            [[999, 999, 999, 999], [777, 777, 777, 777], [333, 333, 333, 333]],
            dtype=dtype.float)
        walkers.id_walkers = tensor([999, 777, 333], dtype=dtype.float)
        walkers.observs = tensor(
            [[999, 999, 999, 999], [777, 777, 777, 777], [333, 333, 333, 333]],
            dtype=dtype.float)
        export_swarm.reset()
        exported = export_swarm.run_exchange_step(walkers)
        assert len(exported) == export_swarm.n_export
        assert export_swarm.best_reward == 999.0
コード例 #12
0
    def test_run_exchange_step(self, export_swarm):
        export_swarm.reset()

        walkers_0 = ExportedWalkers(0)
        exported = export_swarm.run_exchange_step(walkers_0)
        assert len(exported) == export_swarm.n_export
        walkers = ExportedWalkers(3)
        walkers.rewards = numpy.array([999, 777, 333])
        walkers.states = numpy.array([[999, 999, 999,
                                       999], [777, 777, 777, 777],
                                      [333, 333, 333, 333]])
        walkers.id_walkers = numpy.array([999, 777, 333])
        walkers.observs = numpy.array([[999, 999, 999, 999],
                                       [777, 777, 777, 777],
                                       [333, 333, 333, 333]])
        export_swarm.reset()
        exported = export_swarm.run_exchange_step(walkers)
        assert len(exported) == export_swarm.n_export
        assert export_swarm.best_reward == 999
コード例 #13
0
 def test_get_merge_indexes(self, export_swarm):
     walkers = ExportedWalkers(2)
     local_ix, import_ix = export_swarm._get_merge_indexes(walkers)
     assert len(local_ix) == len(import_ix)
     assert len(local_ix) == export_swarm.n_import