Esempio n. 1
0
    def scale(self, state, action):
        """Get epistemic scale of model."""
        none = torch.tensor(0)
        obs = Observation(state, action, none, none, none, none, none, none, none, none)
        for transformation in self.forward_transformations:
            obs = transformation(obs)

        # Predict next-state
        scale = self.base_model.scale(obs.state, obs.action)

        # Back-transform
        obs = Observation(
            obs.state,
            obs.action,
            reward=none,
            done=none,
            next_action=none,
            log_prob_action=none,
            entropy=none,
            state_scale_tril=none,
            next_state=obs.state,
            next_state_scale_tril=scale,
        )

        for transformation in self.reverse_transformations:
            obs = transformation.inverse(obs)
        return obs.next_state_scale_tril
Esempio n. 2
0
    def test_get_item(self, discrete, max_len, num_steps):
        num_episodes = 3
        episode_length = 200
        memory = create_er_from_episodes(discrete, max_len, num_steps,
                                         num_episodes, episode_length)
        memory.end_episode()

        observation, idx, weight = memory[0]
        for attribute in Observation(**observation):
            assert attribute.shape[0] == max(1, num_steps)
            assert idx == 0
            assert weight == 1.0

        for i in range(len(memory)):
            observation, idx, weight = memory[i]
            for attribute in Observation(**observation):
                assert attribute.shape[0] == max(1, num_steps)
                if memory.valid[i]:
                    assert idx == i
                else:
                    assert idx != i
                assert weight == 1.0

        i = np.random.choice(memory.valid_indexes).item()
        observation, idx, weight = memory[i]
        for attribute in Observation(**observation):
            assert attribute.shape[0] == max(1, num_steps)
            assert idx == i
            assert weight == 1.0
Esempio n. 3
0
 def test_equality(self):
     state, action, reward, next_state, done = self.init()
     o = Observation(state, action, reward, next_state, done).to_torch()
     for x, y in zip(
         o, Observation(state, action, reward, next_state, done).to_torch()
     ):
         torch.testing.assert_allclose(x, y)
     assert o is not Observation(state, action, reward, next_state, done)
Esempio n. 4
0
def _train_model_step(model, observation, optimizer, mask, logger):
    if not isinstance(observation, Observation):
        observation = Observation(**observation)
    observation.action = observation.action[..., :model.dim_action[0]]
    if isinstance(model, EnsembleModel):
        loss = train_ensemble_step(model, observation, optimizer, mask)
    elif isinstance(model, NNModel):
        loss = train_nn_step(model, observation, optimizer)
    elif isinstance(model, ExactGPModel):
        loss = train_exact_gp_type2mll_step(model, observation, optimizer)
    else:
        raise TypeError("Only Implemented for Ensembles and GP Models.")
    logger.update(**{f"{model.model_kind[:3]}-loss": loss.item()})
Esempio n. 5
0
    def test_example(self, discrete, dim_state, dim_action, kind):
        if discrete:
            num_states, num_actions = dim_state, dim_action
            dim_state, dim_action = (), ()
        else:
            num_states, num_actions = -1, -1
            dim_state, dim_action = (dim_state,), (dim_action,)
        if kind == "nan":
            o = Observation.nan_example(
                dim_state=dim_state,
                dim_action=dim_action,
                num_states=num_states,
                num_actions=num_actions,
            )
        elif kind == "zero":
            o = Observation.zero_example(
                dim_state=dim_state,
                dim_action=dim_action,
                num_states=num_states,
                num_actions=num_actions,
            )
        elif kind == "random":
            o = Observation.random_example(
                dim_state=dim_state,
                dim_action=dim_action,
                num_states=num_states,
                num_actions=num_actions,
            )
        else:
            with pytest.raises(ValueError):
                Observation.get_example(
                    dim_state=dim_state,
                    dim_action=dim_action,
                    num_states=num_states,
                    num_actions=num_actions,
                    kind=kind,
                )
            return

        if discrete:
            torch.testing.assert_allclose(o.state.shape, torch.Size([]))
            torch.testing.assert_allclose(o.action.shape, torch.Size([]))
            torch.testing.assert_allclose(o.next_state.shape, torch.Size([]))
            torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0))

        else:
            torch.testing.assert_allclose(o.state.shape, torch.Size(dim_state))
            torch.testing.assert_allclose(o.action.shape, torch.Size(dim_action))
            torch.testing.assert_allclose(o.next_state.shape, torch.Size(dim_state))
            torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0))
Esempio n. 6
0
    def predict(self, state, action, next_state=None):
        """Get next_state distribution."""
        none = torch.tensor(0)
        if next_state is None:
            next_state = none
        obs = Observation(
            state, action, none, next_state, none, none, none, none, none, none
        )
        for transformation in self.forward_transformations:
            obs = transformation(obs)

        # Predict next-state
        if self.model_kind == "dynamics":
            reward, done = (none, none), none
            next_state = self.base_model(obs.state, obs.action, obs.next_state)
        elif self.model_kind == "rewards":
            reward = self.base_model(obs.state, obs.action, obs.next_state)
            next_state, done = (none, none), none
        elif self.model_kind == "termination":
            done = self.base_model(obs.state, obs.action, obs.next_state)
            next_state, reward = (none, none), (none, none)
        else:
            raise ValueError(f"{self.model_kind} not in {self.allowed_model_kind}")

        # Back-transform
        obs = Observation(
            obs.state,
            obs.action,
            reward=reward[0],
            done=done,
            next_action=none,
            log_prob_action=none,
            entropy=none,
            state_scale_tril=none,
            next_state=next_state[0],
            next_state_scale_tril=next_state[1],
            reward_scale_tril=reward[1],
        )

        for transformation in self.reverse_transformations:
            obs = transformation.inverse(obs)

        if self.model_kind == "dynamics":
            return obs.next_state, obs.next_state_scale_tril
        elif self.model_kind == "rewards":
            return obs.reward, obs.reward_scale_tril
        elif self.model_kind == "termination":
            return obs.done
Esempio n. 7
0
    def __getitem__(self, idx):
        """Return any desired sub-trajectory.

        Parameters
        ----------
        idx: int

        Returns
        -------
        sub-trajectory: Observation
        """
        if self.sequence_length is None:  # get trajectory
            observation = self._trajectories[idx]
        else:  # Get sub-trajectory.
            trajectory_idx, start = self._sub_trajectory_indexes[idx]
            end = start + self._sequence_length

            trajectory = self._trajectories[trajectory_idx]

            observation = Observation(**{
                key: val[start:end]
                for key, val in asdict(trajectory).items()
            })

        for transform in self.transformations:
            observation = transform(observation)

        return observation
Esempio n. 8
0
def init_er_from_rollout(target_er, agent, environment, max_steps=1000):
    """Initialize an Experience Replay from an Experience Replay.

    Initialize an observation per state in the environment.

    Parameters
    ----------
    target_er: Experience Replay.
        Experience replay to be filled.
    agent: Agent.
        Agent to act in environment.
    environment: Environment.
        Discrete environment.
    max_steps: int.
        Maximum number of steps in the environment.
    """
    while not target_er.is_full:
        state = environment.reset()
        done = False
        while not done:
            action = agent.act(state)
            next_state, reward, done, _ = environment.step(action)
            observation = Observation(
                state=state,
                action=action,
                reward=reward,
                next_state=next_state,
                done=done,
            ).to_torch()
            state = next_state

            target_er.append(observation)
            if max_steps <= environment.time:
                break
Esempio n. 9
0
def _validate_model_step(model, observation, logger):
    if not isinstance(observation, Observation):
        observation = Observation(**observation)
    observation.action = observation.action[..., :model.dim_action[0]]

    mse = model_mse(model, observation).item()
    sharpness_ = sharpness(model, observation).item()
    calibration_score_ = calibration_score(model, observation).item()

    logger.update(
        **{
            f"{model.model_kind[:3]}-val-mse": mse,
            f"{model.model_kind[:3]}-sharp": sharpness_,
            f"{model.model_kind[:3]}-calib": calibration_score_,
        })
    return mse
Esempio n. 10
0
def step_env(environment, state, action, action_scale, pi=None, render=False):
    """Perform a single step in an environment."""
    try:
        next_state, reward, done, info = environment.step(action)
    except TypeError:
        next_state, reward, done, info = environment.step(action.item())

    if not isinstance(action, torch.Tensor):
        action = torch.tensor(action, dtype=torch.get_default_dtype())

    if pi is not None:
        try:
            with torch.no_grad():
                entropy, log_prob_action = get_entropy_and_log_p(
                    pi, action, action_scale
                )
        except RuntimeError:
            entropy, log_prob_action = 0.0, 1.0
    else:
        entropy, log_prob_action = 0.0, 1.0

    observation = Observation(
        state=state,
        action=action,
        reward=reward,
        next_state=next_state,
        done=done,
        entropy=entropy,
        log_prob_action=log_prob_action,
    ).to_torch()
    state = next_state
    if render:
        environment.render()
    return observation, state, done, info
Esempio n. 11
0
def get_primal_q_reps(
    prior, eta, value_function, q_function, transitions, rewards, gamma
):
    """Get Primal for Q-REPS."""
    num_states, num_actions = prior.shape
    td = torch.zeros(num_states, num_actions)
    states = torch.arange(num_states)
    for action in range(num_actions):
        action = torch.tensor(action)
        observation = Observation(state=states, action=action)
        td[:, action] = _compute_exact_td(
            value_function, observation, transitions, rewards, gamma,
        )

    td = td + value_function(states).unsqueeze(-1) - q_function(states)
    max_td = torch.max(eta * td)

    mu_log = torch.log(prior * torch.exp(eta * td - max_td)) + max_td
    mu_dist = torch.distributions.Categorical(logits=mu_log.reshape(-1))
    mu = mu_dist.probs.reshape(num_states, num_actions)

    adv = q_function(states) - value_function(states).unsqueeze(-1)
    max_adv = torch.max(eta * adv)
    d_log = torch.log(prior * torch.exp(eta * adv - max_adv)) + max_adv
    d_dist = torch.distributions.Categorical(logits=d_log.reshape(-1))
    d = d_dist.probs.reshape(num_states, num_actions)

    return mu, d
Esempio n. 12
0
 def closure():
     """Gradient calculation."""
     states = Observation(state=self.sim_dataset.sample_batch(
         self.policy_opt_batch_size))
     self.optimizer.zero_grad()
     losses = self.algorithm(states)
     losses.combined_loss.backward()
     return losses
Esempio n. 13
0
def test_append_error():
    dataset = TrajectoryDataset(sequence_length=10)
    trajectory = [
        Observation(np.random.randn(4), np.random.randn(2), 1,
                    np.random.randn(4), True).to_torch()
    ]
    with pytest.raises(ValueError):
        dataset.append(trajectory)
Esempio n. 14
0
 def sample_batch(self, batch_size):
     """Sample a batch of observations."""
     indices = np.random.choice(self.valid_indexes, batch_size)
     if self.num_steps == 0:
         obs = self._get_observation(indices)
         return (obs, torch.tensor(indices), self.weights[indices])
     else:
         obs, idx, weight = default_collate([self[i] for i in indices])
         return Observation(**obs), idx, weight
Esempio n. 15
0
def get_observation(reward=None):
    return Observation(
        state=torch.randn(4),
        action=torch.randn(4),
        reward=reward if reward else torch.randn(1),
        next_state=torch.randn(4),
        done=False,
        state_scale_tril=torch.randn(4, 4),
        next_state_scale_tril=torch.randn(4, 4),
    ).to_torch()
Esempio n. 16
0
 def forward(self, state):
     """Return policy logits."""
     td = compute_exact_td(
         value_function=self.value_function,
         observation=Observation(state=state),
         transitions=self.transitions,
         rewards=self.rewards,
         gamma=self.gamma,
         support="state",
     )
     return self.eta * td * self.counter
Esempio n. 17
0
    def test_clone(self, discrete, dim_state, dim_action):
        if discrete:
            num_states, num_actions = dim_state, dim_action
            dim_state, dim_action = (), ()
        else:
            num_states, num_actions = -1, -1
            dim_state, dim_action = (dim_state,), (dim_action,)

        o = Observation.random_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )
        o1 = o.clone()
        assert o is not o1
        assert o == o1
        for x, x1 in zip(o, o1):
            assert Observation._is_equal_nan(x, x1)
            assert x is not x1
Esempio n. 18
0
def get_trajectory():
    t = []
    for reward in [3.0, -2.0, 0.5]:
        t.append(
            Observation(
                state=torch.randn(4),
                action=torch.randn(2),
                reward=reward,
                next_state=torch.randn(4),
                done=False,
            ))
    return t
Esempio n. 19
0
    def _get_consecutive_observations(self, start_idx, num_steps):
        if num_steps == 0 and not (isinstance(start_idx, int)
                                   or isinstance(start_idx, np.int)):
            observation = stack_list_of_tuples(self.memory[start_idx])
            return Observation(*map(lambda x: x.unsqueeze(1), observation))
        num_steps = max(1, num_steps)
        if start_idx + num_steps <= self.max_len:
            obs_list = self.memory[start_idx:start_idx + num_steps]
        else:  # The trajectory is split by the circular buffer.
            delta_idx = start_idx + num_steps - self.max_len
            obs_list = np.concatenate(
                (self.memory[start_idx:self.max_len], self.memory[:delta_idx]))

        return stack_list_of_tuples(obs_list)
Esempio n. 20
0
    def test_correctness(self, gamma, value_function, entropy_reg):
        trajectory = [
            Observation(0, 0, reward=1, done=False, entropy=0.2).to_torch(),
            Observation(0, 0, reward=0.5, done=False, entropy=0.3).to_torch(),
            Observation(0, 0, reward=2, done=False, entropy=0.5).to_torch(),
            Observation(0, 0, reward=-0.2, done=False,
                        entropy=-0.2).to_torch(),
        ]

        r0 = 1 + entropy_reg * 0.2
        r1 = 0.5 + entropy_reg * 0.3
        r2 = 2 + entropy_reg * 0.5
        r3 = -0.2 - entropy_reg * 0.2

        v = 0.01 if value_function is not None else 0

        reward = mc_return(
            stack_list_of_tuples(trajectory, -2),
            gamma,
            value_function=value_function,
            entropy_regularization=entropy_reg,
            reduction="min",
        )

        torch.testing.assert_allclose(
            reward,
            torch.tensor([
                r0 + r1 * gamma + r2 * gamma**2 + r3 * gamma**3 + v * gamma**4
            ]),
        )
        assert (mc_return(
            Observation(state=0, reward=0).to_torch(),
            gamma,
            value_function,
            entropy_reg,
        ) == 0)
Esempio n. 21
0
def init_er_from_er(target_er, source_er):
    """Initialize an Experience Replay from an Experience Replay.

    Copy all the transitions in the source ER to the target ER.

    Parameters
    ----------
    target_er: Experience Replay
        Experience replay to be filled.
    source_er: Experience Replay
        Experience replay to be used.
    """
    for i in range(len(source_er)):
        observation, idx, weight = source_er[i]
        target_er.append(Observation(**observation))
Esempio n. 22
0
    def _init_observation(self, observation):
        if observation.state.ndim == 0:
            dim_state, num_states = 1, 1
        else:
            dim_state, num_states = observation.state.shape[-1], -1

        if observation.action.ndim == 0:
            dim_action, num_actions = 1, 1
        else:
            dim_action, num_actions = observation.action.shape[-1], -1

        self.zero_observation = Observation.zero_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )
Esempio n. 23
0
    def test_iter(self, discrete, max_len, num_steps):
        num_episodes = 3
        episode_length = 200
        memory = create_er_from_episodes(discrete, max_len, num_steps,
                                         num_episodes, episode_length)

        for idx, (observation, idx_, weight) in enumerate(memory):
            if idx >= len(memory):
                continue

            if memory.valid[idx] == 1:
                assert idx == idx_
            else:
                assert idx != idx_

            assert weight == 1.0
            for attribute in Observation(**observation):
                assert attribute.shape[0] == max(1, num_steps)
Esempio n. 24
0
def init_er_from_environment(target_er, environment):
    """Initialize an Experience Replay from an Experience Replay.

    Initialize an observation per state in the environment.
    The environment must have discrete states.

    Parameters
    ----------
    target_er: Experience Replay.
        Experience replay to be filled.
    environment: Environment.
        Discrete environment.
    """
    assert environment.num_states is not None

    for state in range(environment.num_states):
        observation = Observation(state, torch.empty(0)).to_torch()
        target_er.append(observation)
Esempio n. 25
0
def collect_model_transitions(state_dist, policy, dynamical_model,
                              reward_model, num_samples):
    """Collect transitions by interacting with an environment.

    Parameters
    ----------
    state_dist: Distribution.
        State distribution.
    policy: AbstractPolicy or Distribution.
        Policy to interact with the environment.
    dynamical_model: AbstractModel.
        Model with which to interact.
    reward_model: AbstractReward.
        Reward model with which to interact.
    num_samples: int.
        Number of transitions.

    Returns
    -------
    transitions: List[Observation]
        List of 1-step transitions.

    """
    state = state_dist.sample((num_samples, ))
    if isinstance(policy, AbstractPolicy):
        action_dist = tensor_to_distribution(policy(state),
                                             **policy.dist_params)
        action = action_dist.sample()
    else:  # action_distribution
        action_dist = policy
        action = action_dist.sample((num_samples, ))

    next_state = tensor_to_distribution(dynamical_model(state,
                                                        action)).sample()
    reward = tensor_to_distribution(reward_model(state, action,
                                                 next_state)).sample()

    transitions = []
    for state_, action_, reward_, next_state_ in zip(state, action, reward,
                                                     next_state):
        transitions.append(
            Observation(state_, action_, reward_, next_state_).to_torch())
    return transitions
Esempio n. 26
0
def create_er_from_transitions(discrete, dim_state, dim_action, max_len,
                               num_steps, num_transitions):
    """Create a memory with `num_transitions' transitions."""
    if discrete:
        num_states, num_actions = dim_state, dim_action
        dim_state, dim_action = (), ()
    else:
        num_states, num_actions = -1, -1
        dim_state, dim_action = (dim_state, ), (dim_action, )

    memory = ExperienceReplay(max_len, num_steps=num_steps)
    for _ in range(num_transitions):
        observation = Observation.random_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )
        memory.append(observation)
    return memory
Esempio n. 27
0
def get_primal_reps(eta, value_function, transitions, rewards, gamma, prior=None):
    """Get Primal for REPS."""
    num_states, num_actions = rewards.shape
    td = torch.zeros(num_states, num_actions)
    states = torch.arange(num_states)
    for action in range(num_actions):
        action = torch.tensor(action)
        observation = Observation(state=states, action=action)
        td[:, action] = _compute_exact_td(
            value_function, observation, transitions, rewards, gamma,
        )

    max_td = torch.max(eta * td)
    if prior is not None:
        mu_log = torch.log(prior * torch.exp(eta * td - max_td)) + max_td
        mu_dist = torch.distributions.Categorical(logits=mu_log.reshape(-1))
    else:
        mu_dist = torch.distributions.Categorical(logits=(eta * td).reshape(-1))
    mu = mu_dist.probs.reshape(num_states, num_actions)
    return mu
Esempio n. 28
0
def dataset(request):
    num_episodes = request.param[0]
    episode_length = request.param[1]
    state_dim = request.param[2]
    action_dim = request.param[3]
    batch_size = request.param[4]
    sequence_length = request.param[5]

    transforms = [
        MeanFunction(lambda state, action: state),
        StateNormalizer(),
        ActionNormalizer(),
    ]

    dataset = TrajectoryDataset(sequence_length=sequence_length,
                                transformations=transforms)

    for _ in range(num_episodes):
        trajectory = []
        for i in range(episode_length):
            trajectory.append(
                Observation(
                    state=np.random.randn(state_dim),
                    action=np.random.randn(action_dim),
                    reward=np.random.randn(),
                    next_state=np.random.randn(state_dim),
                    done=i == (episode_length - 1),
                ).to_torch())

        dataset.append(trajectory)
    return (
        dataset,
        num_episodes,
        episode_length,
        state_dim,
        action_dim,
        batch_size,
        sequence_length,
    )
Esempio n. 29
0
def collect_environment_transitions(state_dist, policy, environment,
                                    num_samples):
    """Collect transitions by interacting with an environment.

    Parameters
    ----------
    state_dist: Distribution.
        State distribution.
    policy: AbstractPolicy or Distribution.
        Policy to interact with the environment.
    environment: AbstractEnvironment.
        Environment with which to interact.
    num_samples: int.
        Number of transitions.

    Returns
    -------
    transitions: List[Observation]
        List of 1-step transitions.

    """
    transitions = []
    for _ in range(num_samples):
        state = state_dist.sample()
        if isinstance(policy, AbstractPolicy):
            action_dist = tensor_to_distribution(policy(state),
                                                 **policy.dist_params)
        else:  # random action_distribution
            action_dist = policy
        action = action_dist.sample()

        state = state.numpy()
        action = action.numpy()
        environment.state = state
        next_state, reward, done, _ = environment.step(action)
        transitions.append(
            Observation(state, action, reward, next_state).to_torch())

    return transitions
Esempio n. 30
0
    def test_append(self, discrete, dim_state, dim_action, max_len, num_steps):
        num_transitions = 200
        memory = create_er_from_transitions(discrete, dim_state, dim_action,
                                            max_len, num_steps,
                                            num_transitions)
        if discrete:
            num_states, num_actions = dim_state, dim_action
            dim_state, dim_action = (), ()
        else:
            num_states, num_actions = -1, -1
            dim_state, dim_action = (dim_state, ), (dim_action, )
        observation = Observation.random_example(
            dim_state=dim_state,
            dim_action=dim_action,
            num_states=num_states,
            num_actions=num_actions,
        )

        memory.append(observation)
        assert memory.valid[(memory.ptr - 1) % max_len] == 1
        assert memory.valid[(memory.ptr - 2) % max_len] == 1
        for i in range(num_steps):
            assert memory.valid[(memory.ptr + i) % max_len] == 0
        assert memory.memory[(memory.ptr - 1) % max_len] is not observation