def get_action(self, x):
        mean, log_std = self.pi(x)
        std = log_std.exp()
        normal = Normal(0, 1)
        z      = normal.sample()
        action = mean + std*z
        log_prob = Normal(mean, std).log_prob(action)
        log_prob = log_prob.sum(dim=-1, keepdim=True)  # reduce dim
        prob = log_prob.exp()

        action = self.action_range*action # scale the action

        return action.detach().numpy(), prob
Ejemplo n.º 2
0
    def sample(
        self,
        dist_params: Dict[str, Tensor],
        n_samples: Union[int, None],
        n_source_weights: Optional[Tensor] = None,
    ) -> Dict[str, Tensor]:
        """Sample from the encoded variational distribution.

        Args:
            dist_params: The output of `self.encode(image_ptiles)`,
                which is the distributional parameters in matrix form.
            n_samples:
                The number of samples to draw. If None, the variational mode is taken instead.
            n_source_weights:
                If specified, adjusts the sampling probabilities of n_sources.

        Returns:
            A dictionary of tensors with shape `n_samples * n_ptiles * max_sources * ...`.
            Consists of `"n_sources", "locs", "log_fluxes", and "fluxes"`.
        """
        if n_source_weights is None:
            n_source_weights = torch.ones(self.max_detections + 1,
                                          device=self.device)
        n_source_weights = n_source_weights.reshape(1, -1)
        ns_log_probs_adj = dist_params[
            "n_source_log_probs"] + n_source_weights.log()
        ns_log_probs_adj -= ns_log_probs_adj.logsumexp(dim=-1, keepdim=True)

        if n_samples is not None:
            n_source_probs = ns_log_probs_adj.exp()
            tile_n_sources = Categorical(probs=n_source_probs).sample(
                (n_samples, ))
        else:
            tile_n_sources = torch.argmax(ns_log_probs_adj,
                                          dim=-1).unsqueeze(0)
        # get distributional parameters conditioned on the sampled numbers of light sources
        dist_params_n_src = self._encode_for_n_sources(
            dist_params["per_source_params"], tile_n_sources)

        tile_is_on_array = get_is_on_from_n_sources(tile_n_sources,
                                                    self.max_detections)
        tile_is_on_array = tile_is_on_array.unsqueeze(-1)

        if n_samples is not None:
            tile_locs = Normal(dist_params_n_src["loc_mean"],
                               dist_params_n_src["loc_sd"]).rsample()
            tile_log_fluxes = Normal(
                dist_params_n_src["log_flux_mean"],
                dist_params_n_src["log_flux_sd"]).rsample()
        else:
            tile_locs = dist_params_n_src["loc_mean"]
            tile_log_fluxes = dist_params_n_src["log_flux_mean"]
        tile_locs *= tile_is_on_array  # Is masking here helpful/necessary?
        tile_fluxes = tile_log_fluxes.exp()
        tile_fluxes *= tile_is_on_array

        return {
            "locs": tile_locs,
            "log_fluxes": tile_log_fluxes,
            "fluxes": tile_fluxes,
            "n_sources": tile_n_sources,
        }