Ejemplo n.º 1
0
    def dynamics_bms_step(self, x_t, geco, seq_step, global_step, z_seq, current_recurrent_state, actions):
        C, H, W = self.input_size[0], self.input_size[1], self.input_size[2]
        log_var = (2 * self.gmm_log_scale).to(x_t.device)
        
        a_prev = self.get_action(seq_step, actions)

        if self.relational_dynamics.ssm == 'Ours':
            lamda, h = self.relational_dynamics(torch.stack(z_seq), current_recurrent_state['n_step_dynamics']['h'], a_prev)
        elif self.relational_dynamics.ssm == 'SSM':
            lamda, h = self.relational_dynamics(z_seq[-1].unsqueeze(0), current_recurrent_state['n_step_dynamics']['h'], a_prev)
        elif self.relational_dynamics.ssm == 'RSSM':
            z_prev = z_seq[-1]  # [N*K, z_size]
            # [N*K, 2*z_size]
            lamda, h = self.relational_dynamics(z_prev, current_recurrent_state['n_step_dynamics']['h'], a_prev)
        current_recurrent_state['n_step_dynamics']['h'] = h.unsqueeze(0)

        loc_z, var_z = lamda.chunk(2, dim=1)
        loc_z, var_z = loc_z.contiguous(), var_z.contiguous()

        p_z = mvn(loc_z, var_z)

        z = p_z.rsample(torch.Size((self.stochastic_samples,)))
        z = z.view(self.stochastic_samples * self.batch_size * self.K, self.z_size)
        
        x_t = x_t.repeat(self.stochastic_samples, 1, 1, 1)
        x_loc, mask_logits = self.image_decoder(z)  #[N*K, C, H, W]
        x_loc = x_loc.view(self.stochastic_samples * self.batch_size, self.K, C, H, W)
        mask_logits = mask_logits.view(self.stochastic_samples * self.batch_size, self.K, 1, H, W)
        mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1).view(self.stochastic_samples * self.batch_size, self.K, 1, H, W)
            
        # NLL [stochastic_samples * batch_size, 1, H, W]
        nll, _ = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs)
        nll = nll.view(self.stochastic_samples, self.batch_size)
        best = torch.argmax(-nll, 0)
        sample_idxs = torch.arange(self.batch_size).to(x_t.device) * self.stochastic_samples + best
    
        nll = nll.permute(1,0).contiguous().view(-1)
        nll = nll[sample_idxs]

        z = z.view(-1, self.K, self.z_size)
        z = z[sample_idxs]
        z = z.view(self.batch_size * self.K, self.z_size)

        x_loc = x_loc.view(self.stochastic_samples, self.batch_size, self.K, C, H, W).permute(1,0,2,3,4,5).contiguous()
        x_loc = x_loc.view(-1, self.K, C, H, W)
        x_loc = x_loc[sample_idxs]
        
        mask_logprobs = mask_logprobs.view(self.stochastic_samples, self.batch_size, self.K, 1, H, W).permute(1,0,2,3,4,5).contiguous()
        mask_logprobs = mask_logprobs.view(-1, self.K, 1, H, W)
        mask_logprobs = mask_logprobs[sample_idxs]

        if self.geco_warm_start > global_step:
        #    # [batch_size]
            loss = torch.mean(nll)
        else:
            loss = -geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll))
        z_seq += [z]
        
        
        return x_loc, mask_logprobs.exp(), nll, z_seq, current_recurrent_state, loss, [var_z]
Ejemplo n.º 2
0
    def rollout(self, latents, seq_step, current_recurrent_states, actions=None, x_t=None, compute_logprob=False):
        """
        z_history is List of length context_len of Tensors of shape [batch_size * K, z_size]
        
        rollout and decode for seq_len - context_len steps

        """
        a_prev = self.get_action(seq_step, actions)
        if a_prev is not None:
            a_prev = a_prev.repeat(self.stochastic_samples, 1)
        
        if self.relational_dynamics.ssm == 'Ours':
            lamda, h = self.relational_dynamics(torch.stack(latents), current_recurrent_states['n_step_dynamics']['h'], a_prev)
        elif self.relational_dynamics.ssm == 'SSM': 
            lamda, h = self.relational_dynamics(latents[-1].unsqueeze(0), current_recurrent_states['n_step_dynamics']['h'], a_prev)
        elif self.relational_dynamics.ssm == 'RSSM':
            latents_ = latents[-1]
            lamda, h = self.relational_dynamics(latents_, current_recurrent_states['n_step_dynamics']['h'], a_prev)
        h = h.unsqueeze(0)
        current_recurrent_states['n_step_dynamics']['h'] = h
       
        # TODO: this is actually _softplus_to_std(var_z) so change to softplus_std
        loc_z, var_z = lamda.chunk(2, dim=1)
        loc_z, var_z = loc_z.contiguous(), var_z.contiguous()
        p_z = mvn(loc_z, var_z)
        z_joint = p_z.rsample(torch.Size((1,)))
        z_joint = z_joint.view(self.stochastic_samples * self.batch_size * self.K, self.z_size)

        #if (decode_last and t == seq_len-1) or not decode_last:
        if self.relational_dynamics.ssm == 'RSSM':
            h_ = h.view(-1, self.z_size)
            z_joint_ = torch.cat([z_joint, h_],1)
            x_loc, mask_logits = self.image_decoder(z_joint_)
        elif self.relational_dynamics.ssm == 'Ours' or self.relational_dynamics.ssm == 'SSM':
            x_loc, mask_logits = self.image_decoder(z_joint)

        _, C, H, W = x_loc.shape
        x_loc = x_loc.view(self.stochastic_samples, self.batch_size, self.K, C, H, W)
        mask_logits = mask_logits.view(self.stochastic_samples * self.batch_size, self.K, 1, H, W)
        mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1).view(self.stochastic_samples, self.batch_size, self.K, 1, H, W)
        means = [x_loc.permute(1,0,2,3,4,5).contiguous()]
        masks = [torch.exp(mask_logprobs).permute(1,0,2,3,4,5).contiguous()]
        # add latents to sequence
        latents += [z_joint]

        if compute_logprob:
            log_var = (2 * self.gmm_log_scale).to(x_t.device)
            x_loc = x_loc.view(self.stochastic_samples * self.batch_size, self.K, C, H, W)
            mask_logprobs = mask_logprobs.view(self.stochastic_samples * self.batch_size, self.K, 1, H, W)
            # NLL [samples * batch_size]
            nll, _ = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs)
            nll = nll.view(self.stochastic_samples, self.batch_size)
            nll_discounted = nll * (1 / ((seq_step - self.context_len)+1.))
            return means, masks, latents, current_recurrent_states, nll, nll_discounted, [var_z]

        return means, masks, latents, current_recurrent_states, [var_z]
Ejemplo n.º 3
0
    def rollout(self, h, c, seq_len, actions=None):
        """
        z_history is List of length context_len of Tensors of shape [batch_size * K, z_size]
        
        rollout and decode for seq_len - context_len steps
        actions is [seq_len, action_dim]

        """
        means = []
        masks = []
        for t in range(self.context_len, seq_len):
            if t == self.context_len:
                #h = h.repeat(1, self.stochastic_samples, 1)  # [1, batch_size * stochastic_samples, hidden_dim]
                #c = c.repeat(1, self.stochastic_samples, 1)  # [1, batch_size * stochastic_samples, hidden_dim]
                h = h.unsqueeze(2).repeat(1,1,self.stochastic_samples,1)  # [1, batch_size, self.stochastic_samples, hidden_dim]
                h = h.view(1, self.batch_size * self.stochastic_samples, -1)
                c = c.unsqueeze(2).repeat(1,1,self.stochastic_samples,1)  # [1, batch_size, self.stochastic_samples, hidden_dim]
                c = c.view(1, self.batch_size * self.stochastic_samples, -1)
                if actions is not None:
                    actions = actions.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1)
                    actions = actions.view(self.batch_size * self.stochastic_samples, -1, self.action_dim)  # [batch * stochastic_samples, seq_len, 4]

            a_t = self.get_action(t, actions) 
            if a_t is not None:
                # prior
                p_t = self.prior(torch.cat([a_t, h.squeeze(0)], 1))
            else:
                p_t = self.prior(h.squeeze(0))
                
            p_t_loc = self.p_loc(p_t)
            p_t_var= self.p_var(p_t)
            p_t = mvn(p_t_loc, p_t_var)
            
            # sample
            z_t = p_t.sample()
            phi_z_t = self.phi_z(z_t)

             # decode 
            x_means, _ = self.dec(torch.cat([phi_z_t, h.squeeze(0)], 1))
            _, C, H, W = x_means.shape
            # autoregressively encode own output
            phi_x_t = self.phi_x(x_means)

            # recurrence
            self.lstm.flatten_parameters()
            out, (h,c) = self.lstm(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), (h,c))
            #means += [x_means.view(self.stochastic_samples, self.batch_size, C, H, W).permute(1,0,2,3,4).contiguous()]
            means += [x_means.view(self.batch_size, self.stochastic_samples, C, H, W)]
            masks += [x_means[:,0]]

        return means, masks  # List of length seq_len - context_len of images and masks
Ejemplo n.º 4
0
    def dynamics_step(self, x_t, geco, seq_step, global_step, z_seq, current_recurrent_state, actions):
        C, H, W = self.input_size[0], self.input_size[1], self.input_size[2]
        log_var = (2 * self.gmm_log_scale).to(x_t.device)
        
        a_prev = self.get_action(seq_step, actions)

        if self.relational_dynamics.ssm == 'Ours':
            lamda, h = self.relational_dynamics(torch.stack(z_seq), current_recurrent_state['n_step_dynamics']['h'], a_prev)
        elif self.relational_dynamics.ssm == 'SSM':
            lamda, h = self.relational_dynamics(z_seq[-1].unsqueeze(0), current_recurrent_state['n_step_dynamics']['h'], a_prev)
        elif self.relational_dynamics.ssm == 'RSSM':
            z_prev = z_seq[-1]  # [N*K, z_size]
            # [N*K, 2*z_size]
            lamda, h = self.relational_dynamics(z_prev, current_recurrent_state['n_step_dynamics']['h'], a_prev)
        current_recurrent_state['n_step_dynamics']['h'] = h.unsqueeze(0)

        loc_z, var_z = lamda.chunk(2, dim=1)
        loc_z, var_z = loc_z.contiguous(), var_z.contiguous()

        p_z = mvn(loc_z, var_z)

        z = p_z.rsample(torch.Size((1,)))
        z = z.view(self.batch_size * self.K, self.z_size)
        if self.relational_dynamics.ssm == 'RSSM':
            z_ = torch.cat([z, current_recurrent_state['n_step_dynamics']['h'].view(-1, self.z_size)], 1)
        elif self.relational_dynamics.ssm == 'Ours' or self.relational_dynamics.ssm == 'SSM':
            z_ = z

        x_loc, mask_logits = self.image_decoder(z_)  #[N*K, C, H, W]
        x_loc = x_loc.view(self.batch_size, self.K, C, H, W)
        mask_logits = mask_logits.view(self.batch_size, self.K, 1, H, W)
        mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1).view(self.batch_size, self.K, 1, H, W)
            
        # NLL [batch_size, 1, H, W]
        nll, _ = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs)
        if self.geco_warm_start > global_step:
        #    # [batch_size]
            loss = torch.mean(nll)
        else:
            loss = -geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll))

        z_seq += [z]

        return x_loc, mask_logprobs.exp(), nll, z_seq, current_recurrent_state, loss, [var_z]
Ejemplo n.º 5
0
    def inference_step(self, x_t, geco, seq_step, global_step, posterior_zs, lambdas, current_recurrent_states, actions):
        total_loss = 0.
        C, H, W = self.input_size[0], self.input_size[1], self.input_size[2]
        log_var = (2 * self.gmm_log_scale).to(x_t.device)
        dynamics_dist = []

        if len(self.iterative_inference_schedule) == 1:
            num_iters = self.iterative_inference_schedule[0]
        else:
            num_iters = self.iterative_inference_schedule[seq_step]
        

        if seq_step == 0:
            assert not torch.isnan(self.lamda_0).any(), 'lambda_0 has nan'
                # expand lambda_0
            lamda_0 = self.lamda_0.repeat(self.batch_size*self.K,1) # [N*K, 2*z_size]
            deterministic_state = current_recurrent_states['n_step_dynamics']['h']
            prior_z = std_mvn(shape=[self.batch_size * self.K, self.z_size], device=x_t.device)
        else:
            prior_z = std_mvn(shape=[self.batch_size * self.K, self.z_size], device=x_t.device)
            
            a_prev = self.get_action(seq_step, actions)
            
            if self.relational_dynamics.ssm == 'Ours':
                lamda_dynamics, q_h_dyn = self.relational_dynamics(torch.stack(posterior_zs), current_recurrent_states['n_step_dynamics']['h'], a_prev)
            elif self.relational_dynamics.ssm == 'SSM':
                lamda_dynamics, q_h_dyn = self.relational_dynamics(posterior_zs[-1].unsqueeze(0), current_recurrent_states['n_step_dynamics']['h'], a_prev)
            elif self.relational_dynamics.ssm == 'RSSM':
                z_prev = posterior_zs[-1]
                lamda_dynamics, q_h_dyn = self.relational_dynamics(z_prev, current_recurrent_states['n_step_dynamics']['h'], a_prev)
            # for next time step
            current_recurrent_states['n_step_dynamics']['h'] = q_h_dyn.unsqueeze(0)
            loc_z, var_z = lamda_dynamics.chunk(2, dim=1)
            loc_z, var_z = loc_z.contiguous(), var_z.contiguous()
            
            dynamics_prior_z = mvn(loc_z, var_z)
            dynamics_dist += [var_z.detach()]

            deterministic_state = current_recurrent_states['n_step_dynamics']['h']
            
            if self.separate_variances:
                lamda_0 = self.lamda_0.repeat(self.batch_size*self.K,1) # [N*K, 2*z_size]
                loc_z_, var_z_ = lamda_0.chunk(2, dim=1)
                loc_z_, var_z_ = loc_z_.contiguous(), var_z_.contiguous()
                
                # use the learned var shared across timesteps
                # and loc_z from dynamics
                dynamics_prior_z = mvn(loc_z, var_z_)
                # update lamda_dynamics
                lamda_0 = torch.cat([loc_z, var_z_],1)

            else:
                lamda_0 = lamda_dynamics

        h = current_recurrent_states['inference_lambda']['h']
        c = current_recurrent_states['inference_lambda']['c']
     
        for i in range(num_iters):
            loc_z, var_z = lamda_0.chunk(2, dim=1)
            loc_z, var_z = loc_z.contiguous(), var_z.contiguous()
            posterior_z = mvn(loc_z, var_z)
            detached_posterior_z = mvn(loc_z.detach(), var_z.detach())
            z = posterior_z.rsample()
            
            # Get means and masks based on SSM. RSSM adds the deterministic path from latent state to observation here.
            if self.relational_dynamics.ssm == 'RSSM':
                z_ = torch.cat([z, deterministic_state.view(-1, self.z_size)], 1)
                x_loc, mask_logits = self.image_decoder(z_)  #[N*K, C, H, W]
            elif self.relational_dynamics.ssm == 'Ours' or self.relational_dynamics.ssm == 'SSM':
                x_loc, mask_logits = self.image_decoder(z)  #[N*K, C, H, W]

            x_loc = x_loc.view(self.batch_size, self.K, C, H, W)

            # softmax across slots
            mask_logits = mask_logits.view(self.batch_size, self.K, 1, H, W)
            mask_logprobs = nn.functional.log_softmax(mask_logits, dim=1)

            # NLL [batch_size, 1, H, W]
            nll, ll_outs = gmm_loglikelihood(x_t, x_loc, log_var, mask_logprobs)

            # KL div
            if seq_step == 0:
                kl_div = torch.distributions.kl.kl_divergence(posterior_z, prior_z)
                kl_div = kl_div.view(self.batch_size, self.K).sum(1)
                refine_foreground_only = False
            else:
                kl_div = torch.distributions.kl.kl_divergence(posterior_z, dynamics_prior_z)
                kl_div = kl_div.view(self.batch_size, self.K).sum(1)
                refine_foreground_only = False
            
            if self.geco_warm_start > global_step:
            #    # [batch_size]
                loss = torch.mean(nll + self.kl_beta * kl_div)
            else:
                loss = torch.mean(self.kl_beta * kl_div) - geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll))
            scaled_loss = ((i+1.) / num_iters) * loss
            total_loss += scaled_loss
            
            # Refinement
            if i == num_iters-1:
                # after T refinement steps, just output final loss
                #z_seq += [z]
                #break
                continue

            # compute refine inputs
            x_ = x_t.repeat(self.K, 1, 1, 1).view(self.batch_size, self.K, C, H, W)
            
            img_inps, vec_inps = refinenet_sequential_inputs(x_, x_loc, mask_logprobs,
                    mask_logits, ll_outs['log_p_k'], ll_outs['normal_ll'], lamda_0, loss, self.layer_norms,
                    not self.training)

            delta, (h,c) = self.refine_net(img_inps, vec_inps, h, c)
            lamda_0 = lamda_0 + delta
       
        posterior_zs += [z]
        lambdas += [lamda_0]

        return x_loc, mask_logprobs.exp(), nll, torch.mean(kl_div), posterior_zs, lambdas, total_loss, current_recurrent_states, i, dynamics_dist
Ejemplo n.º 6
0
    def forward(self, x, actions, geco, step):
        C, H, W = self.input_size[0], self.input_size[1], self.input_size[2]
        T = x.shape[1]
        total_loss = 0.

        x_means_t = []
        masks_t = []  # empty
        nll_t = []
        kl_t = []

        h, c = self.h_0, self.c_0
        h = h.to(x.device).repeat(1, self.batch_size, 1)
        c = c.to(x.device).repeat(1, self.batch_size, 1)
        log_var = (2 * self.log_scale).to(x.device)
        
        if self.training:
            context_len = T
        else:
            context_len = self.context_len

        for t in range(context_len):
            x_t = x[:,t]
            a_t = self.get_action(t, actions)

            phi_x_t = self.phi_x(x_t)
            # encoder
            if a_t is not None:
                enc_t = self.encoder(torch.cat([a_t, phi_x_t, h.squeeze(0)], 1))
            else:
                enc_t = self.encoder(torch.cat([phi_x_t, h.squeeze(0)], 1))

            q_t_loc = self.q_loc(enc_t)
            q_t_var = self.q_var(enc_t)
            q_t = mvn(q_t_loc, q_t_var)

            # prior
            # N.b. SVG-LP provides phi_x_t-1 to the prior from ground truth x_t-1
            # VRNN passes phi_x_t-1 into the LSTM which gets processed by prior at time t "h", 
            # which gives p_t. Same!
            if a_t is not None:
                # prior
                p_t = self.prior(torch.cat([a_t, h.squeeze(0)], 1))
            else:
                p_t = self.prior(h.squeeze(0))
            p_t_loc = self.p_loc(p_t)
            p_t_var= self.p_var(p_t)
            p_t = mvn(p_t_loc, p_t_var)

            # sample
            z_t = q_t.rsample()
            phi_z_t = self.phi_z(z_t)

            # decode
            x_means, _ = self.dec(torch.cat([phi_z_t, h.squeeze(0)], 1))

            # recurrence
            self.lstm.flatten_parameters()
            out, (h,c) = self.lstm(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), (h,c))

            # Loss
            # KL
            kl_div = torch.distributions.kl.kl_divergence(q_t, p_t)  #[batch_size]
            # NLL
            nll = gaussian_loglikelihood(x_t, x_means, log_var)

            if self.geco_warm_start > step:
                loss = torch.mean(nll + self.kl_beta * kl_div)
            else:
                loss = torch.mean(self.kl_beta * kl_div) - \
                    geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll))
            total_loss += loss

            x_means_t += [x_means]
            masks_t += [x_means[:,0]]
            nll_t += [torch.mean(nll)]
            kl_t += [torch.mean(kl_div)]
        
        if not self.training:
            with torch.no_grad():
                pred_x_means, pred_x_masks = self.rollout(h, c, T, actions)
                x_means_t = [_.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1) for _ in x_means_t]
                masks_t = [_.unsqueeze(1).repeat(1, self.stochastic_samples, 1, 1, 1) for _ in masks_t]
                x_means_t = x_means_t + pred_x_means
                masks_t = masks_t + pred_x_masks

                return {
                    'x_means': x_means_t,
                    'masks': masks_t
                }

        for t in range(context_len, T):
            x_t = x[:,t]
            a_t = self.get_action(t, actions)
            # prior
            if a_t is not None:
                p_t = self.prior(torch.cat([a_t,h.squeeze(0)],1))
            else:
                p_t = self.prior(torch.cat([h.squeeze(0)],1))

            p_t_loc = self.p_loc(p_t)
            p_t_var= self.p_var(p_t)
            p_t = mvn(p_t_loc, p_t_var)

            # sample
            z_t = p_t.rsample()
            phi_z_t = self.phi_z(z_t)

             # decode
            x_means, _ = self.dec(torch.cat([phi_z_t, h.squeeze(0)], 1))
            # autoregressively encode own output
            phi_x_t = self.phi_x(x_means)

            # recurrence
            self.lstm.flatten_parameters()
            out, (h,c) = self.lstm(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), (h,c))

            # Loss
            # NLL
            nll = gaussian_loglikelihood(x_t, x_means, log_var)

            if self.geco_warm_start > step:
                loss = torch.mean(nll)
            else:
                loss = -geco.constraint(self.geco_C_ema, self.geco_beta, torch.mean(nll))
            total_loss += loss

            x_means_t += [x_means]
            masks_t += [x_means[:,0]]
            nll_t += [torch.mean(nll)]

        return {
            'total_loss': total_loss,
            'nll': torch.sum(torch.stack(nll_t)),
            'kl': torch.sum(torch.stack(kl_t)),
            'x_means': x_means_t,
            'masks': masks_t,
            'inference_steps': torch.zeros(1).to(x.device)
        }