예제 #1
0
    def new_latent_state(self):
        """
        Return new latent state.
        This is a function because the latent state is different for DVRL and RNN.
        """
        device = next(self.parameters()).device
        initial_state = st.State(h=torch.zeros(
            self.batch_size, self.num_particles, self.h_dim).to(device))

        log_weight = torch.zeros(self.batch_size,
                                 self.num_particles).to(device)

        initial_state.log_weight = log_weight

        return initial_state
예제 #2
0
    def predict_observations(self, latent_state, current_observation, actions,
                             emission_state_random_variable, predicted_times):
        """
        Assumes that the current encoded action (saved in 'current_observation') is
        repeated into the future
        """

        max_distance = max(predicted_times)
        old_log_weight = latent_state.log_weight
        predicted_observations = []
        particle_observations = []

        if 0 in predicted_times:
            x = emission_state_random_variable.all_x._probability

            averaged_obs = stats.empirical_mean(x, old_log_weight)
            predicted_observations.append(averaged_obs)
            particle_observations.append(x)

        batch_size, num_particles, z_dim = latent_state.z.size()
        batch_size, num_particles, h_dim = latent_state.h.size()
        for dt in range(max_distance):
            old_observation = current_observation
            previous_latent_state = latent_state

            # Get next state
            transition_state_random_variable = self.transition_network(
                previous_latent_state, old_observation)
            latent_state = self.sample_from(transition_state_random_variable)

            # Hack. This is usually done in det_transition
            latent_state.phi_z = self.deterministic_transition_network.phi_z(
                latent_state.z.view(-1, z_dim)).view(batch_size, num_particles,
                                                     h_dim)

            # Draw observation
            emission_state_random_variable = self.emission_network(
                previous_latent_state, latent_state, old_observation
                # observation_states
            )
            x = emission_state_random_variable.all_x._probability
            averaged_obs = stats.empirical_mean(x, old_log_weight)

            # Encode observation
            # Unsqueeze time dimension
            current_observation = st.State(all_x=averaged_obs.unsqueeze(0),
                                           all_a=actions.contiguous())
            current_observation = self.encoding_network(current_observation)
            current_observation.unsequeeze_and_expand_all_(
                dim=2, size=self.num_particles)
            current_observation = current_observation.index_elements(0)

            # Deterministic update
            latent_state = self.deterministic_transition_network(
                previous_latent_state=previous_latent_state,
                latent_state=latent_state,
                observation_states=current_observation,
                time=0)

            if dt + 1 in predicted_times:
                predicted_observations.append(averaged_obs)
                particle_observations.append(x)

        return predicted_observations, particle_observations
예제 #3
0
    def encode(self, observation, reward, actions, previous_latent_state,
               predicted_times):
        """
        This is where the core of the DVRL algorithm is happening.

        Args:
            observation, reward: Last observation and reward recieved from all n_e environments
            actions: Action vector (oneHot for discrete actions)
            previous_latent_state: previous latent state of type state.State
            predicted_times (list of ints): List of timesteps into the future for which predictions
                                            should be returned. Only makes sense if
                                            encoding_loss_coef != 0 and obs_loss_coef != 0

        return latent_state, \
            - encoding_logli, \
            (- transition_logpdf + proposal_logpdf, - emission_logpdf),\
            avg_num_killed_particles,\
            predicted_observations, particle_observations
        Returns:
            latent_state: New latent state
            - encoding_logli = encoding_loss: Reconstruction loss when prediction current observation X obs_loss_coef
            - transition_logpdf + proposal_logpdf: KL divergence loss
            - emission_logpdf: Reconstruction loss
            avg_num_killed_particles: Average numer of killed particles in particle filter
            predicted_observations: Predicted observations (depending on timesteps specified in predicted_times)
            predicted_particles: List of Nones

        """
        batch_size, *rest = observation.size()

        # Total observation dim to normalise the likelihood
        # obs_dim = reduce(mul, rest, 1)

        # Needed for legacy AESMC code
        ae_util.init(observation.is_cuda)

        # Legacy code: We need to pass in a (time) sequence of observations
        # With dim=0 for time
        img_observation = observation.unsqueeze(0)
        actions = actions.unsqueeze(0)
        reward = reward.unsqueeze(0)

        # Legacy code: All values are wrapped in state.State (which can contain more than one value)
        observation_states = st.State(all_x=img_observation.contiguous(),
                                      all_a=actions.contiguous(),
                                      r=reward.contiguous())

        old_log_weight = previous_latent_state.log_weight

        # Encoding the actions and observations (nothing recurrent yet)
        observation_states = self.encoding_network(observation_states)

        # Expand the particle dimension
        observation_states.unsequeeze_and_expand_all_(dim=2,
                                                      size=self.num_particles)

        ancestral_indices = sample_ancestral_index(old_log_weight)

        # How many particles were killed?
        # List over batch size
        num_killed_particles = list(
            tu.num_killed_particles(ancestral_indices.data.cpu()))
        if self.resample:
            previous_latent_state = previous_latent_state.resample(
                ancestral_indices)
        else:
            num_killed_particles = [0] * batch_size

        avg_num_killed_particles = sum(num_killed_particles) / len(
            num_killed_particles)

        # Legacy code: Select first (and only) time index
        current_observation = observation_states.index_elements(0)

        # Sample stochastic latent state z from proposal
        proposal_state_random_variable = self.proposal_network(
            previous_latent_state=previous_latent_state,
            observation_states=current_observation,
            time=0)
        latent_state = self.sample_from(proposal_state_random_variable)

        # Compute deterministic state h and add to the latent state
        latent_state = self.deterministic_transition_network(
            previous_latent_state=previous_latent_state,
            latent_state=latent_state,
            observation_states=current_observation,
            time=0)

        # Compute prior probability over z
        transition_state_random_variable = self.transition_network(
            previous_latent_state, current_observation)

        # Compute probability over observation
        emission_state_random_variable = self.emission_network(
            previous_latent_state, latent_state, current_observation
            # observation_states
        )

        emission_logpdf = emission_state_random_variable.logpdf(
            current_observation, batch_size, self.num_particles)

        proposal_logpdf = proposal_state_random_variable.logpdf(
            latent_state, batch_size, self.num_particles)
        transition_logpdf = transition_state_random_variable.logpdf(
            latent_state, batch_size, self.num_particles)

        assert (self.prior_loss_coef == 1)
        assert (self.obs_loss_coef == 1)
        new_log_weight = transition_logpdf - proposal_logpdf + emission_logpdf
        # new_log_weight = (self.prior_loss_coef * (transition_logpdf - proposal_logpdf)
        #                   + self.obs_loss_coef * emission_logpdf)

        latent_state.log_weight = new_log_weight

        # Average (in log space) over particles
        encoding_logli = math.logsumexp(
            # torch.stack(log_weights, dim=0), dim=2
            new_log_weight,
            dim=1) - np.log(self.num_particles)

        # inference_result.latent_states = latent_states

        predicted_observations = None
        particle_observations = None
        if predicted_times is not None:
            predicted_observations, particle_observations = self.predict_observations(
                latent_state=latent_state,
                current_observation=current_observation,
                actions=actions,
                emission_state_random_variable=emission_state_random_variable,
                predicted_times=predicted_times)

        ae_util.init(False)

        return latent_state, \
            - encoding_logli, \
            (- transition_logpdf + proposal_logpdf, - emission_logpdf),\
            avg_num_killed_particles,\
            predicted_observations, particle_observations
예제 #4
0
    def reconstruct_predict(self,
                            observation_states,
                            num_particles,
                            resample,
                            reconstruction_length,
                            prediction_length,
                            summarize_function):
        """
        REWRITE!!
        input:
            observations: Variable [num_timesteps, batch_size, observation_dim]
            resample: bool. True: smc, False: is.
            num_particles: number. number of particles for posterior approximation.
            prediction_length: number. length of prediction.
            reconstruction_length: number. length of reconstruction_length
                                   Should be num_timesteps - prediction_length
            summarize_function: Function. Takes in an ensemble of states at a certain time
                                and corresponding weights. Outputs a single (combined) state.
        output:
            observations_reconstructed_predicted:
                Tensor [num_timesteps + prediction_length, batch_size, observation_dim]
        """

        # Check this
        num_timesteps, batch_size, nr_channels, w, h = observation_states.all_x.size()

        assert(num_timesteps == reconstruction_length + prediction_length)
        # I think we only predict one channel?

        # given_obs = observations[:reconstruction_length]
        # Only regress model on 'known' states
        given_obs = st.State(
            all_x=observation_states.all_x[:reconstruction_length]
        )

        inference_result = self.forward(
            observation_states=given_obs,
            num_particles=num_particles,
            resample=resample,
            return_inference_results=True
            )

        latent_states = inference_result.latent_states
        latent_state = latent_states[-1]

        for t in range(reconstruction_length, num_timesteps):
            previous_latent_state = latent_state

            # a) Prior: Draw latent_state.z
            transition_state_random_variable = self.transition_network(
                previous_latent_state
                )
            latent_state = transition_state_random_variable.sample(
                batch_size, num_particles
                )
            batch_size, num_particles, z_dim = latent_state.z.size()

            # TODO: This is highly specific for VRNN
            latent_state.phi_z = self.deterministic_transition_network.phi_z(
                latent_state.z.view(-1, z_dim)
                ).view(
                    batch_size, num_particles, -1
                    )

            latent_states.append(latent_state)

            # b) Generation: Draw x and compute phi_x
            emission_state_random_variable = self.emission_network(
                previous_latent_state, latent_state
                )
            x = emission_state_random_variable.sample(
                        batch_size, num_particles).all_x

            # This is usually done in the init-network
            # num_timesteps = 1 for one image
            # TODO: This is highly specific for VRNN
            phi_x = self.encoding_network.phi_x(
                x.view(-1, nr_channels, w, h)
                ).view(
                    1, batch_size, num_particles, -1
                    ).contiguous()

            current_observation = st.State(
                all_phi_x=phi_x,
                x=x  # Just in case, not used in VRNN
                )

            # Set time=0 because we have only the last observation
            if self.deterministic_transition_network is not None:
                latent_state = self.deterministic_transition_network(
                    previous_latent_state=previous_latent_state,
                    latent_state=latent_state,
                    observation_states=current_observation,
                    time=0
                )

        # Ok, at this point we have all the latent states
        # Compute reconstructed/predicted observations from given latents
        averaged_obs = torch.zeros(num_timesteps, batch_size, nr_channels, h, w)
        all_obs = torch.zeros(num_timesteps, batch_size, num_particles, nr_channels, h, w)
        observation_states.unsequeeze_and_expand_all_(dim=2, size=num_particles)
        initial_state_random_variable = self.initial_network(
            observation_states
        )

        initial_state = initial_state_random_variable.sample(
            batch_size, num_particles
            )

        for t in range(num_timesteps):
            if t == 0:
                previous_latent_state = initial_state
            else:
                previous_latent_state = latent_states[t-1]
            latent_state = latent_states[t]
            emission_state_random_variable = self.emission_network(
                previous_latent_state, latent_state
                )
            x = emission_state_random_variable.sample(
                        batch_size, num_particles).all_x
            all_obs[t] = x.data
            averaged_obs[t] = summarize_function(
                x,
                inference_result.log_weight).data

        return all_obs, averaged_obs, inference_result.log_weight.data