def scale(self, state, action): """Get epistemic scale of model.""" none = torch.tensor(0) obs = Observation(state, action, none, none, none, none, none, none, none, none) for transformation in self.forward_transformations: obs = transformation(obs) # Predict next-state scale = self.base_model.scale(obs.state, obs.action) # Back-transform obs = Observation( obs.state, obs.action, reward=none, done=none, next_action=none, log_prob_action=none, entropy=none, state_scale_tril=none, next_state=obs.state, next_state_scale_tril=scale, ) for transformation in self.reverse_transformations: obs = transformation.inverse(obs) return obs.next_state_scale_tril
def test_get_item(self, discrete, max_len, num_steps): num_episodes = 3 episode_length = 200 memory = create_er_from_episodes(discrete, max_len, num_steps, num_episodes, episode_length) memory.end_episode() observation, idx, weight = memory[0] for attribute in Observation(**observation): assert attribute.shape[0] == max(1, num_steps) assert idx == 0 assert weight == 1.0 for i in range(len(memory)): observation, idx, weight = memory[i] for attribute in Observation(**observation): assert attribute.shape[0] == max(1, num_steps) if memory.valid[i]: assert idx == i else: assert idx != i assert weight == 1.0 i = np.random.choice(memory.valid_indexes).item() observation, idx, weight = memory[i] for attribute in Observation(**observation): assert attribute.shape[0] == max(1, num_steps) assert idx == i assert weight == 1.0
def test_equality(self): state, action, reward, next_state, done = self.init() o = Observation(state, action, reward, next_state, done).to_torch() for x, y in zip( o, Observation(state, action, reward, next_state, done).to_torch() ): torch.testing.assert_allclose(x, y) assert o is not Observation(state, action, reward, next_state, done)
def _train_model_step(model, observation, optimizer, mask, logger): if not isinstance(observation, Observation): observation = Observation(**observation) observation.action = observation.action[..., :model.dim_action[0]] if isinstance(model, EnsembleModel): loss = train_ensemble_step(model, observation, optimizer, mask) elif isinstance(model, NNModel): loss = train_nn_step(model, observation, optimizer) elif isinstance(model, ExactGPModel): loss = train_exact_gp_type2mll_step(model, observation, optimizer) else: raise TypeError("Only Implemented for Ensembles and GP Models.") logger.update(**{f"{model.model_kind[:3]}-loss": loss.item()})
def test_example(self, discrete, dim_state, dim_action, kind): if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state,), (dim_action,) if kind == "nan": o = Observation.nan_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) elif kind == "zero": o = Observation.zero_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) elif kind == "random": o = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) else: with pytest.raises(ValueError): Observation.get_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, kind=kind, ) return if discrete: torch.testing.assert_allclose(o.state.shape, torch.Size([])) torch.testing.assert_allclose(o.action.shape, torch.Size([])) torch.testing.assert_allclose(o.next_state.shape, torch.Size([])) torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0)) else: torch.testing.assert_allclose(o.state.shape, torch.Size(dim_state)) torch.testing.assert_allclose(o.action.shape, torch.Size(dim_action)) torch.testing.assert_allclose(o.next_state.shape, torch.Size(dim_state)) torch.testing.assert_allclose(o.log_prob_action, torch.tensor(1.0))
def predict(self, state, action, next_state=None): """Get next_state distribution.""" none = torch.tensor(0) if next_state is None: next_state = none obs = Observation( state, action, none, next_state, none, none, none, none, none, none ) for transformation in self.forward_transformations: obs = transformation(obs) # Predict next-state if self.model_kind == "dynamics": reward, done = (none, none), none next_state = self.base_model(obs.state, obs.action, obs.next_state) elif self.model_kind == "rewards": reward = self.base_model(obs.state, obs.action, obs.next_state) next_state, done = (none, none), none elif self.model_kind == "termination": done = self.base_model(obs.state, obs.action, obs.next_state) next_state, reward = (none, none), (none, none) else: raise ValueError(f"{self.model_kind} not in {self.allowed_model_kind}") # Back-transform obs = Observation( obs.state, obs.action, reward=reward[0], done=done, next_action=none, log_prob_action=none, entropy=none, state_scale_tril=none, next_state=next_state[0], next_state_scale_tril=next_state[1], reward_scale_tril=reward[1], ) for transformation in self.reverse_transformations: obs = transformation.inverse(obs) if self.model_kind == "dynamics": return obs.next_state, obs.next_state_scale_tril elif self.model_kind == "rewards": return obs.reward, obs.reward_scale_tril elif self.model_kind == "termination": return obs.done
def __getitem__(self, idx): """Return any desired sub-trajectory. Parameters ---------- idx: int Returns ------- sub-trajectory: Observation """ if self.sequence_length is None: # get trajectory observation = self._trajectories[idx] else: # Get sub-trajectory. trajectory_idx, start = self._sub_trajectory_indexes[idx] end = start + self._sequence_length trajectory = self._trajectories[trajectory_idx] observation = Observation(**{ key: val[start:end] for key, val in asdict(trajectory).items() }) for transform in self.transformations: observation = transform(observation) return observation
def init_er_from_rollout(target_er, agent, environment, max_steps=1000): """Initialize an Experience Replay from an Experience Replay. Initialize an observation per state in the environment. Parameters ---------- target_er: Experience Replay. Experience replay to be filled. agent: Agent. Agent to act in environment. environment: Environment. Discrete environment. max_steps: int. Maximum number of steps in the environment. """ while not target_er.is_full: state = environment.reset() done = False while not done: action = agent.act(state) next_state, reward, done, _ = environment.step(action) observation = Observation( state=state, action=action, reward=reward, next_state=next_state, done=done, ).to_torch() state = next_state target_er.append(observation) if max_steps <= environment.time: break
def _validate_model_step(model, observation, logger): if not isinstance(observation, Observation): observation = Observation(**observation) observation.action = observation.action[..., :model.dim_action[0]] mse = model_mse(model, observation).item() sharpness_ = sharpness(model, observation).item() calibration_score_ = calibration_score(model, observation).item() logger.update( **{ f"{model.model_kind[:3]}-val-mse": mse, f"{model.model_kind[:3]}-sharp": sharpness_, f"{model.model_kind[:3]}-calib": calibration_score_, }) return mse
def step_env(environment, state, action, action_scale, pi=None, render=False): """Perform a single step in an environment.""" try: next_state, reward, done, info = environment.step(action) except TypeError: next_state, reward, done, info = environment.step(action.item()) if not isinstance(action, torch.Tensor): action = torch.tensor(action, dtype=torch.get_default_dtype()) if pi is not None: try: with torch.no_grad(): entropy, log_prob_action = get_entropy_and_log_p( pi, action, action_scale ) except RuntimeError: entropy, log_prob_action = 0.0, 1.0 else: entropy, log_prob_action = 0.0, 1.0 observation = Observation( state=state, action=action, reward=reward, next_state=next_state, done=done, entropy=entropy, log_prob_action=log_prob_action, ).to_torch() state = next_state if render: environment.render() return observation, state, done, info
def get_primal_q_reps( prior, eta, value_function, q_function, transitions, rewards, gamma ): """Get Primal for Q-REPS.""" num_states, num_actions = prior.shape td = torch.zeros(num_states, num_actions) states = torch.arange(num_states) for action in range(num_actions): action = torch.tensor(action) observation = Observation(state=states, action=action) td[:, action] = _compute_exact_td( value_function, observation, transitions, rewards, gamma, ) td = td + value_function(states).unsqueeze(-1) - q_function(states) max_td = torch.max(eta * td) mu_log = torch.log(prior * torch.exp(eta * td - max_td)) + max_td mu_dist = torch.distributions.Categorical(logits=mu_log.reshape(-1)) mu = mu_dist.probs.reshape(num_states, num_actions) adv = q_function(states) - value_function(states).unsqueeze(-1) max_adv = torch.max(eta * adv) d_log = torch.log(prior * torch.exp(eta * adv - max_adv)) + max_adv d_dist = torch.distributions.Categorical(logits=d_log.reshape(-1)) d = d_dist.probs.reshape(num_states, num_actions) return mu, d
def closure(): """Gradient calculation.""" states = Observation(state=self.sim_dataset.sample_batch( self.policy_opt_batch_size)) self.optimizer.zero_grad() losses = self.algorithm(states) losses.combined_loss.backward() return losses
def test_append_error(): dataset = TrajectoryDataset(sequence_length=10) trajectory = [ Observation(np.random.randn(4), np.random.randn(2), 1, np.random.randn(4), True).to_torch() ] with pytest.raises(ValueError): dataset.append(trajectory)
def sample_batch(self, batch_size): """Sample a batch of observations.""" indices = np.random.choice(self.valid_indexes, batch_size) if self.num_steps == 0: obs = self._get_observation(indices) return (obs, torch.tensor(indices), self.weights[indices]) else: obs, idx, weight = default_collate([self[i] for i in indices]) return Observation(**obs), idx, weight
def get_observation(reward=None): return Observation( state=torch.randn(4), action=torch.randn(4), reward=reward if reward else torch.randn(1), next_state=torch.randn(4), done=False, state_scale_tril=torch.randn(4, 4), next_state_scale_tril=torch.randn(4, 4), ).to_torch()
def forward(self, state): """Return policy logits.""" td = compute_exact_td( value_function=self.value_function, observation=Observation(state=state), transitions=self.transitions, rewards=self.rewards, gamma=self.gamma, support="state", ) return self.eta * td * self.counter
def test_clone(self, discrete, dim_state, dim_action): if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state,), (dim_action,) o = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) o1 = o.clone() assert o is not o1 assert o == o1 for x, x1 in zip(o, o1): assert Observation._is_equal_nan(x, x1) assert x is not x1
def get_trajectory(): t = [] for reward in [3.0, -2.0, 0.5]: t.append( Observation( state=torch.randn(4), action=torch.randn(2), reward=reward, next_state=torch.randn(4), done=False, )) return t
def _get_consecutive_observations(self, start_idx, num_steps): if num_steps == 0 and not (isinstance(start_idx, int) or isinstance(start_idx, np.int)): observation = stack_list_of_tuples(self.memory[start_idx]) return Observation(*map(lambda x: x.unsqueeze(1), observation)) num_steps = max(1, num_steps) if start_idx + num_steps <= self.max_len: obs_list = self.memory[start_idx:start_idx + num_steps] else: # The trajectory is split by the circular buffer. delta_idx = start_idx + num_steps - self.max_len obs_list = np.concatenate( (self.memory[start_idx:self.max_len], self.memory[:delta_idx])) return stack_list_of_tuples(obs_list)
def test_correctness(self, gamma, value_function, entropy_reg): trajectory = [ Observation(0, 0, reward=1, done=False, entropy=0.2).to_torch(), Observation(0, 0, reward=0.5, done=False, entropy=0.3).to_torch(), Observation(0, 0, reward=2, done=False, entropy=0.5).to_torch(), Observation(0, 0, reward=-0.2, done=False, entropy=-0.2).to_torch(), ] r0 = 1 + entropy_reg * 0.2 r1 = 0.5 + entropy_reg * 0.3 r2 = 2 + entropy_reg * 0.5 r3 = -0.2 - entropy_reg * 0.2 v = 0.01 if value_function is not None else 0 reward = mc_return( stack_list_of_tuples(trajectory, -2), gamma, value_function=value_function, entropy_regularization=entropy_reg, reduction="min", ) torch.testing.assert_allclose( reward, torch.tensor([ r0 + r1 * gamma + r2 * gamma**2 + r3 * gamma**3 + v * gamma**4 ]), ) assert (mc_return( Observation(state=0, reward=0).to_torch(), gamma, value_function, entropy_reg, ) == 0)
def init_er_from_er(target_er, source_er): """Initialize an Experience Replay from an Experience Replay. Copy all the transitions in the source ER to the target ER. Parameters ---------- target_er: Experience Replay Experience replay to be filled. source_er: Experience Replay Experience replay to be used. """ for i in range(len(source_er)): observation, idx, weight = source_er[i] target_er.append(Observation(**observation))
def _init_observation(self, observation): if observation.state.ndim == 0: dim_state, num_states = 1, 1 else: dim_state, num_states = observation.state.shape[-1], -1 if observation.action.ndim == 0: dim_action, num_actions = 1, 1 else: dim_action, num_actions = observation.action.shape[-1], -1 self.zero_observation = Observation.zero_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, )
def test_iter(self, discrete, max_len, num_steps): num_episodes = 3 episode_length = 200 memory = create_er_from_episodes(discrete, max_len, num_steps, num_episodes, episode_length) for idx, (observation, idx_, weight) in enumerate(memory): if idx >= len(memory): continue if memory.valid[idx] == 1: assert idx == idx_ else: assert idx != idx_ assert weight == 1.0 for attribute in Observation(**observation): assert attribute.shape[0] == max(1, num_steps)
def init_er_from_environment(target_er, environment): """Initialize an Experience Replay from an Experience Replay. Initialize an observation per state in the environment. The environment must have discrete states. Parameters ---------- target_er: Experience Replay. Experience replay to be filled. environment: Environment. Discrete environment. """ assert environment.num_states is not None for state in range(environment.num_states): observation = Observation(state, torch.empty(0)).to_torch() target_er.append(observation)
def collect_model_transitions(state_dist, policy, dynamical_model, reward_model, num_samples): """Collect transitions by interacting with an environment. Parameters ---------- state_dist: Distribution. State distribution. policy: AbstractPolicy or Distribution. Policy to interact with the environment. dynamical_model: AbstractModel. Model with which to interact. reward_model: AbstractReward. Reward model with which to interact. num_samples: int. Number of transitions. Returns ------- transitions: List[Observation] List of 1-step transitions. """ state = state_dist.sample((num_samples, )) if isinstance(policy, AbstractPolicy): action_dist = tensor_to_distribution(policy(state), **policy.dist_params) action = action_dist.sample() else: # action_distribution action_dist = policy action = action_dist.sample((num_samples, )) next_state = tensor_to_distribution(dynamical_model(state, action)).sample() reward = tensor_to_distribution(reward_model(state, action, next_state)).sample() transitions = [] for state_, action_, reward_, next_state_ in zip(state, action, reward, next_state): transitions.append( Observation(state_, action_, reward_, next_state_).to_torch()) return transitions
def create_er_from_transitions(discrete, dim_state, dim_action, max_len, num_steps, num_transitions): """Create a memory with `num_transitions' transitions.""" if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state, ), (dim_action, ) memory = ExperienceReplay(max_len, num_steps=num_steps) for _ in range(num_transitions): observation = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) memory.append(observation) return memory
def get_primal_reps(eta, value_function, transitions, rewards, gamma, prior=None): """Get Primal for REPS.""" num_states, num_actions = rewards.shape td = torch.zeros(num_states, num_actions) states = torch.arange(num_states) for action in range(num_actions): action = torch.tensor(action) observation = Observation(state=states, action=action) td[:, action] = _compute_exact_td( value_function, observation, transitions, rewards, gamma, ) max_td = torch.max(eta * td) if prior is not None: mu_log = torch.log(prior * torch.exp(eta * td - max_td)) + max_td mu_dist = torch.distributions.Categorical(logits=mu_log.reshape(-1)) else: mu_dist = torch.distributions.Categorical(logits=(eta * td).reshape(-1)) mu = mu_dist.probs.reshape(num_states, num_actions) return mu
def dataset(request): num_episodes = request.param[0] episode_length = request.param[1] state_dim = request.param[2] action_dim = request.param[3] batch_size = request.param[4] sequence_length = request.param[5] transforms = [ MeanFunction(lambda state, action: state), StateNormalizer(), ActionNormalizer(), ] dataset = TrajectoryDataset(sequence_length=sequence_length, transformations=transforms) for _ in range(num_episodes): trajectory = [] for i in range(episode_length): trajectory.append( Observation( state=np.random.randn(state_dim), action=np.random.randn(action_dim), reward=np.random.randn(), next_state=np.random.randn(state_dim), done=i == (episode_length - 1), ).to_torch()) dataset.append(trajectory) return ( dataset, num_episodes, episode_length, state_dim, action_dim, batch_size, sequence_length, )
def collect_environment_transitions(state_dist, policy, environment, num_samples): """Collect transitions by interacting with an environment. Parameters ---------- state_dist: Distribution. State distribution. policy: AbstractPolicy or Distribution. Policy to interact with the environment. environment: AbstractEnvironment. Environment with which to interact. num_samples: int. Number of transitions. Returns ------- transitions: List[Observation] List of 1-step transitions. """ transitions = [] for _ in range(num_samples): state = state_dist.sample() if isinstance(policy, AbstractPolicy): action_dist = tensor_to_distribution(policy(state), **policy.dist_params) else: # random action_distribution action_dist = policy action = action_dist.sample() state = state.numpy() action = action.numpy() environment.state = state next_state, reward, done, _ = environment.step(action) transitions.append( Observation(state, action, reward, next_state).to_torch()) return transitions
def test_append(self, discrete, dim_state, dim_action, max_len, num_steps): num_transitions = 200 memory = create_er_from_transitions(discrete, dim_state, dim_action, max_len, num_steps, num_transitions) if discrete: num_states, num_actions = dim_state, dim_action dim_state, dim_action = (), () else: num_states, num_actions = -1, -1 dim_state, dim_action = (dim_state, ), (dim_action, ) observation = Observation.random_example( dim_state=dim_state, dim_action=dim_action, num_states=num_states, num_actions=num_actions, ) memory.append(observation) assert memory.valid[(memory.ptr - 1) % max_len] == 1 assert memory.valid[(memory.ptr - 2) % max_len] == 1 for i in range(num_steps): assert memory.valid[(memory.ptr + i) % max_len] == 0 assert memory.memory[(memory.ptr - 1) % max_len] is not observation