Пример #1
0
 def test_dimensions(self):
     self.assertEqual(
         inference.sample_ancestral_index(torch.rand(2, 3)).size(),
         torch.Size([2, 3]))
     self.assertEqual(
         inference.sample_ancestral_index(torch.rand(1, 2)).size(),
         torch.Size([1, 2]))
     self.assertEqual(
         inference.sample_ancestral_index(torch.rand(2, 1)).size(),
         torch.Size([2, 1]))
Пример #2
0
    def test_sampler(self):
        weight = [0.2, 0.3, 0.5]
        num_trials = 10000
        ancestral_indices = inference.sample_ancestral_index(
            torch.log(torch.Tensor(weight)).unsqueeze(0).expand(
                num_trials, len(weight)))

        empirical_probabilities = []
        for i in range(len(weight)):
            empirical_probabilities.append(
                torch.sum((ancestral_indices == i).float()).item() /
                (num_trials * len(weight)))

        np.testing.assert_allclose(
            np.array(empirical_probabilities),
            np.array(weight),
            atol=1e-2  # 2 decimal places
        )
Пример #3
0
    def encode(self, observation, reward, actions, previous_latent_state,
               predicted_times):
        """
        This is where the core of the DVRL algorithm is happening.

        Args:
            observation, reward: Last observation and reward recieved from all n_e environments
            actions: Action vector (oneHot for discrete actions)
            previous_latent_state: previous latent state of type state.State
            predicted_times (list of ints): List of timesteps into the future for which predictions
                                            should be returned. Only makes sense if
                                            encoding_loss_coef != 0 and obs_loss_coef != 0

        return latent_state, \
            - encoding_logli, \
            (- transition_logpdf + proposal_logpdf, - emission_logpdf),\
            avg_num_killed_particles,\
            predicted_observations, particle_observations
        Returns:
            latent_state: New latent state
            - encoding_logli = encoding_loss: Reconstruction loss when prediction current observation X obs_loss_coef
            - transition_logpdf + proposal_logpdf: KL divergence loss
            - emission_logpdf: Reconstruction loss
            avg_num_killed_particles: Average numer of killed particles in particle filter
            predicted_observations: Predicted observations (depending on timesteps specified in predicted_times)
            predicted_particles: List of Nones

        """
        batch_size, *rest = observation.size()

        # Total observation dim to normalise the likelihood
        # obs_dim = reduce(mul, rest, 1)

        # Needed for legacy AESMC code
        ae_util.init(observation.is_cuda)

        # Legacy code: We need to pass in a (time) sequence of observations
        # With dim=0 for time
        img_observation = observation.unsqueeze(0)
        actions = actions.unsqueeze(0)
        reward = reward.unsqueeze(0)

        # Legacy code: All values are wrapped in state.State (which can contain more than one value)
        observation_states = st.State(all_x=img_observation.contiguous(),
                                      all_a=actions.contiguous(),
                                      r=reward.contiguous())

        old_log_weight = previous_latent_state.log_weight

        # Encoding the actions and observations (nothing recurrent yet)
        observation_states = self.encoding_network(observation_states)

        # Expand the particle dimension
        observation_states.unsequeeze_and_expand_all_(dim=2,
                                                      size=self.num_particles)

        ancestral_indices = sample_ancestral_index(old_log_weight)

        # How many particles were killed?
        # List over batch size
        num_killed_particles = list(
            tu.num_killed_particles(ancestral_indices.data.cpu()))
        if self.resample:
            previous_latent_state = previous_latent_state.resample(
                ancestral_indices)
        else:
            num_killed_particles = [0] * batch_size

        avg_num_killed_particles = sum(num_killed_particles) / len(
            num_killed_particles)

        # Legacy code: Select first (and only) time index
        current_observation = observation_states.index_elements(0)

        # Sample stochastic latent state z from proposal
        proposal_state_random_variable = self.proposal_network(
            previous_latent_state=previous_latent_state,
            observation_states=current_observation,
            time=0)
        latent_state = self.sample_from(proposal_state_random_variable)

        # Compute deterministic state h and add to the latent state
        latent_state = self.deterministic_transition_network(
            previous_latent_state=previous_latent_state,
            latent_state=latent_state,
            observation_states=current_observation,
            time=0)

        # Compute prior probability over z
        transition_state_random_variable = self.transition_network(
            previous_latent_state, current_observation)

        # Compute probability over observation
        emission_state_random_variable = self.emission_network(
            previous_latent_state, latent_state, current_observation
            # observation_states
        )

        emission_logpdf = emission_state_random_variable.logpdf(
            current_observation, batch_size, self.num_particles)

        proposal_logpdf = proposal_state_random_variable.logpdf(
            latent_state, batch_size, self.num_particles)
        transition_logpdf = transition_state_random_variable.logpdf(
            latent_state, batch_size, self.num_particles)

        assert (self.prior_loss_coef == 1)
        assert (self.obs_loss_coef == 1)
        new_log_weight = transition_logpdf - proposal_logpdf + emission_logpdf
        # new_log_weight = (self.prior_loss_coef * (transition_logpdf - proposal_logpdf)
        #                   + self.obs_loss_coef * emission_logpdf)

        latent_state.log_weight = new_log_weight

        # Average (in log space) over particles
        encoding_logli = math.logsumexp(
            # torch.stack(log_weights, dim=0), dim=2
            new_log_weight,
            dim=1) - np.log(self.num_particles)

        # inference_result.latent_states = latent_states

        predicted_observations = None
        particle_observations = None
        if predicted_times is not None:
            predicted_observations, particle_observations = self.predict_observations(
                latent_state=latent_state,
                current_observation=current_observation,
                actions=actions,
                emission_state_random_variable=emission_state_random_variable,
                predicted_times=predicted_times)

        ae_util.init(False)

        return latent_state, \
            - encoding_logli, \
            (- transition_logpdf + proposal_logpdf, - emission_logpdf),\
            avg_num_killed_particles,\
            predicted_observations, particle_observations
Пример #4
0
 def test_type(self):
     self.assertIsInstance(
         inference.sample_ancestral_index(torch.rand(1, 1)),
         torch.LongTensor)