Exemple #1
0
    def setUpClass(cls):
        env = GymEnv('Pendulum-v0')
        random_pol = RandomPol(cls.env.observation_space, cls.env.action_space)
        sampler = EpiSampler(cls.env, pol, num_parallel=1)
        epis = sampler.sample(pol, max_steps=32)
        traj = Traj()
        traj.add_epis(epis)
        traj.register_epis()

        cls.num_step = traj.num_step

        make_redis('localhost', '6379')
        cls.r = get_redis()

        cls.r.set('env', env)
        cls.r.set('traj', traj)

        pol_net = PolNet(env.observation_space, env.action_space)
        gpol = GaussianPol(env.observation_space, env.action_space, pol_net)
        pol_net = PolNet(env.observation_space,
                         env.action_space, deterministic=True)
        dpol = DeterministicActionNoisePol(
            env.observation_space, env.action_space, pol_net)
        model_net = ModelNet(env.observation_space, env.action_space)
        mpcpol = MPCPol(env.observation_space,
                        env.action_space, model_net, rew_func)
        q_net = QNet(env.observation_space, env.action_space)
        qfunc = DeterministicSAVfunc(
            env.observation_space, env.action_space, q_net)
        aqpol = ArgmaxQfPol(env.observation_space, env.action_space, qfunc)
        v_net = VNet(env.observation_space)
        vfunc = DeterministicSVfunc(env.observation_space, v_net)

        cls.r.set('gpol', cloudpickle.dumps(gpol))
        cls.r.set('dpol', cloudpickle.dumps(dpol))
        cls.r.set('mpcpol', cloudpickle.dumps(mpcpol))
        cls.r.set('qfunc', cloudpickle.dumps(qfunc))
        cls.r.set('aqpol', cloudpickle.dumps(aqpol))
        cls.r.set('vfunc', cloudpickle.dumps(vfunc))

        c2d = C2DEnv(env)
        pol_net = PolNet(c2d.observation_space, c2d.action_space)
        mcpol = MultiCategoricalPol(
            env.observation_space, env.action_space, pol_net)

        cls.r.set('mcpol', cloudpickle.dumps(mcpol))
Exemple #2
0
 def test_distributed_epi_sampler(self):
     proc_redis = subprocess.Popen(['redis-server'])
     proc_slave = subprocess.Popen([
         'python', '-m', 'machina.samplers.distributed_epi_sampler',
         '--world_size', '1', '--rank', '0', '--redis_host', 'localhost',
         '--redis_port', '6379'
     ])
     make_redis('localhost', '6379')
     sampler = DistributedEpiSampler(1,
                                     -1,
                                     self.env,
                                     self.pol,
                                     num_parallel=1)
     epis = sampler.sample(self.pol, max_epis=2)
     assert len(epis) >= 2
     children = psutil.Process(os.getpid()).children(recursive=True)
     for child in children:
         child.send_signal(SIGTERM)
        """
        This method should be called in master node.
        """
        self.pol = pol
        self.max_epis = max_epis // self.world_size if max_epis is not None else None
        self.max_steps = max_steps // self.world_size if max_steps is not None else None
        self.deterministic = deterministic

        self.scatter_from_master('pol')
        self.scatter_from_master('max_epis')
        self.scatter_from_master('max_steps')
        self.scatter_from_master('deterministic')

        self.gather_to_master('epis')

        return self.epis


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', type=int)
    parser.add_argument('--rank', type=int)
    parser.add_argument('--redis_host', type=str, default='localhost')
    parser.add_argument('--redis_port', type=str, default='6379')
    args = parser.parse_args()

    make_redis(args.redis_host, args.redis_port)

    sampler = DistributedEpiSampler(
        args.world_size, args.rank)