def rollout_meta_with_values_on_gym_env(sess, env, state_ph, internal_state_ph, action_ph, reward_ph, deterministic_ph, action_value_op, action_op, value_op, internal_state_op, zero_state_fn, initial_action, num_episodes=1, deterministic=False, stream=None, save_replay=True): if stream is None and save_replay: raise ValueError('missing `stream` to `save_replay`.') replays = [] for episode in range(num_episodes): internal_state = sess.run(zero_state_fn(1, dtype=state_ph.dtype)) reward = 0. action = initial_action experiences = [] next_state = env.reset() while True: state = next_state action_value, action, value, internal_state = sess.run( (action_value_op, action_op, value_op, internal_state_op), feed_dict={ state_ph: [[state]], **{ k: v for k, v in zip(internal_state_ph, internal_state) }, deterministic_ph: deterministic, action_ph: [[action]], reward_ph: [[reward]], }) next_state, reward, terminal, _ = env.step(action) experiences.append( experience.ExperienceWithValues(state, next_state, action, action_value, value, reward, terminal)) if terminal: break replay = experience.ReplayWithValues(*zip(*experiences), sequence_length=len(experiences)) if save_replay: stream.write(replay) replays.append(replay) return experience.RolloutsWithValues(*zip(*replays))
def rollout_with_values_on_gym_env(sess, env, state_ph, deterministic_ph, action_value_op, action_op, value_op, num_episodes=1, deterministic=False, stream=None, save_replay=True): if stream is None and save_replay: raise ValueError('missing `stream` to `save_replay`.') rewards = 0. for episode in range(num_episodes): experiences = [] next_state = env.reset() while True: state = next_state action_value, action, value = sess.run( (action_value_op, action_op, value_op), feed_dict={ state_ph: [[state]], deterministic_ph: deterministic }) next_state, reward, terminal, _ = env.step(action) experiences.append( experience.ExperienceWithValues(state, next_state, action, action_value, value, reward, terminal)) if terminal: break replay = experience.ReplayWithValues(*zip(*experiences), sequence_length=len(experiences)) if save_replay: stream.write(replay) rewards += sum(replay.reward) return rewards
def pad_or_truncate_map(replay): """Truncate or pad replays.""" with_values = 'value' in replay if with_values: replay = experience.ReplayWithValues(**replay) else: replay = experience.Replay(**replay) sequence_length = math_ops.minimum(max_sequence_length, replay.sequence_length) sequence_length.set_shape([1]) state = sequence_utils.pad_or_truncate(replay.state, max_sequence_length, axis=0, pad_value=0) state.set_shape([max_sequence_length] + state_shape) next_state = sequence_utils.pad_or_truncate(replay.next_state, max_sequence_length, axis=0, pad_value=0) next_state.set_shape([max_sequence_length] + state_shape) action = sequence_utils.pad_or_truncate(replay.action, max_sequence_length, axis=0, pad_value=0) action.set_shape([max_sequence_length] + action_shape) action_value = sequence_utils.pad_or_truncate(replay.action_value, max_sequence_length, axis=0, pad_value=0) action_value.set_shape([max_sequence_length] + action_value_shape) reward = sequence_utils.pad_or_truncate(replay.reward, max_sequence_length, axis=0, pad_value=0) reward.set_shape([max_sequence_length] + reward_shape) terminal = sequence_utils.pad_or_truncate( replay.terminal, max_sequence_length, axis=0, pad_value=ops.convert_to_tensor(False)) terminal.set_shape([max_sequence_length]) if with_values: value = sequence_utils.pad_or_truncate(replay.value, max_sequence_length, axis=0, pad_value=0) value.set_shape([max_sequence_length] + reward_shape) return experience.ReplayWithValues( state=state, next_state=next_state, action=action, action_value=action_value, value=value, reward=reward, terminal=terminal, sequence_length=sequence_length) return experience.Replay(state=state, next_state=next_state, action=action, action_value=action_value, reward=reward, terminal=terminal, sequence_length=sequence_length)