示例#1
0
    def loss_fcn(self, rollout_real: StepSequence,
                 rollout_sim: StepSequence) -> float:
        """
        Compute the discrepancy between two time sequences of observations given metric.
        Be sure to align and truncate the rollouts beforehand.

        :param rollout_real: (concatenated) real-world rollout containing the observations
        :param rollout_sim: (concatenated) simulated rollout containing the observations
        :return: discrepancy cost summed over the observation dimensions
        """
        if len(rollout_real) != len(rollout_sim):
            raise pyrado.ShapeErr(given=rollout_real,
                                  expected_match=rollout_sim)

        # Extract the observations
        real_obs = rollout_real.get_data_values("observations",
                                                truncate_last=True)
        sim_obs = rollout_sim.get_data_values("observations",
                                              truncate_last=True)

        # Filter the observations
        real_obs = gaussian_filter1d(real_obs, self.std_obs_filt, axis=0)
        sim_obs = gaussian_filter1d(sim_obs, self.std_obs_filt, axis=0)

        # Normalize the signals
        real_obs_norm = self.obs_normalizer.project_to(real_obs)
        sim_obs_norm = self.obs_normalizer.project_to(sim_obs)

        # Compute loss based on the error
        loss_per_obs_dim = self.metric(real_obs_norm - sim_obs_norm)
        assert len(loss_per_obs_dim) == real_obs.shape[1]
        assert all(loss_per_obs_dim >= 0)
        return sum(loss_per_obs_dim)
示例#2
0
def convert_step_sequence(traj: StepSequence):
    """
    Converts a StepSequence to a Tensor which can be fed through a Network

    :param traj: A step sequence containing a trajectory
    :return: A Tensor containing the trajectory
    """
    assert isinstance(traj, StepSequence)
    traj.torch()
    state = traj.get_data_values('observations')[:-1].double()
    next_state = traj.get_data_values('observations')[1::].double()
    action = traj.get_data_values('actions').narrow(
        0, 0, next_state.shape[0]).double()
    traj = to.cat((state, next_state, action), 1).cpu().double()
    return traj
示例#3
0
def preprocess_rollout(rollout: StepSequence) -> StepSequence:
    """
    Extracts observations and actions from a `StepSequence` and packs them into a PyTorch tensor which can be fed
    through a network.

    :param rollout: a `StepSequence` instance containing a trajectory
    :return: a PyTorch tensor` containing the trajectory
    """
    if not isinstance(rollout, StepSequence):
        raise pyrado.TypeErr(given=rollout, expected_type=StepSequence)

    # Convert data type
    rollout.torch(to.get_default_dtype())

    # Extract the data
    state = rollout.get_data_values("observations")[:-1]
    next_state = rollout.get_data_values("observations")[1::]
    action = rollout.get_data_values("actions").narrow(0, 0,
                                                       next_state.shape[0])

    rollout = to.cat((state, next_state, action), 1)
    return rollout
示例#4
0
    def evaluate(self,
                 rollout: StepSequence,
                 hidden_states_name: str = 'hidden_states') -> to.Tensor:
        """
        Re-evaluate the given rollout and return a derivable action tensor.
        The default implementation simply calls `forward()`.

        :param rollout: recorded, complete rollout
        :param hidden_states_name: name of hidden states rollout entry, used for recurrent networks.
                                   Defaults to 'hidden_states'. Change for value functions.
        :return: actions with gradient data
        """
        self.eval()
        return self(rollout.get_data_values(
            'observations', truncate_last=True))  # all observations at once
示例#5
0
    def evaluate(self,
                 rollout: StepSequence,
                 hidden_states_name: str = 'hidden_states') -> to.Tensor:
        """
        Re-evaluate the given rollout and return a derivable action tensor.
        The default implementation simply calls `forward()`.

        :param rollout: complete rollout
        :param hidden_states_name: name of hidden states rollout entry, used for recurrent networks.
                                   Defaults to 'hidden_states'. Change for value functions.
        :return: actions with gradient data
        """
        # Set policy, i.e. PyTorch nn.Module, to evaluation mode
        self.eval()

        res = self(rollout.get_data_values(
            'observations', truncate_last=True))  # all observations at once

        # Set policy, i.e. PyTorch nn.Module, back to training mode
        self.train()

        return res