예제 #1
0
else:
    ray.init(num_gpus=args.num_gpus)

# Create a debugging friendly instance
if args.debug:
    from tqdm import tqdm
    from pprint import pprint
    trainer = impala.ImpalaAgent(env="dm-" + env_name,
                                 config={
                                     "multiagent": {
                                         "policy_graphs": {
                                             "def_policy":
                                             (PPOPolicyGraph,
                                              Box(0.0,
                                                  255.0,
                                                  shape=(84, 84, 3)),
                                              Discrete(9), {
                                                  "gamma": 0.99
                                              })
                                         },
                                         "policy_mapping_fn":
                                         lambda agent_id: "def_policy",
                                     },
                                 })
    for iter in tqdm(range(args.num_steps), desc="Iters"):
        results = trainer.train()
        if iter % 500 == 0:
            trainer.save("saved_models/multi-carla/" + args.model_arch)
        pprint(results)
else:
    config.update({
예제 #2
0
            })


# Create a debugging friendly instance
if args.debug:
    from tqdm import tqdm
    from pprint import pprint
    trainer = impala.ImpalaAgent(
        env="dm-" + env_name,
        # Use independent policy graphs for each agent
        config={
            "multiagent": {
                "policy_graphs": {
                    id: default_policy()
                    for id in env_actor_configs["actors"].keys()
                },
                "policy_mapping_fn": lambda agent_id: agent_id,
            },
            "env_config": env_actor_configs,
            "num_workers": args.num_workers,
            "num_envs_per_worker": args.envs_per_worker,
            "sample_batch_size": args.sample_bs_per_worker,
            "train_batch_size": args.train_bs
        })
    if args.checkpoint_path and os.path.isfile(args.checkpoint_path):
        trainer.restore(args.checkpoint_path)
        print("Loaded checkpoint from:{}".format(args.checkpoint_path))

    for iter in tqdm(range(args.num_iters), desc="Iters"):
        results = trainer.train()
        if iter % 500 == 0:
    # Use independent policy graphs for each agent
    config = {
        "env": env_name,
        "multiagent": {
            "policy_graphs":
            {id: gen_policy()
             for id in env_actor_configs["actors"].keys()},
            "policy_mapping_fn": tune.function(lambda agent_id: agent_id),
        },
    }
    # Create a debugging friendly instance
    if args.debug:
        from pprint import pprint
        from tqdm import tqdm
        trainer = impala.ImpalaAgent(env=env_name, config=config)
        if args.checkpoint_path and os.path.isfile(args.checkpoint_path):
            trainer.restore(args.checkpoint_path)
            print("Loaded checkpoint from:{}".format(args.checkpoint_path))

        for iter in tqdm(range(args.num_iters), desc="Iters"):
            results = trainer.train()
            if iter % 500 == 0:
                trainer.save("saved_models/multi-carla/" + args.model_arch)
            pprint(results)
    else:
        # Unused exp_spec
        experiment_spec = tune.Experiment(
            "multi-carla/" + args.model_arch,
            "IMPALA",
            # timesteps_total is init with None (not 0) which causes issue
예제 #4
0
    "num_gpus": 2,
    "sample_async": False,
    "sample_batch_size": 20,
    # "use_pytorch": False,
    # "vf_loss_coeff": 0.5,
    # "entropy_coeff": -0.01,
    "gamma": 0.99,
    # "grad_clip": 40.0,
    # "lambda": 1.0,
    "lr": 0.0001,
    "observation_filter": "NoFilter",
    "preprocessor_pref": "rllib",
    "model": model,
    "log_level": "DEBUG"
}
agent = impala.ImpalaAgent(config=config, env="PongDeterministic-v4")
#agent = pg.PGAgent(config=config, env="PongDeterministic-v4")
#agent = a3c.A3CAgent(config=config, env="PongDeterministic-v4")
#agent = dqn.DQNAgent(config=config, env="PongDeterministic-v4")
policy_graph = agent.local_evaluator.policy_map["default"].sess.graph
writer = tf.summary.FileWriter(agent._result_logger.logdir, policy_graph)
writer.close()

while True:
    result = agent.train()
    print(result)
    print("training_iteration", result["training_iteration"])
    print("timesteps this iter", result["timesteps_this_iter"])
    print("timesteps_total", result["timesteps_total"])
    print("time_this_iter_s", result["time_this_iter_s"])
    print("time_total_s", result["time_total_s"])
예제 #5
0
# Create a debugging friendly instance
if args.debug:
    from tqdm import tqdm
    from pprint import pprint
    trainer = impala.ImpalaAgent(env="dm-" + env_name,
                                 config={
                                     "multiagent": {
                                         "policy_graphs": {
                                             "def_policy":
                                             (VTracePolicyGraph,
                                              Box(0.0,
                                                  255.0,
                                                  shape=(84, 84, 3)),
                                              Discrete(9), {
                                                  "gamma": 0.99
                                              })
                                         },
                                         "policy_mapping_fn":
                                         lambda agent_id: "def_policy",
                                     },
                                     "env_config": env_actor_configs,
                                     "num_workers": args.num_workers,
                                     "num_envs_per_worker":
                                     args.envs_per_worker,
                                     "sample_batch_size":
                                     args.sample_bs_per_worker,
                                     "train_batch_size": args.train_bs
                                 })
    if args.checkpoint_path and os.path.isfile(args.checkpoint_path):
        trainer.restore(args.checkpoint_path)
        print("Loaded checkpoint from:{}".format(args.checkpoint_path))