Example #1
0
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        # TODO: need to decide which action to take
        pi, value = self.model(*model_inputs)
        int_pi, int_value = self.model_int(*model_inputs)

        dist_info = DistInfo(prob=pi)

        if self.dual_model:
            pi_int, pi_int = self.model_int(*model_inputs)
            dist_int_info = DistInfo(prob=pi_int)
            if self._mode == "eval":
                action = self.distribution.sample(dist_info)
            else:
                action = self.distribution.sample(dist_int_info)
        else:
            action = self.distribution.sample(dist_info)

        if self.dual_model:
            agent_info = AgentInfoTwin(dist_info=dist_info,
                                       value=value,
                                       dist_int_info=dist_int_info,
                                       int_value=int_value)
        else:
            agent_info = AgentInfo(dist_info=dist_info, value=value)

        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)
Example #2
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              device="cpu"):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, sampled_option),
         device=self.device)
     pi, beta, q, pi_omega = self.model(*model_inputs[:-1])
     return buffer_to(
         (DistInfo(prob=select_at_indexes(sampled_option, pi)), q, beta,
          DistInfo(prob=pi_omega)),
         device=device)
Example #3
0
 def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
     # Assume init_rnn_state already shaped: [N,B,H]
     model_inputs = buffer_to((observation, prev_action, prev_reward,
                               init_rnn_state), device=self.device)
     pi, value, next_rnn_state = self.model(*model_inputs)
     dist_info, value = buffer_to((DistInfo(prob=pi), value), device="cpu")
     return dist_info, value, next_rnn_state  # Leave rnn_state on device.
Example #4
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's option and action distributions from inputs.
     Calls model to get mean, std for all pi_w, q, beta for all options, pi over options
     Moves inputs to device and returns outputs back to CPU, for the
     sampler.  (no grad)
     """
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi = self.model(*model_inputs)
     dist_info_omega = DistInfo(prob=pi)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOC(dist_info=dist_info,
                              dist_info_o=dist_info_o,
                              q=q,
                              value=(pi * q).sum(-1),
                              termination=terminations,
                              dist_info_omega=dist_info_omega,
                              prev_o=self._prev_option,
                              o=new_o)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     probs, value, rnn_state = self.model(*agent_inputs,
                                          self.prev_rnn_state)
     dist_info = DistInfo(prob=probs)
     if self._mode == 'sample':
         action = self.distribution.sample(dist_info)
     elif self._mode == 'eval':
         action = torch.argmax(probs, dim=-1)
     # Model handles None, but Buffer does not, make zeros if needed:
     if self.prev_rnn_state is None:
         prev_rnn_state = buffer_func(rnn_state, torch.zeros_like)
     else:
         prev_rnn_state = self.prev_rnn_state
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Example #6
0
 def __call__(self, observation, prev_action, prev_reward):
     prev_action = self.format_actions(prev_action)
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     pi, ext_value, int_value = self.model(*model_inputs)
     return buffer_to((DistInfo(prob=pi), ext_value, int_value),
                      device="cpu")
Example #7
0
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        dist_info = DistInfo(prob=pi)

        if self.dual_model:
            int_pi, int_value, int_rnn_state = self.model_int(
                *agent_inputs, self.prev_int_rnn_state)
            dist_int_info = DistInfo(prob=int_pi)
            if self._mode == "eval":
                action = self.distribution.sample(dist_info)
            else:
                action = self.distribution.sample(dist_int_info)
        else:
            action = self.distribution.sample(dist_info)

        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state or buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)

        if self.dual_model:
            prev_int_rnn_state = self.prev_int_rnn_state or buffer_func(
                int_rnn_state, torch.zeros_like)
            prev_int_rnn_state = buffer_method(prev_int_rnn_state, "transpose",
                                               0, 1)
            agent_info = AgentInfoRnnTwin(
                dist_info=dist_info,
                value=value,
                prev_rnn_state=prev_rnn_state,
                dist_int_info=dist_int_info,
                int_value=int_value,
                prev_int_rnn_state=prev_int_rnn_state)
        else:
            agent_info = AgentInfoRnn(dist_info=dist_info,
                                      value=value,
                                      prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        self.advance_rnn_state(rnn_state)  # Keep on device.
        if self.dual_model:
            self.advance_int_rnn_state(int_rnn_state)
        return AgentStep(action=action, agent_info=agent_info)
Example #8
0
    def inverse_loss(self, samples):
        observation = samples.observation[0]  # [T,B,C,H,W]->[B,C,H,W]
        last_observation = samples.observation[-1]

        if self.random_shift_prob > 0.:
            observation = random_shift(
                imgs=observation,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )
            last_observation = random_shift(
                imgs=last_observation,
                pad=self.random_shift_pad,
                prob=self.random_shift_prob,
            )

        action = samples.action  # [T,B,A]
        # if self.onehot_actions:
        #     action = to_onehot(action, self._act_dim, dtype=torch.float)
        observation, last_observation, action = buffer_to(
            (observation, last_observation, action), device=self.device)

        _, conv_obs = self.encoder(observation)
        _, conv_last = self.encoder(last_observation)

        valid = valid_from_done(samples.done).type(torch.bool)  # [T,B]
        # All timesteps invalid if the last_observation is:
        valid = valid[-1].repeat(self.delta_T, 1).transpose(1, 0)  # [B,T-1]

        if self.onehot_actions:
            logits = self.inverse_model(conv_obs, conv_last)  # [B,T-1,A]
            labels = action[:-1].transpose(1,
                                           0)  # [B,T-1], not the last action
            labels[~valid] = IGNORE_INDEX

            b, t, a = logits.shape
            logits = logits.view(b * t, a)
            labels = labels.reshape(b * t)
            logits = logits - torch.max(logits, dim=1, keepdim=True)[0]
            inv_loss = self.c_e_loss(logits, labels)

            valid = valid.reshape(b * t).to(self.device)
            dist_info = DistInfo(prob=F.softmax(logits, dim=1))
            entropy = self.distribution.mean_entropy(
                dist_info=dist_info,
                valid=valid,
            )
            entropy_loss = -self.entropy_loss_coeff * entropy

            correct = torch.argmax(logits.detach(), dim=1) == labels
            accuracy = torch.mean(correct[valid].float())

        else:
            raise NotImplementedError

        perplexity = self.distribution.mean_perplexity(dist_info,
                                                       valid.to(self.device))

        return inv_loss, entropy_loss, accuracy, perplexity, conv_obs
Example #9
0
 def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, init_rnn_state), device=self.device
     )
     pi, value, next_rnn_state, _ = self.model(*model_inputs)  # Ignore conv out
     dist_info, value = buffer_to((DistInfo(prob=pi), value), device="cpu")
     return dist_info, value, next_rnn_state
 def step(self, observation, prev_action, prev_reward):
     model_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     action, action_probs, log_action_probs = self.model(*model_inputs)
     dist_info = DistInfo(prob=action_probs)
     agent_info = AgentInfo(dist_info=dist_info)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
 def pi(self, observation, prev_action, prev_reward):
     """Compute action log-probabilities for state/observation, and
     sample new action (with grad).  Uses special ``sample_loglikelihood()``
     method of Gaussian distriution, which handles action squashing
     through this process."""
     model_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     action, action_probs, log_action_probs = self.model(*model_inputs)
     dist_info = DistInfo(prob=action_probs)
     action_probs, log_pi, dist_info = buffer_to((action_probs, log_action_probs, dist_info), device="cpu")
     return action, action_probs, log_pi, dist_info  # Action stays on device for q models.
Example #12
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.format_actions(prev_action)
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     pi, ext_value, int_value = self.model(*model_inputs)
     dist_info = DistInfo(prob=pi)
     action = self.distribution.sample(dist_info)
     agent_info = IntAgentInfo(dist_info=dist_info,
                               ext_value=ext_value,
                               int_value=int_value)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Example #13
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     pi, beta, q, pi_omega = self.model(*model_inputs)
     dist_info_omega = DistInfo(prob=pi_omega)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     dist_info = DistInfo(prob=pi)
     dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOC(dist_info=dist_info,
                              dist_info_o=dist_info_o,
                              q=q,
                              value=(pi_omega * q).sum(-1),
                              termination=terminations,
                              dist_info_omega=dist_info_omega,
                              prev_o=self._prev_option,
                              o=new_o)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Example #14
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     pi, value, conv = self.model(*model_inputs)
     if self._act_uniform:
         pi[:] = 1. / pi.shape[-1]  # uniform
     dist_info = DistInfo(prob=pi)
     action = self.distribution.sample(dist_info)
     agent_info = AgentInfoConv(dist_info=dist_info, value=value,
         conv=conv if self.store_latent else None)  # Don't write extra data.
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Example #15
0
 def step(self, observation, prev_action=None, prev_reward=None):
     pi, value, sym_features = self.model(
         observation.to(device=self.device), extract_sym_features=True)
     dist_info = DistInfo(prob=pi)
     action = self.distribution.sample(dist_info)
     # either sym_features should always be given or never
     if sym_features is not None:
         agent_info = SafeAgentInfo(dist_info=dist_info,
                                    value=value,
                                    sym_features=sym_features)
     else:
         agent_info = AgentInfo(dist_info=dist_info, value=value)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Example #16
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     observation = observation.type(
         torch.float)  # Expect torch.uint8 inputs
     observation = observation.mul_(1. /
                                    255)  # From [0-255] to [0-1], in place.
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     pi, value = self.model(*model_inputs)
     dist_info = DistInfo(prob=pi)
     action = self.distribution.sample(dist_info)
     agent_info = AgentInfo(dist_info=dist_info, value=value)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Example #17
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     prev_option_input = self._prev_option
     if prev_option_input is None:  # Hack to extract previous option
         prev_option_input = torch.full_like(prev_action, -1)
     prev_action = self.distribution.to_onehot(prev_action)
     prev_option_input = self.distribution_omega.to_onehot_with_invalid(
         prev_option_input)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, prev_option_input),
         device=self.device)
     pi, beta, q, pi_omega, rnn_state = self.model(*model_inputs,
                                                   self.prev_rnn_state)
     dist_info_omega = DistInfo(prob=pi_omega)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     dist_info = DistInfo(prob=pi)
     dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi_omega * q).sum(-1),
                                 termination=terminations,
                                 dist_info_omega=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     self.advance_rnn_state(rnn_state)
     return AgentStep(action=action, agent_info=agent_info)
Example #18
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              device="cpu"):
     """Performs forward pass on training data, for algorithm. Returns sampled distinfo, q, beta, and piomega distinfo"""
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, sampled_option),
         device=self.device)
     mu, log_std, beta, q, pi = self.model(*model_inputs[:-1])
     # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo)
     return buffer_to(
         (DistInfoStd(mean=select_at_indexes(sampled_option, mu),
                      log_std=select_at_indexes(sampled_option, log_std)),
          q, beta, DistInfo(prob=pi)),
         device=device)
Example #19
0
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        # mean, log_std = self.model(*model_inputs)
        # dist_info = DistInfoStd(mean=mean, log_std=log_std)
        # action = self.distribution.sample(dist_info)
        if self.random_actions_for_pretraining:
            action = torch.randint_like(prev_action, 15)
            action = buffer_to(action, device="cpu")
            return AgentStep(action=action,
                             agent_info=AgentInfo(dist_info=None))

        pi, _, _ = self.model(*model_inputs)
        dist_info = DistInfo(prob=pi)
        action = self.distribution.sample(dist_info)
        agent_info = AgentInfo(dist_info=dist_info)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)
Example #20
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              init_rnn_state,
              device="cpu"):
     """Performs forward pass on training data, for algorithm (requires
     recurrent state input). Returnssampled distinfo, q, beta, and piomega distinfo"""
     # Assume init_rnn_state already shaped: [N,B,H]
     model_inputs = buffer_to((observation, prev_action, prev_reward,
                               init_rnn_state, sampled_option),
                              device=self.device)
     mu, log_std, beta, q, pi, next_rnn_state = self.model(
         *model_inputs[:-1])
     # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo)
     dist_info, q, beta, dist_info_omega = buffer_to(
         (DistInfoStd(mean=select_at_indexes(sampled_option, mu),
                      log_std=select_at_indexes(sampled_option, log_std)),
          q, beta, DistInfo(prob=pi)),
         device=device)
     return dist_info, q, beta, dist_info_omega, next_rnn_state  # Leave rnn_state on device.
Example #21
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi, rnn_state = self.model(
         *agent_inputs, self.prev_rnn_state)
     terminations = torch.bernoulli(beta).bool()  # Sample terminations
     dist_info_omega = DistInfo(prob=pi)
     new_o = self.sample_option(terminations, dist_info_omega)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi * q).sum(-1),
                                 termination=terminations,
                                 inter_option_dist_info=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Example #22
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward), device=self.device
     )
     pi, value, rnn_state, conv = self.model(*model_inputs, self.prev_rnn_state)
     if self._act_uniform:
         pi[:] = 1.0 / pi.shape[-1]  # uniform
     dist_info = DistInfo(prob=pi)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnnConv(
         dist_info=dist_info,
         value=value,
         prev_rnn_state=prev_rnn_state,
         conv=conv if self.store_latent else None,
     )  # Don't write the extra data.
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)
     return AgentStep(action=action, agent_info=agent_info)
Example #23
0
 def __call__(self, observation, prev_action=None, prev_reward=None):
     pi, value, _ = self.model(observation.to(device=self.device))
     return buffer_to((DistInfo(prob=pi), value), device="cpu")