def get_transition(time_step, next_time_step, action, next_action): return dataset.Transition(s1=time_step.observation, s2=next_time_step.observation, a1=action, a2=next_action, reward=next_time_step.reward, discount=next_time_step.discount)
def get_offline_data(tf_env): gym_env = tf_env.pyenv.envs[0] # offline_dataset = gym_env.unwrapped.get_dataset() offline_dataset = gym_env.get_dataset() dataset_size = len(offline_dataset["observations"]) tf_dataset = dataset.Dataset( tf_env.observation_spec(), tf_env.action_spec(), size=dataset_size ) observation_dtype = tf_env.observation_spec().dtype action_dtype = tf_env.action_spec().dtype offline_dataset["terminals"] = np.squeeze(offline_dataset["terminals"]) offline_dataset["rewards"] = np.squeeze(offline_dataset["rewards"]) (nonterminal_steps,) = np.where( np.logical_and( np.logical_not(offline_dataset["terminals"]), np.arange(dataset_size) < dataset_size - 1, ) ) logging.info( "Found %d non-terminal steps out of a total of %d steps." % (len(nonterminal_steps), dataset_size) ) s1 = tf.convert_to_tensor( offline_dataset["observations"][nonterminal_steps], dtype=observation_dtype ) s2 = tf.convert_to_tensor( offline_dataset["observations"][nonterminal_steps + 1], dtype=observation_dtype ) a1 = tf.convert_to_tensor( offline_dataset["actions"][nonterminal_steps], dtype=action_dtype ) a2 = tf.convert_to_tensor( offline_dataset["actions"][nonterminal_steps + 1], dtype=action_dtype ) discount = tf.convert_to_tensor( 1.0 - offline_dataset["terminals"][nonterminal_steps + 1], dtype=tf.float32 ) reward = tf.convert_to_tensor( offline_dataset["rewards"][nonterminal_steps], dtype=tf.float32 ) transitions = dataset.Transition(s1, s2, a1, a2, discount, reward) tf_dataset.add_transitions(transitions) return tf_dataset