コード例 #1
0
               policy_noise_func=policy_noise)

    if c.restart_from_trial is not None:
        ddpg.load(save_env.get_trial_model_dir())
    logger.info("DDPG framework initialized")

    # training
    # preparations
    env = BipedalWalker()

    # begin training
    episode = Counter()
    episode_finished = False
    global_step = Counter()
    local_step = Counter()
    timer = Timer()

    while episode < c.max_episodes:
        episode.count()
        logger.info("Begin episode {} at {}".format(
            episode,
            dt.now().strftime("%m/%d-%H:%M:%S")))

        # environment initialization
        env.reset()

        # render configuration
        if episode.get() % c.profile_int == 0 and global_step.get(
        ) > c.ddpg_warmup_steps:
            render = True
        else:
コード例 #2
0
              batch_size=c.ppo_update_batch_size,
              learning_rate=c.learning_rate)

    if c.restart_from_trial is not None:
        ppo.load(save_env.get_trial_model_dir())
    logger.info("PPO framework initialized")

    # training
    # preparations
    ctx = get_context("spawn")
    pool = Pool(processes=c.workers, context=ctx)
    pool.enable_global_find(True)

    # begin training
    episode = Counter(step=c.ppo_update_int)
    timer = Timer()

    while episode < c.max_episodes:
        first_episode = episode.get()
        episode.count()
        last_episode = episode.get() - 1
        logger.info("Begin episode {}-{} at {}".format(
            first_episode, last_episode,
            dt.now().strftime("%m/%d-%H:%M:%S")))

        # begin trials
        def run_trial(episode_num):
            # TODO: agent_num cannot be pickled ?
            env = BipedalMultiCarrier(agent_num=c.agent_num)

            # render configuration
コード例 #3
0
                batch_size=c.ddpg_update_batch_size,
                learning_rate=0.001)

    if c.restart_from_trial is not None:
        ddpg.load(save_env.get_trial_model_dir())
    logger.info("DDPG framework initialized")

    # training
    # preparations
    env = BipedalWalker()

    # begin training
    episode = Counter()
    episode_finished = False
    local_step = Counter()
    timer = Timer()

    while episode < c.max_episodes:
        episode.count()
        logger.info("Begin episode {} at {}".format(episode, dt.now().strftime("%m/%d-%H:%M:%S")))

        # environment initialization
        env.reset()

        # batch size = 1
        total_reward = 0
        state, reward = t.tensor(env.reset(), dtype=t.float32, device=c.device), 0

        while not episode_finished and local_step.get() <= c.max_steps:
            local_step.count()
コード例 #4
0
    writer = global_board.writer
    logger.info("Directories prepared.")

    ppo = create_models()
    logger.info("PPO framework initialized")

    # training
    # preparations
    ctx = get_context("spawn")
    pool = Pool(processes=c.workers, context=ctx)
    pool.enable_global_find(True)
    pool.enable_copy_tensors(False)

    # begin training
    episode = Counter(step=c.ppo_update_int)
    timer = Timer()

    while episode < c.max_episodes:
        first_episode = episode.get()
        episode.count()
        last_episode = episode.get() - 1
        logger.info("Begin episode {}-{} at {}".format(first_episode, last_episode,
                                                       dt.now().strftime("%m/%d-%H:%M:%S")))

        # begin trials
        def run_trial(episode_num):
            config = generate_combat_config(c.map_size)
            env = magent.GridWorld(config, map_size=c.map_size)
            env.reset()

            group_handles = env.get_handles()
コード例 #5
0
    ddpg = create_models()
    logger.info("DDPG framework initialized")

    # training
    # preparations
    config = generate_combat_config(c.map_size)
    env = magent.GridWorld(config, map_size=c.map_size)
    env.reset()

    # begin training
    episode = Counter()
    episode_finished = False
    global_step = Counter()
    local_step = Counter()
    timer = Timer()

    while episode < c.max_episodes:
        episode.count()
        logger.info("Begin episode {} at {}".format(
            episode,
            dt.now().strftime("%m/%d-%H:%M:%S")))

        # environment initialization
        env.reset()

        group_handles = env.get_handles()
        generate_combat_map(env, c.map_size, c.agent_ratio, group_handles[0],
                            group_handles[1])

        # render configuration