def dynamics_bms_step(self, x_t, geco, seq_step, global_step, z_seq, current_recurrent_state, actions): C, H, W = self.input_size[0], self.input_size[1], self.input_size[2] log_var = (2 * self.gmm_log_scale).to(x_t.device) a_prev = self.get_action(seq_step, actions) if self.relational_dynamics.ssm == 'Ours': lamda, h = self.relational_dynamics(torch.stack(z_seq), current_recurrent_state['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'SSM': lamda, h = self.relational_dynamics(z_seq[-1].unsqueeze(0), current_recurrent_state['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'RSSM': z_prev = z_seq[-1] # [N*K, z_size] # [N*K, 2*z_size] lamda, h = self.relational_dynamics(z_prev, current_recurrent_state['n_step_dynamics']['h'], a_prev) current_recurrent_state['n_step_dynamics']['h'] = h.unsqueeze(0) loc_z, var_z = lamda.chunk(2, dim=1) loc_z, var_z = loc_z.contiguous(), var_z.contiguous() p_z = mvn(loc_z, var_z) z = p_z.rsample(torch.Size((self.stochastic_samples,))) z = z.view(self.stochastic_samples * self.batch_size * self.K, self.z_size) x_t = x_t.repeat(self.stochastic_samples, 1, 1, 1) x_loc, mask_logits = self.image_decoder(z) #[N*K, C, H, W] x_loc = x_loc.view(self.stochastic_samples * self.batch_size, self.K, C, H, W) mask_logits = mask_logits.view(self.stochastic_samples * self.batch_size, self.K, 1, H, W) mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1).view(self.stochastic_samples * self.batch_size, self.K, 1, H, W) # NLL [stochastic_samples * batch_size, 1, H, W] nll, _ = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs) nll = nll.view(self.stochastic_samples, self.batch_size) best = torch.argmax(-nll, 0) sample_idxs = torch.arange(self.batch_size).to(x_t.device) * self.stochastic_samples + best nll = nll.permute(1,0).contiguous().view(-1) nll = nll[sample_idxs] z = z.view(-1, self.K, self.z_size) z = z[sample_idxs] z = z.view(self.batch_size * self.K, self.z_size) x_loc = x_loc.view(self.stochastic_samples, self.batch_size, self.K, C, H, W).permute(1,0,2,3,4,5).contiguous() x_loc = x_loc.view(-1, self.K, C, H, W) x_loc = x_loc[sample_idxs] mask_logprobs = mask_logprobs.view(self.stochastic_samples, self.batch_size, self.K, 1, H, W).permute(1,0,2,3,4,5).contiguous() mask_logprobs = mask_logprobs.view(-1, self.K, 1, H, W) mask_logprobs = mask_logprobs[sample_idxs] if self.geco_warm_start > global_step: # # [batch_size] loss = torch.mean(nll) else: loss = -geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll)) z_seq += [z] return x_loc, mask_logprobs.exp(), nll, z_seq, current_recurrent_state, loss, [var_z]
def rollout(self, latents, seq_step, current_recurrent_states, actions=None, x_t=None, compute_logprob=False): """ z_history is List of length context_len of Tensors of shape [batch_size * K, z_size] rollout and decode for seq_len - context_len steps """ a_prev = self.get_action(seq_step, actions) if a_prev is not None: a_prev = a_prev.repeat(self.stochastic_samples, 1) if self.relational_dynamics.ssm == 'Ours': lamda, h = self.relational_dynamics(torch.stack(latents), current_recurrent_states['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'SSM': lamda, h = self.relational_dynamics(latents[-1].unsqueeze(0), current_recurrent_states['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'RSSM': latents_ = latents[-1] lamda, h = self.relational_dynamics(latents_, current_recurrent_states['n_step_dynamics']['h'], a_prev) h = h.unsqueeze(0) current_recurrent_states['n_step_dynamics']['h'] = h # TODO: this is actually _softplus_to_std(var_z) so change to softplus_std loc_z, var_z = lamda.chunk(2, dim=1) loc_z, var_z = loc_z.contiguous(), var_z.contiguous() p_z = mvn(loc_z, var_z) z_joint = p_z.rsample(torch.Size((1,))) z_joint = z_joint.view(self.stochastic_samples * self.batch_size * self.K, self.z_size) #if (decode_last and t == seq_len-1) or not decode_last: if self.relational_dynamics.ssm == 'RSSM': h_ = h.view(-1, self.z_size) z_joint_ = torch.cat([z_joint, h_],1) x_loc, mask_logits = self.image_decoder(z_joint_) elif self.relational_dynamics.ssm == 'Ours' or self.relational_dynamics.ssm == 'SSM': x_loc, mask_logits = self.image_decoder(z_joint) _, C, H, W = x_loc.shape x_loc = x_loc.view(self.stochastic_samples, self.batch_size, self.K, C, H, W) mask_logits = mask_logits.view(self.stochastic_samples * self.batch_size, self.K, 1, H, W) mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1).view(self.stochastic_samples, self.batch_size, self.K, 1, H, W) means = [x_loc.permute(1,0,2,3,4,5).contiguous()] masks = [torch.exp(mask_logprobs).permute(1,0,2,3,4,5).contiguous()] # add latents to sequence latents += [z_joint] if compute_logprob: log_var = (2 * self.gmm_log_scale).to(x_t.device) x_loc = x_loc.view(self.stochastic_samples * self.batch_size, self.K, C, H, W) mask_logprobs = mask_logprobs.view(self.stochastic_samples * self.batch_size, self.K, 1, H, W) # NLL [samples * batch_size] nll, _ = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs) nll = nll.view(self.stochastic_samples, self.batch_size) nll_discounted = nll * (1 / ((seq_step - self.context_len)+1.)) return means, masks, latents, current_recurrent_states, nll, nll_discounted, [var_z] return means, masks, latents, current_recurrent_states, [var_z]
def rollout(self, h, c, seq_len, actions=None): """ z_history is List of length context_len of Tensors of shape [batch_size * K, z_size] rollout and decode for seq_len - context_len steps actions is [seq_len, action_dim] """ means = [] masks = [] for t in range(self.context_len, seq_len): if t == self.context_len: #h = h.repeat(1, self.stochastic_samples, 1) # [1, batch_size * stochastic_samples, hidden_dim] #c = c.repeat(1, self.stochastic_samples, 1) # [1, batch_size * stochastic_samples, hidden_dim] h = h.unsqueeze(2).repeat(1,1,self.stochastic_samples,1) # [1, batch_size, self.stochastic_samples, hidden_dim] h = h.view(1, self.batch_size * self.stochastic_samples, -1) c = c.unsqueeze(2).repeat(1,1,self.stochastic_samples,1) # [1, batch_size, self.stochastic_samples, hidden_dim] c = c.view(1, self.batch_size * self.stochastic_samples, -1) if actions is not None: actions = actions.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1) actions = actions.view(self.batch_size * self.stochastic_samples, -1, self.action_dim) # [batch * stochastic_samples, seq_len, 4] a_t = self.get_action(t, actions) if a_t is not None: # prior p_t = self.prior(torch.cat([a_t, h.squeeze(0)], 1)) else: p_t = self.prior(h.squeeze(0)) p_t_loc = self.p_loc(p_t) p_t_var= self.p_var(p_t) p_t = mvn(p_t_loc, p_t_var) # sample z_t = p_t.sample() phi_z_t = self.phi_z(z_t) # decode x_means, _ = self.dec(torch.cat([phi_z_t, h.squeeze(0)], 1)) _, C, H, W = x_means.shape # autoregressively encode own output phi_x_t = self.phi_x(x_means) # recurrence self.lstm.flatten_parameters() out, (h,c) = self.lstm(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), (h,c)) #means += [x_means.view(self.stochastic_samples, self.batch_size, C, H, W).permute(1,0,2,3,4).contiguous()] means += [x_means.view(self.batch_size, self.stochastic_samples, C, H, W)] masks += [x_means[:,0]] return means, masks # List of length seq_len - context_len of images and masks
def dynamics_step(self, x_t, geco, seq_step, global_step, z_seq, current_recurrent_state, actions): C, H, W = self.input_size[0], self.input_size[1], self.input_size[2] log_var = (2 * self.gmm_log_scale).to(x_t.device) a_prev = self.get_action(seq_step, actions) if self.relational_dynamics.ssm == 'Ours': lamda, h = self.relational_dynamics(torch.stack(z_seq), current_recurrent_state['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'SSM': lamda, h = self.relational_dynamics(z_seq[-1].unsqueeze(0), current_recurrent_state['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'RSSM': z_prev = z_seq[-1] # [N*K, z_size] # [N*K, 2*z_size] lamda, h = self.relational_dynamics(z_prev, current_recurrent_state['n_step_dynamics']['h'], a_prev) current_recurrent_state['n_step_dynamics']['h'] = h.unsqueeze(0) loc_z, var_z = lamda.chunk(2, dim=1) loc_z, var_z = loc_z.contiguous(), var_z.contiguous() p_z = mvn(loc_z, var_z) z = p_z.rsample(torch.Size((1,))) z = z.view(self.batch_size * self.K, self.z_size) if self.relational_dynamics.ssm == 'RSSM': z_ = torch.cat([z, current_recurrent_state['n_step_dynamics']['h'].view(-1, self.z_size)], 1) elif self.relational_dynamics.ssm == 'Ours' or self.relational_dynamics.ssm == 'SSM': z_ = z x_loc, mask_logits = self.image_decoder(z_) #[N*K, C, H, W] x_loc = x_loc.view(self.batch_size, self.K, C, H, W) mask_logits = mask_logits.view(self.batch_size, self.K, 1, H, W) mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1).view(self.batch_size, self.K, 1, H, W) # NLL [batch_size, 1, H, W] nll, _ = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs) if self.geco_warm_start > global_step: # # [batch_size] loss = torch.mean(nll) else: loss = -geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll)) z_seq += [z] return x_loc, mask_logprobs.exp(), nll, z_seq, current_recurrent_state, loss, [var_z]
def inference_step(self, x_t, geco, seq_step, global_step, posterior_zs, lambdas, current_recurrent_states, actions): total_loss = 0. C, H, W = self.input_size[0], self.input_size[1], self.input_size[2] log_var = (2 * self.gmm_log_scale).to(x_t.device) dynamics_dist = [] if len(self.iterative_inference_schedule) == 1: num_iters = self.iterative_inference_schedule[0] else: num_iters = self.iterative_inference_schedule[seq_step] if seq_step == 0: assert not torch.isnan(self.lamda_0).any(), 'lambda_0 has nan' # expand lambda_0 lamda_0 = self.lamda_0.repeat(self.batch_size*self.K,1) # [N*K, 2*z_size] deterministic_state = current_recurrent_states['n_step_dynamics']['h'] prior_z = std_mvn(shape=[self.batch_size * self.K, self.z_size], device=x_t.device) else: prior_z = std_mvn(shape=[self.batch_size * self.K, self.z_size], device=x_t.device) a_prev = self.get_action(seq_step, actions) if self.relational_dynamics.ssm == 'Ours': lamda_dynamics, q_h_dyn = self.relational_dynamics(torch.stack(posterior_zs), current_recurrent_states['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'SSM': lamda_dynamics, q_h_dyn = self.relational_dynamics(posterior_zs[-1].unsqueeze(0), current_recurrent_states['n_step_dynamics']['h'], a_prev) elif self.relational_dynamics.ssm == 'RSSM': z_prev = posterior_zs[-1] lamda_dynamics, q_h_dyn = self.relational_dynamics(z_prev, current_recurrent_states['n_step_dynamics']['h'], a_prev) # for next time step current_recurrent_states['n_step_dynamics']['h'] = q_h_dyn.unsqueeze(0) loc_z, var_z = lamda_dynamics.chunk(2, dim=1) loc_z, var_z = loc_z.contiguous(), var_z.contiguous() dynamics_prior_z = mvn(loc_z, var_z) dynamics_dist += [var_z.detach()] deterministic_state = current_recurrent_states['n_step_dynamics']['h'] if self.separate_variances: lamda_0 = self.lamda_0.repeat(self.batch_size*self.K,1) # [N*K, 2*z_size] loc_z_, var_z_ = lamda_0.chunk(2, dim=1) loc_z_, var_z_ = loc_z_.contiguous(), var_z_.contiguous() # use the learned var shared across timesteps # and loc_z from dynamics dynamics_prior_z = mvn(loc_z, var_z_) # update lamda_dynamics lamda_0 = torch.cat([loc_z, var_z_],1) else: lamda_0 = lamda_dynamics h = current_recurrent_states['inference_lambda']['h'] c = current_recurrent_states['inference_lambda']['c'] for i in range(num_iters): loc_z, var_z = lamda_0.chunk(2, dim=1) loc_z, var_z = loc_z.contiguous(), var_z.contiguous() posterior_z = mvn(loc_z, var_z) detached_posterior_z = mvn(loc_z.detach(), var_z.detach()) z = posterior_z.rsample() # Get means and masks based on SSM. RSSM adds the deterministic path from latent state to observation here. if self.relational_dynamics.ssm == 'RSSM': z_ = torch.cat([z, deterministic_state.view(-1, self.z_size)], 1) x_loc, mask_logits = self.image_decoder(z_) #[N*K, C, H, W] elif self.relational_dynamics.ssm == 'Ours' or self.relational_dynamics.ssm == 'SSM': x_loc, mask_logits = self.image_decoder(z) #[N*K, C, H, W] x_loc = x_loc.view(self.batch_size, self.K, C, H, W) # softmax across slots mask_logits = mask_logits.view(self.batch_size, self.K, 1, H, W) mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1) # NLL [batch_size, 1, H, W] nll, ll_outs = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs) # KL div if seq_step == 0: kl_div = torch.distributions.kl.kl_divergence(posterior_z, prior_z) kl_div = kl_div.view(self.batch_size, self.K).sum(1) refine_foreground_only = False else: kl_div = torch.distributions.kl.kl_divergence(posterior_z, dynamics_prior_z) kl_div = kl_div.view(self.batch_size, self.K).sum(1) refine_foreground_only = False if self.geco_warm_start > global_step: # # [batch_size] loss = torch.mean(nll + self.kl_beta * kl_div) else: loss = torch.mean(self.kl_beta * kl_div) - geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll)) scaled_loss = ((i+1.) / num_iters) * loss total_loss += scaled_loss # Refinement if i == num_iters-1: # after T refinement steps, just output final loss #z_seq += [z] #break continue # compute refine inputs x_ = x_t.repeat(self.K, 1, 1, 1).view(self.batch_size, self.K, C, H, W) img_inps, vec_inps = refinenet_sequential_inputs(x_, x_loc, mask_logprobs, mask_logits, ll_outs['log_p_k'], ll_outs['normal_ll'], lamda_0, loss, self.layer_norms, not self.training) delta, (h,c) = self.refine_net(img_inps, vec_inps, h, c) lamda_0 = lamda_0 + delta posterior_zs += [z] lambdas += [lamda_0] return x_loc, mask_logprobs.exp(), nll, torch.mean(kl_div), posterior_zs, lambdas, total_loss, current_recurrent_states, i, dynamics_dist
def forward(self, x, actions, geco, step): C, H, W = self.input_size[0], self.input_size[1], self.input_size[2] T = x.shape[1] total_loss = 0. x_means_t = [] masks_t = [] # empty nll_t = [] kl_t = [] h, c = self.h_0, self.c_0 h = h.to(x.device).repeat(1, self.batch_size, 1) c = c.to(x.device).repeat(1, self.batch_size, 1) log_var = (2 * self.log_scale).to(x.device) if self.training: context_len = T else: context_len = self.context_len for t in range(context_len): x_t = x[:,t] a_t = self.get_action(t, actions) phi_x_t = self.phi_x(x_t) # encoder if a_t is not None: enc_t = self.encoder(torch.cat([a_t, phi_x_t, h.squeeze(0)], 1)) else: enc_t = self.encoder(torch.cat([phi_x_t, h.squeeze(0)], 1)) q_t_loc = self.q_loc(enc_t) q_t_var = self.q_var(enc_t) q_t = mvn(q_t_loc, q_t_var) # prior # N.b. SVG-LP provides phi_x_t-1 to the prior from ground truth x_t-1 # VRNN passes phi_x_t-1 into the LSTM which gets processed by prior at time t "h", # which gives p_t. Same! if a_t is not None: # prior p_t = self.prior(torch.cat([a_t, h.squeeze(0)], 1)) else: p_t = self.prior(h.squeeze(0)) p_t_loc = self.p_loc(p_t) p_t_var= self.p_var(p_t) p_t = mvn(p_t_loc, p_t_var) # sample z_t = q_t.rsample() phi_z_t = self.phi_z(z_t) # decode x_means, _ = self.dec(torch.cat([phi_z_t, h.squeeze(0)], 1)) # recurrence self.lstm.flatten_parameters() out, (h,c) = self.lstm(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), (h,c)) # Loss # KL kl_div = torch.distributions.kl.kl_divergence(q_t, p_t) #[batch_size] # NLL nll = gaussian_loglikelihood(x_t, x_means, log_var) if self.geco_warm_start > step: loss = torch.mean(nll + self.kl_beta * kl_div) else: loss = torch.mean(self.kl_beta * kl_div) - \ geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll)) total_loss += loss x_means_t += [x_means] masks_t += [x_means[:,0]] nll_t += [torch.mean(nll)] kl_t += [torch.mean(kl_div)] if not self.training: with torch.no_grad(): pred_x_means, pred_x_masks = self.rollout(h, c, T, actions) x_means_t = [_.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1) for _ in x_means_t] masks_t = [_.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1) for _ in masks_t] x_means_t = x_means_t + pred_x_means masks_t = masks_t + pred_x_masks return { 'x_means': x_means_t, 'masks': masks_t } for t in range(context_len, T): x_t = x[:,t] a_t = self.get_action(t, actions) # prior if a_t is not None: p_t = self.prior(torch.cat([a_t,h.squeeze(0)],1)) else: p_t = self.prior(torch.cat([h.squeeze(0)],1)) p_t_loc = self.p_loc(p_t) p_t_var= self.p_var(p_t) p_t = mvn(p_t_loc, p_t_var) # sample z_t = p_t.rsample() phi_z_t = self.phi_z(z_t) # decode x_means, _ = self.dec(torch.cat([phi_z_t, h.squeeze(0)], 1)) # autoregressively encode own output phi_x_t = self.phi_x(x_means) # recurrence self.lstm.flatten_parameters() out, (h,c) = self.lstm(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), (h,c)) # Loss # NLL nll = gaussian_loglikelihood(x_t, x_means, log_var) if self.geco_warm_start > step: loss = torch.mean(nll) else: loss = -geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll)) total_loss += loss x_means_t += [x_means] masks_t += [x_means[:,0]] nll_t += [torch.mean(nll)] return { 'total_loss': total_loss, 'nll': torch.sum(torch.stack(nll_t)), 'kl': torch.sum(torch.stack(kl_t)), 'x_means': x_means_t, 'masks': masks_t, 'inference_steps': torch.zeros(1).to(x.device) }