def sample(self, states, macro=None, burn_in=0): n_agents = self.params['n_agents'] h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) for t in range(states.size(0) - 1): y_t = states[t].clone() z_list = [] for i in range(n_agents): prior_t = self.prior[i](torch.cat([y_t, h[-1]], 1)) prior_mean_t = self.prior_mean[i](prior_t) prior_std_t = self.prior_std[i](prior_t) z_t = sample_gauss(prior_mean_t, prior_std_t) z_list.append(z_t) dec_t = self.dec[i](torch.cat([y_t, z_t, h[-1]], 1)) dec_mean_t = self.dec_mean[i](dec_t) dec_std_t = self.dec_std[i](dec_t) if t >= burn_in: states[t + 1, :, 2 * i:2 * i + 2] = sample_gauss( dec_mean_t, dec_std_t) z_concat = torch.cat(z_list, -1) _, h = self.rnn(torch.cat([y_t, z_concat], 1).unsqueeze(0), h) return states, None
def sample(self, states, macro, burn_in=0, fix_m=[]): n_agents = self.params['n_agents'] macro_shared = get_macro_ohe(macro, 1, self.params['m_dim']).squeeze() if len(fix_m) == 0: fix_m = [-1] * n_agents h_micro = [ torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_micro_dim']) for i in range(n_agents) ] h_macro = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_macro_dim']) macro_intents = torch.zeros(macro.size()) for t in range(states.size(0) - 1): y_t = states[t].clone() m_t = macro_shared[t].clone() for i in range(n_agents): if t >= burn_in: dec_macro_t = self.dec_macro( torch.cat([y_t, h_macro[-1]], 1)) m_t = sample_multinomial(torch.exp(dec_macro_t)) macro_intents[t] = torch.max(m_t, -1)[1].unsqueeze(-1) _, h_macro = self.gru_macro( torch.cat([m_t], 1).unsqueeze(0), h_macro) for i in range(n_agents): prior_t = self.prior[i](torch.cat([y_t, m_t, h_micro[i][-1]], 1)) prior_mean_t = self.prior_mean[i](prior_t) prior_std_t = self.prior_std[i](prior_t) z_t = sample_gauss(prior_mean_t, prior_std_t) dec_t = self.dec[i](torch.cat([y_t, m_t, z_t, h_micro[i][-1]], 1)) dec_mean_t = self.dec_mean[i](dec_t) dec_std_t = self.dec_std[i](dec_t) if t >= burn_in: states[t + 1, :, 2 * i:2 * i + 2] = sample_gauss( dec_mean_t, dec_std_t) _, h_micro[i] = self.gru_micro[i](torch.cat( [states[t + 1, :, 2 * i:2 * i + 2], z_t], 1).unsqueeze(0), h_micro[i]) macro_intents.data[-1] = macro_intents.data[-2] return states, macro_intents
def sample(self, states, macro, burn_in=0, fix_m=[]): n_agents = self.params['n_agents'] states_single = index_by_agent(states, n_agents) macro_single = get_macro_ohe(macro, n_agents, self.params['m_dim']) if len(fix_m) == 0: fix_m = [-1]*n_agents h_micro = [torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_micro_dim']) for i in range(n_agents)] h_macro = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_macro_dim']) macro_intents = torch.zeros(states.size(0), states.size(1), n_agents) for t in range(states.size(0)-1): y_t = states[t].clone() m_t = macro_single[t].clone() for i in range(n_agents): if t >= burn_in: dec_macro_t = self.dec_macro[i](torch.cat([y_t, h_macro[-1]], 1)) m_t[i] = sample_multinomial(torch.exp(dec_macro_t)) macro_intents[t] = torch.max(m_t, 2)[1].transpose(0,1) m_t_concat = m_t.transpose(0,1).contiguous().view(states.size(1), -1) _, h_macro = self.gru_macro(torch.cat([m_t_concat], 1).unsqueeze(0), h_macro) for i in range(n_agents): prior_t = self.prior[i](torch.cat([m_t[i], h_micro[i][-1]], 1)) prior_mean_t = self.prior_mean[i](prior_t) prior_std_t = self.prior_std[i](prior_t) z_t = sample_gauss(prior_mean_t, prior_std_t) dec_t = self.dec[i](torch.cat([y_t, m_t[i], z_t, h_micro[i][-1]], 1)) dec_mean_t = self.dec_mean[i](dec_t) dec_std_t = self.dec_std[i](dec_t) if t >= burn_in: states[t+1,:,2*i:2*i+2] = sample_gauss(dec_mean_t, dec_std_t) _, h_micro[i] = self.gru_micro[i](torch.cat([states[t+1,:,2*i:2*i+2], z_t], 1).unsqueeze(0), h_micro[i]) macro_intents.data[-1] = macro_intents.data[-2] return states, macro_intents
def forward(self, states, macro=None, hp=None): n_agents = self.params['n_agents'] states_single = index_by_agent(states, n_agents) macro_single = get_macro_ohe(macro, n_agents, self.params['m_dim']) out = {} if hp['pretrain']: out['crossentropy_loss'] = 0 else: out['kl_loss'] = 0 out['recon_loss'] = 0 h_micro = [torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_micro_dim']) for i in range(n_agents)] h_macro = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_macro_dim']) if self.params['cuda']: h_macro = h_macro.cuda() h_micro = cudafy_list(h_micro) for t in range(states.size(0)-1): x_t = states_single[t].clone() y_t = states[t].clone() m_t = macro_single[t].clone() if hp['pretrain']: for i in range(n_agents): dec_macro_t = self.dec_macro[i](torch.cat([y_t, h_macro[-1]], 1)) out['crossentropy_loss'] -= torch.sum(m_t[i]*dec_macro_t) m_t_concat = m_t.transpose(0,1).contiguous().view(states.size(1), -1).clone() _, h_macro = self.gru_macro(torch.cat([m_t_concat], 1).unsqueeze(0), h_macro) else: for i in range(n_agents): enc_t = self.enc[i](torch.cat([x_t[i], m_t[i], h_micro[i][-1]], 1)) enc_mean_t = self.enc_mean[i](enc_t) enc_std_t = self.enc_std[i](enc_t) prior_t = self.prior[i](torch.cat([m_t[i], h_micro[i][-1]], 1)) prior_mean_t = self.prior_mean[i](prior_t) prior_std_t = self.prior_std[i](prior_t) z_t = sample_gauss(enc_mean_t, enc_std_t) dec_t = self.dec[i](torch.cat([y_t, m_t[i], z_t, h_micro[i][-1]], 1)) dec_mean_t = self.dec_mean[i](dec_t) dec_std_t = self.dec_std[i](dec_t) _, h_micro[i] = self.gru_micro[i](torch.cat([x_t[i], z_t], 1).unsqueeze(0), h_micro[i]) out['kl_loss'] += kld_gauss(enc_mean_t, enc_std_t, prior_mean_t, prior_std_t) out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t[i]) return out
def forward(self, states, macro=None, hp=None): out = {} out['kl_loss'] = 0 out['recon_loss'] = 0 out['z_entropy'] = 0 out['discrim_loss'] = 0 n_agents = self.params['n_agents'] h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) if self.params['cuda']: h = h.cuda() for t in range(states.size(0)): y_t = states[t].clone() _, h = self.rnn(y_t.unsqueeze(0), h) enc = self.enc(h[-1]) enc_mean = self.enc_mean(enc) enc_std = self.enc_std(enc) z = sample_gauss(enc_mean, enc_std) prior_mean = torch.zeros(enc_mean.size()).to(enc_mean.device) prior_std = torch.ones(enc_std.size()).to(enc_std.device) out['kl_loss'] += kld_gauss(enc_mean, enc_std, prior_mean, prior_std) out['z_entropy'] -= entropy_gauss(enc_std) h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) if self.params['cuda']: h = h.cuda() for t in range(states.size(0) - 1): y_t = states[t].clone() _, h = self.rnn(y_t.unsqueeze(0), h) for i in range(n_agents): x_t = states[t + 1][:, 2 * i:2 * i + 2].clone() dec_t = self.dec[i](torch.cat([z, h[-1]], 1)) dec_mean_t = self.dec_mean[i](dec_t) dec_std_t = self.dec_std[i](dec_t) out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t) discrim = self.discrim(h[-1]) discrim_mean = self.discrim_mean(discrim) discrim_std = self.discrim_std(discrim) out['discrim_loss'] += nll_gauss(discrim_mean, discrim_std, z) return out
def sample(self, states, macro=None, burn_in=0): h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) for t in range(states.size(0) - 1): y_t = states[t].clone() prior_t = self.prior(torch.cat([y_t, h[-1]], 1)) prior_mean_t = self.prior_mean(prior_t) prior_std_t = self.prior_std(prior_t) z_t = sample_gauss(prior_mean_t, prior_std_t) dec_t = self.dec(torch.cat([y_t, z_t, h[-1]], 1)) dec_mean_t = self.dec_mean(dec_t) dec_std_t = self.dec_std(dec_t) if t >= burn_in: states[t + 1] = sample_gauss(dec_mean_t, dec_std_t) _, h = self.rnn(torch.cat([states[t + 1], z_t], 1).unsqueeze(0), h) return states, None
def sample(self, states, macro=None, burn_in=0): h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) for t in range(states.size(0)): dec_t = self.dec(torch.cat([h[-1]], 1)) dec_mean_t = self.dec_mean(dec_t) dec_std_t = self.dec_std(dec_t) if t >= burn_in: states[t] = sample_gauss(dec_mean_t, dec_std_t) _, h = self.rnn(states[t].unsqueeze(0), h) return states, None
def sample(self, states, macro=None, burn_in=0): h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) prior_mean = torch.zeros(states.size(1), self.params['z_dim']).to(states.device) prior_std = torch.ones(states.size(1), self.params['z_dim']).to(states.device) z = sample_gauss(prior_mean, prior_std) for t in range(states.size(0) - 1): y_t = states[t].clone() _, h = self.rnn(y_t.unsqueeze(0), h) for i in range(self.params['n_agents']): dec_t = self.dec[i](torch.cat([z, h[-1]], 1)) dec_mean_t = self.dec_mean[i](dec_t) dec_std_t = self.dec_std[i](dec_t) if t >= burn_in: states[t + 1, :, 2 * i:2 * i + 2] = sample_gauss( dec_mean_t, dec_std_t) return states, None
def forward(self, states, macro=None, hp=None): out = {} out['kl_loss'] = 0 out['recon_loss'] = 0 n_agents = self.params['n_agents'] h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) if self.params['cuda']: h = h.cuda() for t in range(states.size(0) - 1): y_t = states[t].clone() z_list = [] for i in range(n_agents): x_t = states[t + 1][:, 2 * i:2 * i + 2].clone() enc_t = self.enc[i](torch.cat([x_t, y_t, h[-1]], 1)) enc_mean_t = self.enc_mean[i](enc_t) enc_std_t = self.enc_std[i](enc_t) prior_t = self.prior[i](torch.cat([y_t, h[-1]], 1)) prior_mean_t = self.prior_mean[i](prior_t) prior_std_t = self.prior_std[i](prior_t) z_t = sample_gauss(enc_mean_t, enc_std_t) z_list.append(z_t) dec_t = self.dec[i](torch.cat([y_t, z_t, h[-1]], 1)) dec_mean_t = self.dec_mean[i](dec_t) dec_std_t = self.dec_std[i](dec_t) out['kl_loss'] += kld_gauss(enc_mean_t, enc_std_t, prior_mean_t, prior_std_t) out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t) z_concat = torch.cat(z_list, -1) _, h = self.rnn(torch.cat([y_t, z_concat], 1).unsqueeze(0), h) return out
def forward(self, states, macro=None, hp=None): out = {} out['kl_loss'] = 0 out['recon_loss'] = 0 h = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_dim']) if self.params['cuda']: h = h.cuda() for t in range(states.size(0) - 1): y_t = states[t].clone() x_t = states[t + 1].clone() enc_t = self.enc(torch.cat([x_t, y_t, h[-1]], 1)) enc_mean_t = self.enc_mean(enc_t) enc_std_t = self.enc_std(enc_t) prior_t = self.prior(torch.cat([y_t, h[-1]], 1)) prior_mean_t = self.prior_mean(prior_t) prior_std_t = self.prior_std(prior_t) z_t = sample_gauss(enc_mean_t, enc_std_t) dec_t = self.dec(torch.cat([y_t, z_t, h[-1]], 1)) dec_mean_t = self.dec_mean(dec_t) dec_std_t = self.dec_std(dec_t) _, h = self.rnn(torch.cat([x_t, z_t], 1).unsqueeze(0), h) out['kl_loss'] += kld_gauss(enc_mean_t, enc_std_t, prior_mean_t, prior_std_t) out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t) return out