Esempio n. 1
0
    # Path(__file__).parent / "../dataset/intersection_4lane_sv"
# ).resolve(), (
    # Path(__file__).parent / "../dataset/intersection_4lane_sv_up"
#     Path(__file__).parent / "../dataset_public/mixed_loop/its_merge_a"
# ).resolve(), (
#     Path(__file__).parent / "../dataset/intersection_4lane_sv_right"
#     Path(__file__).parent / "../dataset_public/mixed_loop/roundabout_its_a"
# ).resolve(), (
#     Path(__file__).parent / "../dataset_public/mixed_loop/roundabout_merge_a"
    Path(__file__).parent / "../dataset/simple"
).resolve()]
print(f"training on {scenario_paths}")

from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG, DQNTrainer, validate_config, execution_plan, get_policy_class
config = DEFAULT_CONFIG.copy()
# config["seed_global"] = 0
DQN = DQNTrainer.with_updates(
    name="DQN_TORCH", default_policy=DQNTorchPolicy, default_config=DEFAULT_CONFIG, get_policy_class=None)

def parse_args():
    parser = argparse.ArgumentParser("train on multi scenarios")

    # env setting
    parser.add_argument("--scenario", type=str, default=None, help="Scenario name")
    parser.add_argument("--exper", type=str, default="multi_scenarios")
    parser.add_argument(
        "--headless", default=False, action="store_true", help="Turn on headless mode"
    )

Esempio n. 2
0
            original_traj['dones'][-1] = True  # set last one to done
            original_traj['rewards'] = np.array(original_traj['rewards'],
                                                dtype=np.float)
            original_traj['rewards'][-1] = \
                1 - 0.9 * goal_step / self.policy.config['horizon']         # change reward
            code.interact(local=locals())
        except:
            pass
        return original_traj


# build postprocess_fn using SamplingStrategy
postprocess_with_HER = build_DQN_HER_postprocess_fn(MiniGridSamplingStrategy)

# Trainer config using Rainbow DQN with HER
HER_RAINBOW_DQN_CONFIG = DEFAULT_CONFIG.copy()
HER_RAINBOW_DQN_CONFIG.update({
    # Common
    "framework": "torch",
    "num_gpus": 1,
    # Model
    # "model": {
    #     "dim": 5,
    #     "conv_filters": [
    #         [2, [5, 5], 1]
    #     ],
    #     "conv_activation": "relu",
    #     "fcnet_hiddens": [4,4],
    #     "max_seq_len": 100
    # },
    # Hindsight Experience Replay