コード例 #1
0
def get_transition(time_step, next_time_step, action, next_action):
    return dataset.Transition(s1=time_step.observation,
                              s2=next_time_step.observation,
                              a1=action,
                              a2=next_action,
                              reward=next_time_step.reward,
                              discount=next_time_step.discount)
コード例 #2
0
ファイル: train_eval_offline.py プロジェクト: mihdalal/d4rl
def get_offline_data(tf_env):
    gym_env = tf_env.pyenv.envs[0]
    # offline_dataset = gym_env.unwrapped.get_dataset()
    offline_dataset = gym_env.get_dataset()
    dataset_size = len(offline_dataset["observations"])
    tf_dataset = dataset.Dataset(
        tf_env.observation_spec(), tf_env.action_spec(), size=dataset_size
    )
    observation_dtype = tf_env.observation_spec().dtype
    action_dtype = tf_env.action_spec().dtype

    offline_dataset["terminals"] = np.squeeze(offline_dataset["terminals"])
    offline_dataset["rewards"] = np.squeeze(offline_dataset["rewards"])
    (nonterminal_steps,) = np.where(
        np.logical_and(
            np.logical_not(offline_dataset["terminals"]),
            np.arange(dataset_size) < dataset_size - 1,
        )
    )
    logging.info(
        "Found %d non-terminal steps out of a total of %d steps."
        % (len(nonterminal_steps), dataset_size)
    )

    s1 = tf.convert_to_tensor(
        offline_dataset["observations"][nonterminal_steps], dtype=observation_dtype
    )
    s2 = tf.convert_to_tensor(
        offline_dataset["observations"][nonterminal_steps + 1], dtype=observation_dtype
    )
    a1 = tf.convert_to_tensor(
        offline_dataset["actions"][nonterminal_steps], dtype=action_dtype
    )
    a2 = tf.convert_to_tensor(
        offline_dataset["actions"][nonterminal_steps + 1], dtype=action_dtype
    )
    discount = tf.convert_to_tensor(
        1.0 - offline_dataset["terminals"][nonterminal_steps + 1], dtype=tf.float32
    )
    reward = tf.convert_to_tensor(
        offline_dataset["rewards"][nonterminal_steps], dtype=tf.float32
    )

    transitions = dataset.Transition(s1, s2, a1, a2, discount, reward)

    tf_dataset.add_transitions(transitions)
    return tf_dataset