def make_toy_mr(env_id, env_size=None, max_episode_steps=300):
    from toy_mr import ToyMR
    from gym import wrappers
    from chain_env import ChainEnvironment
    if env_id == 'toy_mr':
        env = ToyMR()
    else:
        assert isinstance(env_size, int), f'got {env_size}'
        env = ChainEnvironment(N=env_size)
    env = wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
    return env
def test(agent):
    EPISODE_NUM = 1000
    env = wrappers.TimeLimit(toy_text.FrozenLakeEnv(map_name='8x8'),
                             max_episode_steps=1000)
    score_sum = 0.0
    for _ in range(EPISODE_NUM):
        observation = env.reset()
        while True:
            observation, reward, done, _ = env.step(
                agent.best_action(observation))
            score_sum += reward
            if done:
                break
    assert score_sum / EPISODE_NUM >= 0.99
def train():
    UPDATE_PERIOD = 10000
    EPSILON_DECAY = 0.7
    env = wrappers.TimeLimit(toy_text.FrozenLakeEnv(map_name='8x8'),
                             max_episode_steps=1000)

    print('agent is not created')
    exit(1)
    agent = None

    check_agent_actions(agent)
    check_agent_update()

    epsilon = 1.0
    last_scores_sum = 0.0
    for episode_id in itertools.count():
        score = 0.0
        observation = env.reset()
        while True:

            print(
                'perform a random action with probability of epsilon. Request best action from agent otherwise'
            )
            exit(1)
            action = None

            next_observation, reward, done, _ = env.step(action)
            agent.update_q_values(observation, action, reward,
                                  next_observation, done)
            score += reward
            if done:
                break
            observation = next_observation

        last_scores_sum += score
        if episode_id > 0 and episode_id % UPDATE_PERIOD == 0:
            epsilon *= EPSILON_DECAY
            last_mean_score = last_scores_sum / UPDATE_PERIOD
            print('last_mean_score:', last_mean_score)
            if last_mean_score > 0.999:
                print('You won!')
                break
            else:
                last_scores_sum = 0.0
    return agent
示例#4
0
def try_gym():
    # Jump to FrozenLakeEnv defenition to see its descrition
    env = wrappers.TimeLimit(toy_text.FrozenLakeEnv(map_name='8x8'), max_episode_steps=1000)
    print('action_space size:', env.action_space.n)
    print('random action sample:', env.action_space.sample())
    print('observation_space size:', env.observation_space.n)
    for episode_id in range(1):
        score = 0.0
        observation = env.reset()
        print('observation:', observation)
        # If render() prints colorcodes instead of coloring a simbol on Windows and you are a perfectionist, you can use
        # https://aka.ms/terminal if your Windows version is high enough
        env.render()
        for step_id in itertools.count():
            observation, reward, done, diagnostic_info = env.step(env.action_space.sample())
            # env.render()
            score += reward
            if done: break
        print('score:', score)  # It will probaly be 0
示例#5
0
文件: ddpg.py 项目: jtib/chi-rl-alg
def test_ddpg():
    import gym_mix
    env = gym.make('ContinuousCopyRand-v0')
    env = wrappers.TimeLimit(env, max_episode_steps=0)

    @model(optimizer=tf.train.AdamOptimizer(0.0001),
           tracker=tf.train.ExponentialMovingAverage(1 - 0.001))
    def actor(x):
        x = layers.fully_connected(
            x, 50, biases_initializer=layers.xavier_initializer())
        a = layers.fully_connected(
            x,
            env.action_space.shape[0],
            None,
            weights_initializer=tf.random_normal_initializer(0, 1e-4))
        return a

    @model(optimizer=tf.train.AdamOptimizer(.001),
           tracker=tf.train.ExponentialMovingAverage(1 - 0.001))
    def critic(x, a):
        x = layers.fully_connected(
            x, 300, biases_initializer=layers.xavier_initializer())
        x = tf.concat([x, a], axis=1)
        x = layers.fully_connected(
            x, 300, biases_initializer=layers.xavier_initializer())
        x = layers.fully_connected(
            x, 300, biases_initializer=layers.xavier_initializer())
        q = layers.fully_connected(
            x,
            1,
            None,
            weights_initializer=tf.random_normal_initializer(0, 1e-4))
        return tf.squeeze(q, 1)

    agent = DdpgAgent(env, actor, critic)

    for ep in range(10000):
        R, _ = agent.play_episode()

        if ep % 100 == 0:
            print(f'Return after episode {ep} is {R}')
示例#6
0
        env_tst = environ.StocksEnv(stock_data,
                                    bars_count=BARS_COUNT,
                                    reset_on_close=True,
                                    state_1d=True)
    elif os.path.isdir(args.data):
        env = environ.StocksEnv.from_dir(args.data,
                                         bars_count=BARS_COUNT,
                                         reset_on_close=True,
                                         state_1d=True)
        env_tst = environ.StocksEnv.from_dir(args.data,
                                             bars_count=BARS_COUNT,
                                             reset_on_close=True,
                                             state_1d=True)
    else:
        raise RuntimeError("No data to train on")
    env = wrappers.TimeLimit(env, max_episode_steps=1000)

    val_data = {"BANK": data.load_relative(args.valdata)}
    env_val = environ.StocksEnv(val_data,
                                bars_count=BARS_COUNT,
                                reset_on_close=True,
                                state_1d=True)

    writer = SummaryWriter(comment="-conv-" + args.run)
    net = models.DQNConv1D(env.observation_space.shape,
                           env.action_space.n).to(device)
    print(net)
    #TARGET NETWORK TRICK AS MENTIONED IN CORRESPONDING NOTEBOOK
    tgt_net = DeepTrader.agent.TargetNet(net)
    selector = DeepTrader.actions.EpsilonGreedyActionSelector(EPSILON_START)
    agent = DeepTrader.agent.DQNAgent(net, selector, device=device)
def train_agent(
    run_name,
    data_paths=conf.default_data_paths,
    validation_paths=conf.default_validation_paths,
    model=models.DQNConv1D,
    large=False,
    load_checkpoint=None,
    saves_path=None,
    eps_steps=None,
):
    """
    Main function for training the agents

    :run_name: a string of choice that dictates where to save
    :data_paths: dict specifying what data to train with
    :validation_paths: dict specifying what data to validate with
    :model: what model to use
    :large: whether or not to use large feature set
    :load_checkpoint: an optinal path to checkpoint to load from
    """

    print("=" * 80)
    print("Training starting".rjust(40 + 17 // 2))
    print("=" * 80)

    # Get training data
    stock_data = data.get_data_as_dict(data_paths, large=large)
    val_data = data.get_data_as_dict(validation_paths, large=large)

    # Setup before training can begin
    step_idx = 0
    eval_states = None
    best_mean_val = None
    EPSILON_STEPS = eps_steps if eps_steps is not None else conf.EPSILON_STEPS

    # Use GPU if available, else fall back on CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Info] Using device: {device}")

    # Set up the path to save the checkpoints to
    if saves_path is None:
        saves_path = os.path.join("saves", run_name)
    else:
        saves_path = os.path.join(saves_path, run_name)

    print(f"[Info] Saving to path: {saves_path}")

    os.makedirs(saves_path, exist_ok=True)

    # Create the gym-environment that the agent will interact with during training
    env = environ.StocksEnv(
        stock_data,
        bars_count=conf.BARS_COUNT,
        reset_on_close=conf.RESET_ON_CLOSE,
        random_ofs_on_reset=conf.RANDOM_OFS_ON_RESET,
        reward_on_close=conf.REWARD_ON_CLOSE,
        large=large,
    )

    env = wrappers.TimeLimit(env, max_episode_steps=1000)

    # Create the gym-environment that the agent will interact with when validating
    env_val = environ.StocksEnv(
        val_data,
        bars_count=conf.BARS_COUNT,
        reset_on_close=conf.RESET_ON_CLOSE,
        random_ofs_on_reset=conf.RANDOM_OFS_ON_RESET,
        reward_on_close=conf.REWARD_ON_CLOSE,
        large=large,
    )

    # Create the model
    net = model(env.observation_space.shape, env.action_space.n).to(device)

    print("Using network:".rjust(40 + 14 // 2))
    print("=" * 80)
    print(net)

    # Initialize agent and epsilon-greedy action-selector from the ptan package
    # The ptan package provides some helper and wrapper functions for ease of
    # use of reinforcement learning
    tgt_net = ptan.agent.TargetNet(net)
    selector = ptan.actions.EpsilonGreedyActionSelector(conf.EPSILON_START)
    agent = ptan.agent.DQNAgent(net, selector, device=device)
    exp_source = ptan.experience.ExperienceSourceFirstLast(
        env, agent, conf.GAMMA, steps_count=conf.REWARD_STEPS)
    buffer = ptan.experience.ExperienceReplayBuffer(exp_source,
                                                    conf.REPLAY_SIZE)
    optimizer = optim.Adam(net.parameters(), lr=conf.LEARNING_RATE)

    # If a checkpoint is supplied to the function –> resume the training from there
    if load_checkpoint is not None:
        state = torch.load(load_checkpoint)
        net.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        step_idx = state["step_idx"]
        best_mean_val = state["best_mean_val"]
        print(
            f"State loaded –> step index: {step_idx}, best mean val: {best_mean_val}"
        )

        net.train()

    # Create a reward tracker, i.e. an object that keeps track of the
    # rewards the agent gets during training
    reward_tracker = common.RewardTracker(np.inf, group_rewards=100)

    # The main training loop
    print("Training loop starting".rjust(40 + 22 // 2))
    print("=" * 80)

    # Run the main training loop
    while True:
        step_idx += 1
        buffer.populate(1)

        # Get current epsilon for epsilon-greedy action-selection
        selector.epsilon = max(conf.EPSILON_STOP,
                               conf.EPSILON_START - step_idx / EPSILON_STEPS)

        # Take a step and get rewards
        new_rewards = exp_source.pop_rewards_steps()
        if new_rewards:
            reward_tracker.reward(new_rewards[0], step_idx, selector.epsilon)

        # As long as not enough data is in buffer, go to top again
        if len(buffer) < conf.REPLAY_INITIAL:
            continue

        if eval_states is None:
            print("Initial buffer populated, start training")
            eval_states = buffer.sample(conf.STATES_TO_EVALUATE)
            eval_states = [
                np.array(transition.state, copy=False)
                for transition in eval_states
            ]
            eval_states = np.array(eval_states, copy=False)

        # Evaluate the model every x number of steps
        # and update the currently best performance if better value gotten
        if step_idx % conf.EVAL_EVERY_STEP == 0:
            mean_val = common.calc_values_of_states(eval_states,
                                                    net,
                                                    device=device)
            # If new best value –> save the model, both with meta data for resuming training
            # and as the full object for use in testing
            if best_mean_val is None or best_mean_val < mean_val:
                if best_mean_val is not None:
                    print(
                        f"{step_idx}: Best mean value updated {best_mean_val:.3f} -> {mean_val:.3f}"
                    )
                best_mean_val = mean_val
                # Save checkpoint with meta data
                torch.save(
                    {
                        "model_state_dict": net.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "step_idx": step_idx,
                        "best_mean_val": best_mean_val,
                    },
                    os.path.join(saves_path, f"mean_val-{mean_val:.3f}.data"),
                )
                # Save full object for testing
                torch.save(
                    net,
                    os.path.join(saves_path,
                                 f"mean_val-{mean_val:.3f}-fullmodel.data"),
                )

        # Reset optimizer's gradients before optimization step
        optimizer.zero_grad()
        batch = buffer.sample(conf.BATCH_SIZE)
        # Calculate the loss
        loss_v = common.calc_loss(
            batch,
            net,
            tgt_net.target_model,
            conf.GAMMA**conf.REWARD_STEPS,
            device=device,
        )
        # Calculate the gradient
        loss_v.backward()
        # Do one step of gradient descent
        optimizer.step()

        # Sync up the to networks we're using
        # Two networks in this manner should increase the agent's ability to converge
        if step_idx % conf.TARGET_NET_SYNC == 0:
            tgt_net.sync()

        # Every 1 million steps, save model in case something happens
        # so we can resume training in that case
        if step_idx % conf.CHECKPOINT_EVERY_STEP == 0:
            idx = step_idx // conf.CHECKPOINT_EVERY_STEP
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "step_idx": step_idx,
                    "best_mean_val": best_mean_val,
                },
                os.path.join(saves_path, f"checkpoint-{idx}.data"),
            )
            torch.save(net, os.path.join(saves_path, f"fullmodel-{idx}.data"))

    print("Training done")
示例#8
0
def make(env_name,
         render=False,
         figID=0,
         record=False,
         ros=False,
         directory='',
         T_steps=None,
         num_targets=1,
         **kwargs):
    """
    Parameters:
    ----------
    env_name : str
        name of an environment. (e.g. 'TargetTracking-v0')
    render : bool
        wether to render.
    figID : int
        figure ID for rendering and/or recording.
    record : bool
        whether to record a video.
    ros : bool
        whether to use ROS.
    directory :str
        a path to store a video file if record is True.
    T_steps : int
        the number of steps per episode.
    num_targets : int
        the number of targets
    """
    # if T_steps is None:
    #     if num_targets > 1:
    #         T_steps = 150
    #     else:
    #         T_steps = 100
    T_steps = 100

    local_view = 0
    if env_name == 'TargetTracking-v0':
        env0 = target_tracking.TargetTrackingEnv0(num_targets=num_targets,
                                                  **kwargs)
    elif env_name == 'TargetTracking-v1':
        env0 = target_tracking.TargetTrackingEnv1(num_targets=num_targets,
                                                  **kwargs)
    elif env_name == 'TargetTracking-v2':
        env0 = target_tracking.TargetTrackingEnv2(num_targets=num_targets,
                                                  **kwargs)
    elif env_name == 'TargetTracking-v3':
        env0 = target_tracking.TargetTrackingEnv3(num_targets=num_targets,
                                                  **kwargs)
    elif env_name == 'TargetTracking-v4':
        env0 = target_tracking.TargetTrackingEnv4(num_targets=num_targets,
                                                  **kwargs)
    elif env_name == 'TargetTracking-v5':
        local_view = 1
        env0 = target_imtracking.TargetTrackingEnv5(num_targets=num_targets,
                                                    **kwargs)
    elif env_name == 'TargetTracking-v6':
        local_view = 1
        env0 = target_imtracking.TargetTrackingEnv6(num_targets=num_targets,
                                                    **kwargs)
    elif env_name == 'TargetTracking-v7':
        local_view = 5
        env0 = target_imtracking.TargetTrackingEnv7(num_targets=num_targets,
                                                    **kwargs)
    elif env_name == 'TargetTracking-v8':
        local_view = 5
        env0 = target_imtracking.TargetTrackingEnv8(num_targets=num_targets,
                                                    **kwargs)
    elif env_name == 'TargetTracking-v9':
        local_view = 5
        env0 = target_imtracking.TargetTrackingEnv9(num_targets=num_targets,
                                                    **kwargs)
    elif env_name == 'TargetTracking-v10':
        local_view = 1
        env0 = target_imtracking.TargetTrackingEnv10(num_targets=num_targets,
                                                     **kwargs)
    elif env_name == 'TargetTracking-v1_SEQ':
        env0 = target_seq_tracking.TargetTrackingEnv1_SEQ(
            num_targets=num_targets, **kwargs)
    elif env_name == 'TargetTracking-v5_SEQ':
        local_view = 1
        env0 = target_seq_tracking.TargetTrackingEnv5_SEQ(
            num_targets=num_targets, **kwargs)
    elif env_name == 'TargetTracking-v7_SEQ':
        local_view = 5
        env0 = target_seq_tracking.TargetTrackingEnv7_SEQ(
            num_targets=num_targets, **kwargs)
    elif env_name == 'TargetTracking-v8_SEQ':
        local_view = 5
        env0 = target_seq_tracking.TargetTrackingEnv8_SEQ(
            num_targets=num_targets, **kwargs)
    elif env_name == 'TargetTracking-v9_SEQ':
        local_view = 5
        env0 = target_seq_tracking.TargetTrackingEnv9_SEQ(
            num_targets=num_targets, **kwargs)
    elif env_name == 'TargetTracking-info1':
        from ttenv.infoplanner_python.target_tracking_infoplanner import TargetTrackingInfoPlanner1
        env0 = TargetTrackingInfoPlanner1(num_targets=num_targets, **kwargs)
    elif env_name == 'TargetTracking-info2':
        from ttenv.infoplanner_python.target_tracking_infoplanner import TargetTrackingInfoPlanner2
        env0 = TargetTrackingInfoPlanner2(num_targets=num_targets, **kwargs)
    else:
        raise ValueError('No such environment exists.')

    env = wrappers.TimeLimit(env0, max_episode_steps=T_steps)
    if ros:
        from ttenv.ros_wrapper import Ros
        env = Ros(env)
    if render:
        from ttenv.display_wrapper import Display2D
        env = Display2D(env, figID=figID, local_view=local_view)
    if record:
        from ttenv.display_wrapper import Video2D
        env = Video2D(env, dirname=directory, local_view=local_view)

    return env
示例#9
0
def make(env_name,
         render=False,
         figID=0,
         record=False,
         ros=False,
         directory='',
         T_steps=None,
         num_agents=2,
         num_targets=1,
         **kwargs):
    """
    env_name : str
        name of an environment. (e.g. 'Cartpole-v0')
    type : str
        type of an environment. One of ['atari', 'classic_control',
        'classic_mdp','target_tracking']
    """
    if T_steps is None:
        # if num_targets > 1:
        T_steps = 200
        # else:
        #     T_steps = 150
    if env_name == 'maTracking-v0':
        from maTTenv.env.maTracking_v0 import maTrackingEnv0
        env0 = maTrackingEnv0(num_agents=num_agents,
                              num_targets=num_targets,
                              **kwargs)
    elif env_name == 'maTracking-v1':
        from maTTenv.env.maTracking_v1 import maTrackingEnv1
        env0 = maTrackingEnv1(num_agents=num_agents,
                              num_targets=num_targets,
                              **kwargs)
    elif env_name == 'maTracking-v2':
        from maTTenv.env.maTracking_v2 import maTrackingEnv2
        env0 = maTrackingEnv2(num_agents=num_agents,
                              num_targets=num_targets,
                              **kwargs)
    elif env_name == 'maTracking-v3':
        from maTTenv.env.maTracking_v3 import maTrackingEnv3
        env0 = maTrackingEnv3(num_agents=num_agents,
                              num_targets=num_targets,
                              **kwargs)
    elif env_name == 'maTracking-v4':
        from maTTenv.env.maTracking_v4 import maTrackingEnv4
        env0 = maTrackingEnv4(num_agents=num_agents,
                              num_targets=num_targets,
                              **kwargs)

    elif env_name == 'setTracking-v1':
        from maTTenv.env.setTracking_v1 import setTrackingEnv1
        env0 = setTrackingEnv1(num_agents=num_agents,
                               num_targets=num_targets,
                               **kwargs)
    elif env_name == 'setTracking-v2':
        from maTTenv.env.setTracking_v2 import setTrackingEnv2
        env0 = setTrackingEnv2(num_agents=num_agents,
                               num_targets=num_targets,
                               **kwargs)
    elif env_name == 'setTracking-v3':
        from maTTenv.env.setTracking_v3 import setTrackingEnv3
        env0 = setTrackingEnv3(num_agents=num_agents,
                               num_targets=num_targets,
                               **kwargs)
    elif env_name == 'setTracking-v4':
        from maTTenv.env.setTracking_v4 import setTrackingEnv4
        env0 = setTrackingEnv4(num_agents=num_agents,
                               num_targets=num_targets,
                               **kwargs)
    elif env_name == 'setTracking-v5':
        from maTTenv.env.setTracking_v5 import setTrackingEnv5
        env0 = setTrackingEnv5(num_agents=num_agents,
                               num_targets=num_targets,
                               **kwargs)
    elif env_name == 'setTracking-v6':
        from maTTenv.env.setTracking_v6 import setTrackingEnv6
        env0 = setTrackingEnv6(num_agents=num_agents,
                               num_targets=num_targets,
                               **kwargs)
    elif env_name == 'setTracking-v7':
        from maTTenv.env.setTracking_v7 import setTrackingEnv7
        env0 = setTrackingEnv7(num_agents=num_agents,
                               num_targets=num_targets,
                               **kwargs)

    else:
        raise ValueError('No such environment exists.')

    env = wrappers.TimeLimit(env0, max_episode_steps=T_steps)
    if ros:
        from ttenv.ros_wrapper import Ros
        env = Ros(env)
    if render:
        from maTTenv.display_wrapper import Display2D
        env = Display2D(env, figID=figID)
    if record:
        from maTTenv.display_wrapper import Video2D
        env = Video2D(env, dirname=directory)

    return env
示例#10
0
def make(env_name,
         render=False,
         figID=0,
         record=False,
         ros=False,
         dirname='',
         map_name="empty",
         is_training=True,
         num_targets=1,
         T_steps=None,
         im_size=None):
    if False:
        env = gym.make(env_name)
        if record:
            env = Monitor(env, directory=args.log_dir)
    else:
        if 'Target' in env_name:
            from gym import wrappers
            import envs.target_tracking.target_tracking as ttenv
            from envs.target_tracking.target_tracking_advanced import TargetTrackingEnvRNN
            # from envs.target_tracking.target_tracking_infoplanner import TargetTrackingInfoPlanner1, TargetTrackingInfoPlanner2
            from envs.target_tracking import display_wrapper
            if T_steps is None:
                if num_targets > 1:
                    T_steps = 150
                else:
                    T_steps = 100
            if env_name == 'TargetTracking-v0':
                env0 = ttenv.TargetTrackingEnv0(map_name=map_name,
                                                is_training=is_training,
                                                num_targets=num_targets)
            elif env_name == 'TargetTracking-v1':
                env0 = ttenv.TargetTrackingEnv1(map_name=map_name,
                                                is_training=is_training,
                                                num_targets=num_targets)
            elif env_name == 'TargetTracking-v2':
                env0 = ttenv.TargetTrackingEnv2(map_name=map_name,
                                                is_training=is_training,
                                                num_targets=num_targets)
            elif env_name == 'TargetTracking-v3':
                env0 = ttenv.TargetTrackingEnv3(map_name=map_name,
                                                is_training=is_training,
                                                num_targets=num_targets)
            elif env_name == 'TargetTracking-v4':
                env0 = ttenv.TargetTrackingEnv4(map_name=map_name,
                                                is_training=is_training,
                                                num_targets=num_targets)
            elif env_name == 'TargetTracking-v5':
                env0 = ttenv.TargetTrackingEnv5(map_name=map_name,
                                                is_training=is_training,
                                                num_targets=num_targets,
                                                im_size=im_size)
            elif env_name == 'TargetTracking-vRNN':
                env0 = TargetTrackingEnvRNN(map_name=map_name,
                                            is_training=is_training,
                                            num_targets=num_targets)
                T_steps = 200
            elif env_name == 'TargetTracking-info1':
                env0 = TargetTrackingInfoPlanner1(map_name=map_name,
                                                  is_training=is_training,
                                                  num_targets=num_targets)
            elif env_name == 'TargetTracking-info2':
                env0 = TargetTrackingInfoPlanner2(map_name=map_name,
                                                  is_training=is_training,
                                                  num_targets=num_targets)
            else:
                raise ValueError('no such environments')

            env = wrappers.TimeLimit(env0, max_episode_steps=T_steps)
            if ros:
                from envs.ros_wrapper import Ros
                env = Ros(env)
            if render:
                env = display_wrapper.Display2D(env, figID=figID)
            if record:
                env = display_wrapper.Video2D(env, dirname=dirname)
        else:
            from envs import classic_mdp
            env = classic_mdp.model_assign(env_name)
    return env