Пример #1
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's option and action distributions from inputs.
     Calls model to get mean, std for all pi_w, q, beta for all options, pi over options
     Moves inputs to device and returns outputs back to CPU, for the
     sampler.  (no grad)
     """
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi = self.model(*model_inputs)
     dist_info_omega = DistInfo(prob=pi)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOC(dist_info=dist_info,
                              dist_info_o=dist_info_o,
                              q=q,
                              value=(pi * q).sum(-1),
                              termination=terminations,
                              dist_info_omega=dist_info_omega,
                              prev_o=self._prev_option,
                              o=new_o)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Пример #2
0
 def beta_dist_infos(self, observation, prev_action, prev_reward,
                     init_rnn_state):
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, init_rnn_state),
         device=self.device)
     r_mu, r_log_std, _, _ = self.beta_r_model(*model_inputs)
     c_mu, c_log_std, _, _ = self.beta_c_model(*model_inputs)
     return buffer_to((DistInfoStd(mean=r_mu, log_std=r_log_std),
                       DistInfoStd(mean=c_mu, log_std=c_log_std)),
                      device="cpu")
Пример #3
0
    def sample(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist
        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)

        if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
            cat_sample = cat_sample.unsqueeze(0)

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = select_at_indexes(cat_sample, mu)
            log_std = select_at_indexes(cat_sample, log_std)

            if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
                mu, log_std = mu.squeeze(0), log_std.squeeze(0)

            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        if self.training:
            self.delta_distribution.set_std(None)
        else:
            self.delta_distribution.set_std(0)
        delta_sample = self.delta_distribution.sample(new_dist_info)
        return torch.cat((one_hot, delta_sample), dim=-1)
Пример #4
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs,
                                                self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Пример #5
0
 def __call__(self, observation, prev_action, prev_reward, device='cpu'):
     """Performs forward pass on training data, for algorithm."""
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, value = self.model(*model_inputs)
     return buffer_to((DistInfoStd(mean=mu, log_std=log_std), value),
                      device=device)
Пример #6
0
 def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
     # Assume init_rnn_state already shaped: [N,B,H]
     model_inputs = buffer_to((observation, prev_action, prev_reward,
         init_rnn_state), device=self.device)
     mu, log_std, value, next_rnn_state = self.model(*model_inputs)
     dist_info, value = buffer_to((DistInfoStd(mean=mu, log_std=log_std), value), device="cpu")
     return dist_info, value, next_rnn_state  # Leave rnn_state on device.
Пример #7
0
    def __call__(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        mu, log_std, value = self.model(*model_inputs)

        samples = (DistInfoStd(mean=mu, log_std=log_std), value)
        return buffer_to(samples, device="cpu")
Пример #8
0
    def sample_loglikelihood(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist

        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        cat_loglikelihood = select_at_indexes(cat_sample, prob)

        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)
        one_hot = (one_hot - prob).detach() + prob # Make action differentiable through prob

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = mu[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            log_std = log_std[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        delta_sample, delta_loglikelihood = self.delta_distribution.sample_loglikelihood(new_dist_info)
        action = torch.cat((one_hot, delta_sample), dim=-1)
        log_likelihood = cat_loglikelihood + delta_loglikelihood
        return action, log_likelihood
Пример #9
0
 def pi(self, observation, prev_action, prev_reward):
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mean, log_std = self.model(*model_inputs)
     dist_info = DistInfoStd(mean=mean, log_std=log_std)
     action, log_pi = self.distribution.sample_loglikelihood(dist_info)
     log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
     return action, log_pi, dist_info  # Action stays on device for q models.
Пример #10
0
 def step(self, observation, prev_action, prev_reward):
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, value = self.model(*model_inputs)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     agent_info = AgentInfo(dist_info=dist_info, value=value)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Пример #11
0
 def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
     """Performs forward pass on training data, for algorithm (requires
     recurrent state input)."""
     # Assume init_rnn_state already shaped: [N,B,H]
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, init_rnn_state),
         device=self.device)
     mu, log_std, value, next_rnn_state = self.model(*model_inputs)
     dist_info, value = buffer_to(
         (DistInfoStd(mean=mu, log_std=log_std), value), device="cpu")
     return dist_info, value, next_rnn_state  # Leave rnn_state on device.
Пример #12
0
 def pi(self, observation, prev_action, prev_reward):
     """Compute action log-probabilities for state/observation, and
     sample new action (with grad).  Uses special ``sample_loglikelihood()``
     method of Gaussian distribution, which handles action squashing
     through this process."""
     model_inputs = buffer_to(observation, device=self.device)
     mean, log_std, _ = self.model(model_inputs, "pi")
     dist_info = DistInfoStd(mean=mean, log_std=log_std)
     action, log_pi = self.distribution.sample_loglikelihood(dist_info)
     log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
     return action, log_pi, dist_info  # Action stays on device for q models.
Пример #13
0
 def step(self, observation, prev_action, prev_reward):
     model_inputs = buffer_to(observation, device=self.device)
     mean, log_std, sym_features = self.model(model_inputs,
                                              "pi",
                                              extract_sym_features=True)
     dist_info = DistInfoStd(mean=mean, log_std=log_std)
     action = self.distribution.sample(dist_info)
     agent_info = SafeSacAgentInfo(dist_info=dist_info,
                                   sym_features=sym_features)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Пример #14
0
    def forward(self, observation, prev_action, prev_reward):
        if isinstance(observation, tuple):
            observation = torch.cat(observation, dim=-1)

        lead_dim, T, B, _ = infer_leading_dims(observation,
            self._obs_ndim)
        output = self.mlp(observation.view(T * B, -1))
        logits = output[:, :4]
        mu, log_std = output[:, 4:4 + self._delta_dim], output[:, 4 + self._delta_dim:]
        logits, mu, log_std = restore_leading_dims((logits, mu, log_std), lead_dim, T, B)
        return GumbelDistInfo(cat_dist=logits, delta_dist=DistInfoStd(mean=mu, log_std=log_std))
Пример #15
0
 def step(self, observation, prev_action, prev_reward):
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mean, log_std = self.pi_model(*model_inputs)
     dist_info = DistInfoStd(mean=mean, log_std=log_std)
     action = self.distribution.sample(dist_info)
     agent_info = AgentInfo(dist_info=dist_info)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     if np.any(np.isnan(action.numpy())):
         breakpoint()
     return AgentStep(action=action, agent_info=agent_info)
Пример #16
0
 def step(self, observation, prev_action, prev_reward):
     observation, prev_action, prev_reward = buffer_to(
         (observation, prev_action, prev_reward), device=self.device)
     # self.model includes encoder + actor MLP.
     mean, log_std, latent, conv = self.model(observation, prev_action,
                                              prev_reward)
     dist_info = DistInfoStd(mean=mean, log_std=log_std)
     action = self.distribution.sample(dist_info)
     agent_info = AgentInfo(dist_info=dist_info,
                            conv=conv if self.store_latent else None)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Пример #17
0
 def step(self, observation, prev_action, prev_reward):
     model_inputs = buffer_to((observation, ), device=self.device)[0]
     mu, log_std, value, sym_features = self.model(
         model_inputs, extract_sym_features=True)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     action = action.clamp(-1, 1)
     agent_info = SafeAgentInfo(dist_info=dist_info,
                                value=value,
                                sym_features=sym_features)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     return AgentStep(action=action, agent_info=agent_info)
Пример #18
0
    def step(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        actions, means, log_stds = [], [], []
        self.model.start()
        while self.model.has_next():
            mean, log_std = self.model.next(actions, *model_inputs)
            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            action = self.distribution.sample(dist_info)

            actions.append(action)
            means.append(mean)
            log_stds.append(log_std)

        mean, log_std = torch.cat(means, dim=-1), torch.cat(log_stds, dim=-1)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)
        agent_info = AgentInfo(dist_info=dist_info)

        action = torch.cat(actions, dim=-1)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        return AgentStep(action=action, agent_info=agent_info)
Пример #19
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi, rnn_state = self.model(
         *agent_inputs, self.prev_rnn_state)
     terminations = torch.bernoulli(beta).bool()  # Sample terminations
     dist_info_omega = DistInfo(prob=pi)
     new_o = self.sample_option(terminations, dist_info_omega)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi * q).sum(-1),
                                 termination=terminations,
                                 inter_option_dist_info=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Пример #20
0
    def pi(self, observation, prev_action, prev_reward):
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        actions, means, log_stds = [], [], []
        log_pi_total = 0
        self.model.start()
        while self.model.has_next():
            mean, log_std = self.model.next(actions, *model_inputs)
            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            action, log_pi = self.distribution.sample_loglikelihood(dist_info)

            log_pi_total += log_pi
            actions.append(action)
            means.append(mean)
            log_stds.append(log_std)

        mean, log_std = torch.cat(means, dim=-1), torch.cat(log_stds, dim=-1)
        dist_info = DistInfoStd(mean=mean, log_std=log_std)

        log_pi_total, dist_info = buffer_to((log_pi_total, dist_info),
                                            device="cpu")
        action = torch.cat(actions, dim=-1)
        return action, log_pi_total, dist_info  # Action stays on device for q models.
Пример #21
0
 def pi(self, conv_out, prev_action, prev_reward):
     """Compute action log-probabilities for state/observation, and
     sample new action (with grad).  Uses special ``sample_loglikelihood()``
     method of Gaussian distriution, which handles action squashing
     through this process.
     Assume variables already on device."""
     # Call just the actor mlp, not the encoder.
     latent = self.pi_fc1(conv_out)
     mean, log_std = self.pi_mlp(latent, prev_action, prev_reward)
     dist_info = DistInfoStd(mean=mean, log_std=log_std)
     action, log_pi = self.distribution.sample_loglikelihood(dist_info)
     # action = self.distribution.sample(dist_info)
     # log_pi = self.distribution.log_likelihood(action, dist_info)
     log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu")
     return action, log_pi, dist_info  # Action stays on device for q models.
Пример #22
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, and value estimate.
     Moves inputs to device and returns outputs back to CPU, for the
     sampler.  (no grad)
     """
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, value = self.model(*model_inputs)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     agent_info = AgentInfo(dist_info=dist_info, value=value)
     action, agent_info = buffer_to((action, agent_info), device=device)
     return AgentStep(action=action, agent_info=agent_info)
Пример #23
0
 def step(self, observation, prev_action, prev_reward):
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info, value=value,
         prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Пример #24
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              device="cpu"):
     """Performs forward pass on training data, for algorithm. Returns sampled distinfo, q, beta, and piomega distinfo"""
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, sampled_option),
         device=self.device)
     mu, log_std, beta, q, pi = self.model(*model_inputs[:-1])
     # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo)
     return buffer_to(
         (DistInfoStd(mean=select_at_indexes(sampled_option, mu),
                      log_std=select_at_indexes(sampled_option, log_std)),
          q, beta, DistInfo(prob=pi)),
         device=device)
Пример #25
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              init_rnn_state,
              device="cpu"):
     """Performs forward pass on training data, for algorithm (requires
     recurrent state input). Returnssampled distinfo, q, beta, and piomega distinfo"""
     # Assume init_rnn_state already shaped: [N,B,H]
     model_inputs = buffer_to((observation, prev_action, prev_reward,
                               init_rnn_state, sampled_option),
                              device=self.device)
     mu, log_std, beta, q, pi, next_rnn_state = self.model(
         *model_inputs[:-1])
     # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo)
     dist_info, q, beta, dist_info_omega = buffer_to(
         (DistInfoStd(mean=select_at_indexes(sampled_option, mu),
                      log_std=select_at_indexes(sampled_option, log_std)),
          q, beta, DistInfo(prob=pi)),
         device=device)
     return dist_info, q, beta, dist_info_omega, next_rnn_state  # Leave rnn_state on device.
Пример #26
0
    def next(self, actions, observation, prev_action, prev_reward):
        if isinstance(observation, tuple):
            observation = torch.cat(observation, dim=-1)

        lead_dim, T, B, _ = infer_leading_dims(observation,
                                               self._obs_ndim)
        input_obs = observation.view(T * B, -1)
        if self._counter == 0:
            logits = self.mlp_loc(input_obs)
            logits = restore_leading_dims(logits, lead_dim, T, B)
            self._counter += 1
            return logits

        elif self._counter == 1:
            assert len(actions) == 1
            action_loc = actions[0].view(T * B, -1)
            model_input = torch.cat((input_obs, action_loc.repeat((1, self._n_tile))), dim=-1)
            output = self.mlp_delta(model_input)
            mu, log_std = output.chunk(2, dim=-1)
            mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B)
            self._counter += 1
            return DistInfoStd(mean=mu, log_std=log_std)
        else:
            raise Exception('Invalid self._counter', self._counter)
Пример #27
0
    def step(self, observation, prev_action, prev_reward):
        threshold = 0.2
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        if self._max_q_eval_mode == 'none':
            mean, log_std = self.model(*model_inputs)
            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            action = self.distribution.sample(dist_info)
            agent_info = AgentInfo(dist_info=dist_info)
            action, agent_info = buffer_to((action, agent_info), device="cpu")
            return AgentStep(action=action, agent_info=agent_info)
        else:
            global MaxQInput
            observation, prev_action, prev_reward = model_inputs
            fields = observation._fields
            if 'position' in fields:
                no_batch = len(observation.position.shape) == 1
            else:
                no_batch = len(observation.pixels.shape) == 3
            if no_batch:
                if 'state' in self._max_q_eval_mode:
                    observation = [observation.position.unsqueeze(0)]
                else:
                    observation = [observation.pixels.unsqueeze(0)]
            else:
                if 'state' in self._max_q_eval_mode:
                    observation = [observation.position]
                else:
                    observation = [observation.pixels]

            if self._max_q_eval_mode == 'state_rope':
                locations = np.arange(25).astype('float32')
                locations = locations[:, None]
                locations = np.tile(locations, (1, 50)) / 24
            elif self._max_q_eval_mode == 'state_cloth_corner':
                locations = np.array(
                    [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
                    dtype='float32')
                locations = np.tile(locations, (1, 50))
            elif self._max_q_eval_mode == 'state_cloth_point':
                locations = np.mgrid[0:9, 0:9].reshape(2,
                                                       81).T.astype('float32')
                locations = np.tile(locations, (1, 50)) / 8
            elif self._max_q_eval_mode == 'pixel_rope':
                image = observation[0].squeeze(0).cpu().numpy()
                locations = np.transpose(np.where(np.all(
                    image > 150, axis=2))).astype('float32')
                if locations.shape[0] == 0:
                    locations = np.array([[-1, -1]], dtype='float32')
                locations = np.tile(locations, (1, 50)) / 63
            elif self._max_q_eval_mode == 'pixel_cloth':
                image = observation[0].squeeze(0).cpu().numpy()
                locations = np.transpose(np.where(np.any(
                    image < 100, axis=-1))).astype('float32')
                locations = np.tile(locations, (1, 50)) / 63
            else:
                raise Exception()

            observation_pi = self.model.forward_embedding(observation)
            observation_qs = [
                q.forward_embedding(observation) for q in self.q_models
            ]

            n_locations = len(locations)
            observation_pi_i = [
                repeat(o[[i]], [n_locations] + [1] * len(o.shape[1:]))
                for o in observation_pi
            ]
            observation_qs_i = [[
                repeat(o, [n_locations] + [1] * len(o.shape[1:]))
                for o in observation_q
            ] for observation_q in observation_qs]
            locations = torch.from_numpy(locations).to(self.device)

            if MaxQInput is None:
                MaxQInput = namedtuple('MaxQPolicyInput', fields)

            aug_observation_pi = [locations] + list(observation_pi_i)
            aug_observation_pi = MaxQInput(*aug_observation_pi)
            aug_observation_qs = [[locations] + list(observation_q_i)
                                  for observation_q_i in observation_qs_i]
            aug_observation_qs = [
                MaxQInput(*aug_observation_q)
                for aug_observation_q in aug_observation_qs
            ]

            mean, log_std = self.model.forward_output(
                aug_observation_pi)  #, prev_action, prev_reward)

            qs = [
                q.forward_output(aug_obs, mean)
                for q, aug_obs in zip(self.q_models, aug_observation_qs)
            ]
            q = torch.min(torch.stack(qs, dim=0), dim=0)[0]
            #q = q.view(batch_size, n_locations)

            values, indices = torch.topk(q,
                                         math.ceil(threshold * n_locations),
                                         dim=-1)

            # vmin, vmax = values.min(dim=-1, keepdim=True)[0], values.max(dim=-1, keepdim=True)[0]
            # values = (values - vmin) / (vmax - vmin)
            # values = F.log_softmax(values, -1)
            #
            # uniform = torch.rand_like(values)
            # uniform = torch.clamp(uniform, 1e-5, 1 - 1e-5)
            # gumbel = -torch.log(-torch.log(uniform))

            #sampled_idx = torch.argmax(values + gumbel, dim=-1)
            sampled_idx = torch.randint(high=math.ceil(threshold *
                                                       n_locations),
                                        size=(1, )).to(self.device)

            actual_idxs = indices[sampled_idx]
            #actual_idxs += (torch.arange(batch_size) * n_locations).to(self.device)

            location = locations[actual_idxs][:, :1]
            location = (location - 0.5) / 0.5
            delta = torch.tanh(mean[actual_idxs])
            action = torch.cat((location, delta), dim=-1)

            mean, log_std = mean[actual_idxs], log_std[actual_idxs]

            if no_batch:
                action = action.squeeze(0)
                mean = mean.squeeze(0)
                log_std = log_std.squeeze(0)

            dist_info = DistInfoStd(mean=mean, log_std=log_std)
            agent_info = AgentInfo(dist_info=dist_info)

            action, agent_info = buffer_to((action, agent_info), device="cpu")
            return AgentStep(action=action, agent_info=agent_info)
Пример #28
0
 def __call__(self, observation, prev_action, prev_reward):
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, ev, iv = self.model(*model_inputs)
     return buffer_to((DistInfoStd(mean=mu, log_std=log_std), ev, iv),
                      device="cpu")