def register_step(self, observation: Observation, action: Action, reward: Reward, next_observation: Observation, done: bool, action_metadata: dict) -> 'Step': self.trajectory_returns += reward if 'log_pi' in action_metadata.keys(): self.trajectory_entropy += action_metadata['log_pi'].reshape(1, -1) action_metadata.update( dict(returns=self.trajectory_returns, entropy=self.trajectory_entropy)) step = Step(state=observation.reshape(1, -1), action=action.reshape(1, -1), reward=np.array(reward, dtype=np.float32).reshape(1, -1), next_state=next_observation.reshape(1, -1), termination_masks=np.array(done, dtype='uint8').reshape(1, -1), metadata=action_metadata) if self.step_count == 0: self.initialize_records(step) else: for stat, value in step.asdict().items(): self.trajectory_buffer[stat].append(value) self.step_history.append(step) self.step_count += 1 return step