Ejemplo n.º 1
0
    def test_config_init(_):
        c = TestARS.c
        config = ARS.generate_config({})
        config["frame_config"]["models"] = ["ActorDiscrete"]
        config["frame_config"]["model_kwargs"] = [{
            "state_dim": c.observe_dim,
            "action_dim": c.action_num
        }]
        ars = ARS.init_from_config(config)

        for at in ars.get_actor_types():
            # get action will cause filters to initialize
            _action = ars.act(
                {"state": t.zeros([1, c.observe_dim], dtype=t.float32)}, at)
            if at.startswith("neg"):
                ars.store_reward(1.0, at)
            else:
                ars.store_reward(0.0, at)
        ars.update()
        return True
Ejemplo n.º 2
0
 def ars_lr(device, dtype):
     c = TestARS.c
     actor = smw(
         ActorDiscrete(c.observe_dim, c.action_num).type(dtype).to(device),
         device, device)
     lr_func = gen_learning_rate_func([(0, 1e-3), (200000, 3e-4)],
                                      logger=default_logger)
     servers = model_server_helper(model_num=1)
     world = get_world()
     ars_group = world.create_rpc_group("ars", ["0", "1", "2"])
     ars = ARS(actor,
               t.optim.SGD,
               ars_group,
               servers,
               noise_size=1000000,
               lr_scheduler=LambdaLR,
               lr_scheduler_args=((lr_func, ), ))
     return ars
Ejemplo n.º 3
0
 def ars(device, dtype):
     c = TestARS.c
     actor = smw(
         ActorDiscrete(c.observe_dim, c.action_num).type(dtype).to(device),
         device, device)
     servers = model_server_helper(model_num=1)
     world = get_world()
     ars_group = world.create_rpc_group("ars", ["0", "1", "2"])
     ars = ARS(actor,
               t.optim.SGD,
               ars_group,
               servers,
               noise_std_dev=0.1,
               learning_rate=0.1,
               noise_size=1000000,
               rollout_num=6,
               used_rollout_num=6,
               normalize_state=True)
     return ars