def new_latent_state(self): """ Return new latent state. This is a function because the latent state is different for DVRL and RNN. """ device = next(self.parameters()).device initial_state = st.State(h=torch.zeros( self.batch_size, self.num_particles, self.h_dim).to(device)) log_weight = torch.zeros(self.batch_size, self.num_particles).to(device) initial_state.log_weight = log_weight return initial_state
def predict_observations(self, latent_state, current_observation, actions, emission_state_random_variable, predicted_times): """ Assumes that the current encoded action (saved in 'current_observation') is repeated into the future """ max_distance = max(predicted_times) old_log_weight = latent_state.log_weight predicted_observations = [] particle_observations = [] if 0 in predicted_times: x = emission_state_random_variable.all_x._probability averaged_obs = stats.empirical_mean(x, old_log_weight) predicted_observations.append(averaged_obs) particle_observations.append(x) batch_size, num_particles, z_dim = latent_state.z.size() batch_size, num_particles, h_dim = latent_state.h.size() for dt in range(max_distance): old_observation = current_observation previous_latent_state = latent_state # Get next state transition_state_random_variable = self.transition_network( previous_latent_state, old_observation) latent_state = self.sample_from(transition_state_random_variable) # Hack. This is usually done in det_transition latent_state.phi_z = self.deterministic_transition_network.phi_z( latent_state.z.view(-1, z_dim)).view(batch_size, num_particles, h_dim) # Draw observation emission_state_random_variable = self.emission_network( previous_latent_state, latent_state, old_observation # observation_states ) x = emission_state_random_variable.all_x._probability averaged_obs = stats.empirical_mean(x, old_log_weight) # Encode observation # Unsqueeze time dimension current_observation = st.State(all_x=averaged_obs.unsqueeze(0), all_a=actions.contiguous()) current_observation = self.encoding_network(current_observation) current_observation.unsequeeze_and_expand_all_( dim=2, size=self.num_particles) current_observation = current_observation.index_elements(0) # Deterministic update latent_state = self.deterministic_transition_network( previous_latent_state=previous_latent_state, latent_state=latent_state, observation_states=current_observation, time=0) if dt + 1 in predicted_times: predicted_observations.append(averaged_obs) particle_observations.append(x) return predicted_observations, particle_observations
def encode(self, observation, reward, actions, previous_latent_state, predicted_times): """ This is where the core of the DVRL algorithm is happening. Args: observation, reward: Last observation and reward recieved from all n_e environments actions: Action vector (oneHot for discrete actions) previous_latent_state: previous latent state of type state.State predicted_times (list of ints): List of timesteps into the future for which predictions should be returned. Only makes sense if encoding_loss_coef != 0 and obs_loss_coef != 0 return latent_state, \ - encoding_logli, \ (- transition_logpdf + proposal_logpdf, - emission_logpdf),\ avg_num_killed_particles,\ predicted_observations, particle_observations Returns: latent_state: New latent state - encoding_logli = encoding_loss: Reconstruction loss when prediction current observation X obs_loss_coef - transition_logpdf + proposal_logpdf: KL divergence loss - emission_logpdf: Reconstruction loss avg_num_killed_particles: Average numer of killed particles in particle filter predicted_observations: Predicted observations (depending on timesteps specified in predicted_times) predicted_particles: List of Nones """ batch_size, *rest = observation.size() # Total observation dim to normalise the likelihood # obs_dim = reduce(mul, rest, 1) # Needed for legacy AESMC code ae_util.init(observation.is_cuda) # Legacy code: We need to pass in a (time) sequence of observations # With dim=0 for time img_observation = observation.unsqueeze(0) actions = actions.unsqueeze(0) reward = reward.unsqueeze(0) # Legacy code: All values are wrapped in state.State (which can contain more than one value) observation_states = st.State(all_x=img_observation.contiguous(), all_a=actions.contiguous(), r=reward.contiguous()) old_log_weight = previous_latent_state.log_weight # Encoding the actions and observations (nothing recurrent yet) observation_states = self.encoding_network(observation_states) # Expand the particle dimension observation_states.unsequeeze_and_expand_all_(dim=2, size=self.num_particles) ancestral_indices = sample_ancestral_index(old_log_weight) # How many particles were killed? # List over batch size num_killed_particles = list( tu.num_killed_particles(ancestral_indices.data.cpu())) if self.resample: previous_latent_state = previous_latent_state.resample( ancestral_indices) else: num_killed_particles = [0] * batch_size avg_num_killed_particles = sum(num_killed_particles) / len( num_killed_particles) # Legacy code: Select first (and only) time index current_observation = observation_states.index_elements(0) # Sample stochastic latent state z from proposal proposal_state_random_variable = self.proposal_network( previous_latent_state=previous_latent_state, observation_states=current_observation, time=0) latent_state = self.sample_from(proposal_state_random_variable) # Compute deterministic state h and add to the latent state latent_state = self.deterministic_transition_network( previous_latent_state=previous_latent_state, latent_state=latent_state, observation_states=current_observation, time=0) # Compute prior probability over z transition_state_random_variable = self.transition_network( previous_latent_state, current_observation) # Compute probability over observation emission_state_random_variable = self.emission_network( previous_latent_state, latent_state, current_observation # observation_states ) emission_logpdf = emission_state_random_variable.logpdf( current_observation, batch_size, self.num_particles) proposal_logpdf = proposal_state_random_variable.logpdf( latent_state, batch_size, self.num_particles) transition_logpdf = transition_state_random_variable.logpdf( latent_state, batch_size, self.num_particles) assert (self.prior_loss_coef == 1) assert (self.obs_loss_coef == 1) new_log_weight = transition_logpdf - proposal_logpdf + emission_logpdf # new_log_weight = (self.prior_loss_coef * (transition_logpdf - proposal_logpdf) # + self.obs_loss_coef * emission_logpdf) latent_state.log_weight = new_log_weight # Average (in log space) over particles encoding_logli = math.logsumexp( # torch.stack(log_weights, dim=0), dim=2 new_log_weight, dim=1) - np.log(self.num_particles) # inference_result.latent_states = latent_states predicted_observations = None particle_observations = None if predicted_times is not None: predicted_observations, particle_observations = self.predict_observations( latent_state=latent_state, current_observation=current_observation, actions=actions, emission_state_random_variable=emission_state_random_variable, predicted_times=predicted_times) ae_util.init(False) return latent_state, \ - encoding_logli, \ (- transition_logpdf + proposal_logpdf, - emission_logpdf),\ avg_num_killed_particles,\ predicted_observations, particle_observations
def reconstruct_predict(self, observation_states, num_particles, resample, reconstruction_length, prediction_length, summarize_function): """ REWRITE!! input: observations: Variable [num_timesteps, batch_size, observation_dim] resample: bool. True: smc, False: is. num_particles: number. number of particles for posterior approximation. prediction_length: number. length of prediction. reconstruction_length: number. length of reconstruction_length Should be num_timesteps - prediction_length summarize_function: Function. Takes in an ensemble of states at a certain time and corresponding weights. Outputs a single (combined) state. output: observations_reconstructed_predicted: Tensor [num_timesteps + prediction_length, batch_size, observation_dim] """ # Check this num_timesteps, batch_size, nr_channels, w, h = observation_states.all_x.size() assert(num_timesteps == reconstruction_length + prediction_length) # I think we only predict one channel? # given_obs = observations[:reconstruction_length] # Only regress model on 'known' states given_obs = st.State( all_x=observation_states.all_x[:reconstruction_length] ) inference_result = self.forward( observation_states=given_obs, num_particles=num_particles, resample=resample, return_inference_results=True ) latent_states = inference_result.latent_states latent_state = latent_states[-1] for t in range(reconstruction_length, num_timesteps): previous_latent_state = latent_state # a) Prior: Draw latent_state.z transition_state_random_variable = self.transition_network( previous_latent_state ) latent_state = transition_state_random_variable.sample( batch_size, num_particles ) batch_size, num_particles, z_dim = latent_state.z.size() # TODO: This is highly specific for VRNN latent_state.phi_z = self.deterministic_transition_network.phi_z( latent_state.z.view(-1, z_dim) ).view( batch_size, num_particles, -1 ) latent_states.append(latent_state) # b) Generation: Draw x and compute phi_x emission_state_random_variable = self.emission_network( previous_latent_state, latent_state ) x = emission_state_random_variable.sample( batch_size, num_particles).all_x # This is usually done in the init-network # num_timesteps = 1 for one image # TODO: This is highly specific for VRNN phi_x = self.encoding_network.phi_x( x.view(-1, nr_channels, w, h) ).view( 1, batch_size, num_particles, -1 ).contiguous() current_observation = st.State( all_phi_x=phi_x, x=x # Just in case, not used in VRNN ) # Set time=0 because we have only the last observation if self.deterministic_transition_network is not None: latent_state = self.deterministic_transition_network( previous_latent_state=previous_latent_state, latent_state=latent_state, observation_states=current_observation, time=0 ) # Ok, at this point we have all the latent states # Compute reconstructed/predicted observations from given latents averaged_obs = torch.zeros(num_timesteps, batch_size, nr_channels, h, w) all_obs = torch.zeros(num_timesteps, batch_size, num_particles, nr_channels, h, w) observation_states.unsequeeze_and_expand_all_(dim=2, size=num_particles) initial_state_random_variable = self.initial_network( observation_states ) initial_state = initial_state_random_variable.sample( batch_size, num_particles ) for t in range(num_timesteps): if t == 0: previous_latent_state = initial_state else: previous_latent_state = latent_states[t-1] latent_state = latent_states[t] emission_state_random_variable = self.emission_network( previous_latent_state, latent_state ) x = emission_state_random_variable.sample( batch_size, num_particles).all_x all_obs[t] = x.data averaged_obs[t] = summarize_function( x, inference_result.log_weight).data return all_obs, averaged_obs, inference_result.log_weight.data