예제 #1
0
def rollout_meta_with_values_on_gym_env(sess,
                                        env,
                                        state_ph,
                                        internal_state_ph,
                                        action_ph,
                                        reward_ph,
                                        deterministic_ph,
                                        action_value_op,
                                        action_op,
                                        value_op,
                                        internal_state_op,
                                        zero_state_fn,
                                        initial_action,
                                        num_episodes=1,
                                        deterministic=False,
                                        stream=None,
                                        save_replay=True):
    if stream is None and save_replay:
        raise ValueError('missing `stream` to `save_replay`.')

    replays = []
    for episode in range(num_episodes):
        internal_state = sess.run(zero_state_fn(1, dtype=state_ph.dtype))
        reward = 0.
        action = initial_action
        experiences = []
        next_state = env.reset()
        while True:
            state = next_state
            action_value, action, value, internal_state = sess.run(
                (action_value_op, action_op, value_op, internal_state_op),
                feed_dict={
                    state_ph: [[state]],
                    **{
                        k: v
                        for k, v in zip(internal_state_ph, internal_state)
                    },
                    deterministic_ph: deterministic,
                    action_ph: [[action]],
                    reward_ph: [[reward]],
                })

            next_state, reward, terminal, _ = env.step(action)
            experiences.append(
                experience.ExperienceWithValues(state, next_state, action,
                                                action_value, value, reward,
                                                terminal))

            if terminal:
                break
        replay = experience.ReplayWithValues(*zip(*experiences),
                                             sequence_length=len(experiences))
        if save_replay:
            stream.write(replay)
        replays.append(replay)
    return experience.RolloutsWithValues(*zip(*replays))
예제 #2
0
def rollout_with_values_on_gym_env(sess,
                                   env,
                                   state_ph,
                                   deterministic_ph,
                                   action_value_op,
                                   action_op,
                                   value_op,
                                   num_episodes=1,
                                   deterministic=False,
                                   stream=None,
                                   save_replay=True):
    if stream is None and save_replay:
        raise ValueError('missing `stream` to `save_replay`.')

    rewards = 0.
    for episode in range(num_episodes):
        experiences = []
        next_state = env.reset()
        while True:
            state = next_state
            action_value, action, value = sess.run(
                (action_value_op, action_op, value_op),
                feed_dict={
                    state_ph: [[state]],
                    deterministic_ph: deterministic
                })

            next_state, reward, terminal, _ = env.step(action)
            experiences.append(
                experience.ExperienceWithValues(state, next_state, action,
                                                action_value, value, reward,
                                                terminal))

            if terminal:
                break
        replay = experience.ReplayWithValues(*zip(*experiences),
                                             sequence_length=len(experiences))
        if save_replay:
            stream.write(replay)
        rewards += sum(replay.reward)
    return rewards
예제 #3
0
파일: dataset.py 프로젝트: wenkesj/alchemy
        def pad_or_truncate_map(replay):
            """Truncate or pad replays."""
            with_values = 'value' in replay

            if with_values:
                replay = experience.ReplayWithValues(**replay)
            else:
                replay = experience.Replay(**replay)

            sequence_length = math_ops.minimum(max_sequence_length,
                                               replay.sequence_length)
            sequence_length.set_shape([1])

            state = sequence_utils.pad_or_truncate(replay.state,
                                                   max_sequence_length,
                                                   axis=0,
                                                   pad_value=0)
            state.set_shape([max_sequence_length] + state_shape)

            next_state = sequence_utils.pad_or_truncate(replay.next_state,
                                                        max_sequence_length,
                                                        axis=0,
                                                        pad_value=0)
            next_state.set_shape([max_sequence_length] + state_shape)

            action = sequence_utils.pad_or_truncate(replay.action,
                                                    max_sequence_length,
                                                    axis=0,
                                                    pad_value=0)
            action.set_shape([max_sequence_length] + action_shape)

            action_value = sequence_utils.pad_or_truncate(replay.action_value,
                                                          max_sequence_length,
                                                          axis=0,
                                                          pad_value=0)
            action_value.set_shape([max_sequence_length] + action_value_shape)

            reward = sequence_utils.pad_or_truncate(replay.reward,
                                                    max_sequence_length,
                                                    axis=0,
                                                    pad_value=0)
            reward.set_shape([max_sequence_length] + reward_shape)

            terminal = sequence_utils.pad_or_truncate(
                replay.terminal,
                max_sequence_length,
                axis=0,
                pad_value=ops.convert_to_tensor(False))
            terminal.set_shape([max_sequence_length])

            if with_values:
                value = sequence_utils.pad_or_truncate(replay.value,
                                                       max_sequence_length,
                                                       axis=0,
                                                       pad_value=0)
                value.set_shape([max_sequence_length] + reward_shape)

                return experience.ReplayWithValues(
                    state=state,
                    next_state=next_state,
                    action=action,
                    action_value=action_value,
                    value=value,
                    reward=reward,
                    terminal=terminal,
                    sequence_length=sequence_length)

            return experience.Replay(state=state,
                                     next_state=next_state,
                                     action=action,
                                     action_value=action_value,
                                     reward=reward,
                                     terminal=terminal,
                                     sequence_length=sequence_length)