def __call__(self, batch: Dict[str, torch.Tensor]) -> rlt.PolicyNetworkInput: batch = batch_to_device(batch, self.device) preprocessed_state = self.state_preprocessor( batch["state_features"], batch["state_features_presence"] ) preprocessed_next_state = self.state_preprocessor( batch["next_state_features"], batch["next_state_features_presence"] ) preprocessed_action = self.action_preprocessor( batch["action"], batch["action_presence"] ) preprocessed_next_action = self.action_preprocessor( batch["next_action"], batch["next_action_presence"] ) return rlt.PolicyNetworkInput( state=rlt.FeatureData(preprocessed_state), next_state=rlt.FeatureData(preprocessed_next_state), action=rlt.FeatureData(preprocessed_action), next_action=rlt.FeatureData(preprocessed_next_action), reward=batch["reward"].unsqueeze(1), time_diff=batch["time_diff"].unsqueeze(1), step=batch["step"].unsqueeze(1), not_terminal=batch["not_terminal"].unsqueeze(1), extras=rlt.ExtraData( mdp_id=batch["mdp_id"].unsqueeze(1), sequence_number=batch["sequence_number"].unsqueeze(1), action_probability=batch["action_probability"].unsqueeze(1), ), )
def as_policy_network_training_batch(self): return rlt.PolicyNetworkInput( state=rlt.FeatureData(float_features=self.states), action=rlt.FeatureData(float_features=self.actions), next_state=rlt.FeatureData(float_features=self.next_states), next_action=rlt.FeatureData(float_features=self.next_actions), reward=self.rewards, not_terminal=self.not_terminal, step=self.step, time_diff=self.time_diffs, extras=rlt.ExtraData(), )
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() # TODO: We need to normalized the action in here return rlt.PolicyNetworkInput( state=rlt.FeatureData(float_features=batch.state), action=rlt.FeatureData(float_features=batch.action), next_state=rlt.FeatureData(float_features=batch.next_state), next_action=rlt.FeatureData(float_features=batch.next_action), reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )
def __call__(self, batch): not_terminal = 1.0 - batch.terminal.float() # normalize actions (train_low, train_high) = CONTINUOUS_TRAINING_ACTION_RANGE action = torch.tensor( rescale_actions( batch.action.numpy(), new_min=train_low, new_max=train_high, prev_min=self.action_low, prev_max=self.action_high, )) # only normalize non-terminal non_terminal_indices = (batch.terminal == 0).squeeze(1) next_action = torch.zeros_like(action) next_action[non_terminal_indices] = torch.tensor( rescale_actions( batch.next_action[non_terminal_indices].numpy(), new_min=train_low, new_max=train_high, prev_min=self.action_low, prev_max=self.action_high, )) return rlt.PolicyNetworkInput( state=rlt.FeatureData(float_features=batch.state), action=rlt.FeatureData(float_features=action), next_state=rlt.FeatureData(float_features=batch.next_state), next_action=rlt.FeatureData(float_features=next_action), reward=batch.reward, not_terminal=not_terminal, step=None, time_diff=None, extras=rlt.ExtraData( mdp_id=None, sequence_number=None, action_probability=batch.log_prob.exp(), max_num_actions=None, metrics=None, ), )