Example #1
0
    def sample(self, states, macro=None, burn_in=0):
        n_agents = self.params['n_agents']

        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])

        for t in range(states.size(0) - 1):
            y_t = states[t].clone()

            z_list = []

            for i in range(n_agents):
                prior_t = self.prior[i](torch.cat([y_t, h[-1]], 1))
                prior_mean_t = self.prior_mean[i](prior_t)
                prior_std_t = self.prior_std[i](prior_t)

                z_t = sample_gauss(prior_mean_t, prior_std_t)
                z_list.append(z_t)

                dec_t = self.dec[i](torch.cat([y_t, z_t, h[-1]], 1))
                dec_mean_t = self.dec_mean[i](dec_t)
                dec_std_t = self.dec_std[i](dec_t)

                if t >= burn_in:
                    states[t + 1, :, 2 * i:2 * i + 2] = sample_gauss(
                        dec_mean_t, dec_std_t)

            z_concat = torch.cat(z_list, -1)
            _, h = self.rnn(torch.cat([y_t, z_concat], 1).unsqueeze(0), h)

        return states, None
    def sample(self, states, macro, burn_in=0, fix_m=[]):
        n_agents = self.params['n_agents']

        macro_shared = get_macro_ohe(macro, 1, self.params['m_dim']).squeeze()

        if len(fix_m) == 0:
            fix_m = [-1] * n_agents

        h_micro = [
            torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_micro_dim']) for i in range(n_agents)
        ]
        h_macro = torch.zeros(self.params['n_layers'], states.size(1),
                              self.params['rnn_macro_dim'])
        macro_intents = torch.zeros(macro.size())

        for t in range(states.size(0) - 1):
            y_t = states[t].clone()
            m_t = macro_shared[t].clone()

            for i in range(n_agents):
                if t >= burn_in:
                    dec_macro_t = self.dec_macro(
                        torch.cat([y_t, h_macro[-1]], 1))
                    m_t = sample_multinomial(torch.exp(dec_macro_t))

            macro_intents[t] = torch.max(m_t, -1)[1].unsqueeze(-1)
            _, h_macro = self.gru_macro(
                torch.cat([m_t], 1).unsqueeze(0), h_macro)

            for i in range(n_agents):
                prior_t = self.prior[i](torch.cat([y_t, m_t, h_micro[i][-1]],
                                                  1))
                prior_mean_t = self.prior_mean[i](prior_t)
                prior_std_t = self.prior_std[i](prior_t)

                z_t = sample_gauss(prior_mean_t, prior_std_t)

                dec_t = self.dec[i](torch.cat([y_t, m_t, z_t, h_micro[i][-1]],
                                              1))
                dec_mean_t = self.dec_mean[i](dec_t)
                dec_std_t = self.dec_std[i](dec_t)

                if t >= burn_in:
                    states[t + 1, :, 2 * i:2 * i + 2] = sample_gauss(
                        dec_mean_t, dec_std_t)

                _, h_micro[i] = self.gru_micro[i](torch.cat(
                    [states[t + 1, :, 2 * i:2 * i + 2], z_t], 1).unsqueeze(0),
                                                  h_micro[i])

        macro_intents.data[-1] = macro_intents.data[-2]

        return states, macro_intents
    def sample(self, states, macro, burn_in=0, fix_m=[]):
        n_agents = self.params['n_agents']
        states_single = index_by_agent(states, n_agents)
        macro_single = get_macro_ohe(macro, n_agents, self.params['m_dim'])

        if len(fix_m) == 0:
            fix_m = [-1]*n_agents

        h_micro = [torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_micro_dim']) for i in range(n_agents)]
        h_macro = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_macro_dim'])
        macro_intents = torch.zeros(states.size(0), states.size(1), n_agents)

        for t in range(states.size(0)-1):
            y_t = states[t].clone()
            m_t = macro_single[t].clone()

            for i in range(n_agents):
                if t >= burn_in:
                    dec_macro_t = self.dec_macro[i](torch.cat([y_t, h_macro[-1]], 1))
                    m_t[i] = sample_multinomial(torch.exp(dec_macro_t))

            macro_intents[t] = torch.max(m_t, 2)[1].transpose(0,1)
            m_t_concat = m_t.transpose(0,1).contiguous().view(states.size(1), -1)
            _, h_macro = self.gru_macro(torch.cat([m_t_concat], 1).unsqueeze(0), h_macro)

            for i in range(n_agents):
                prior_t = self.prior[i](torch.cat([m_t[i], h_micro[i][-1]], 1))
                prior_mean_t = self.prior_mean[i](prior_t)
                prior_std_t = self.prior_std[i](prior_t)

                z_t = sample_gauss(prior_mean_t, prior_std_t)

                dec_t = self.dec[i](torch.cat([y_t, m_t[i], z_t, h_micro[i][-1]], 1))
                dec_mean_t = self.dec_mean[i](dec_t)
                dec_std_t = self.dec_std[i](dec_t)

                if t >= burn_in:
                    states[t+1,:,2*i:2*i+2] = sample_gauss(dec_mean_t, dec_std_t)
                    
                _, h_micro[i] = self.gru_micro[i](torch.cat([states[t+1,:,2*i:2*i+2], z_t], 1).unsqueeze(0), h_micro[i])

        macro_intents.data[-1] = macro_intents.data[-2]

        return states, macro_intents
    def forward(self, states, macro=None, hp=None):
        n_agents = self.params['n_agents']

        states_single = index_by_agent(states, n_agents)
        macro_single = get_macro_ohe(macro, n_agents, self.params['m_dim'])

        out = {}
        if hp['pretrain']:
            out['crossentropy_loss'] = 0
        else:
            out['kl_loss'] = 0
            out['recon_loss'] = 0
        
        h_micro = [torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_micro_dim']) for i in range(n_agents)]
        h_macro = torch.zeros(self.params['n_layers'], states.size(1), self.params['rnn_macro_dim'])
        if self.params['cuda']:
            h_macro = h_macro.cuda()
            h_micro = cudafy_list(h_micro)

        for t in range(states.size(0)-1):
            x_t = states_single[t].clone()
            y_t = states[t].clone()
            m_t = macro_single[t].clone()
            
            if hp['pretrain']:
                for i in range(n_agents):
                    dec_macro_t = self.dec_macro[i](torch.cat([y_t, h_macro[-1]], 1))
                    out['crossentropy_loss'] -= torch.sum(m_t[i]*dec_macro_t)
    
                m_t_concat = m_t.transpose(0,1).contiguous().view(states.size(1), -1).clone()
                _, h_macro = self.gru_macro(torch.cat([m_t_concat], 1).unsqueeze(0), h_macro)

            else:
                for i in range(n_agents):
                    enc_t = self.enc[i](torch.cat([x_t[i], m_t[i], h_micro[i][-1]], 1))
                    enc_mean_t = self.enc_mean[i](enc_t)
                    enc_std_t = self.enc_std[i](enc_t)

                    prior_t = self.prior[i](torch.cat([m_t[i], h_micro[i][-1]], 1))
                    prior_mean_t = self.prior_mean[i](prior_t)
                    prior_std_t = self.prior_std[i](prior_t)

                    z_t = sample_gauss(enc_mean_t, enc_std_t)

                    dec_t = self.dec[i](torch.cat([y_t, m_t[i], z_t, h_micro[i][-1]], 1))
                    dec_mean_t = self.dec_mean[i](dec_t)
                    dec_std_t = self.dec_std[i](dec_t)

                    _, h_micro[i] = self.gru_micro[i](torch.cat([x_t[i], z_t], 1).unsqueeze(0), h_micro[i])

                    out['kl_loss'] += kld_gauss(enc_mean_t, enc_std_t, prior_mean_t, prior_std_t)
                    out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t[i])

        return out
Example #5
0
    def forward(self, states, macro=None, hp=None):
        out = {}
        out['kl_loss'] = 0
        out['recon_loss'] = 0
        out['z_entropy'] = 0
        out['discrim_loss'] = 0

        n_agents = self.params['n_agents']

        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])
        if self.params['cuda']:
            h = h.cuda()

        for t in range(states.size(0)):
            y_t = states[t].clone()
            _, h = self.rnn(y_t.unsqueeze(0), h)

        enc = self.enc(h[-1])
        enc_mean = self.enc_mean(enc)
        enc_std = self.enc_std(enc)

        z = sample_gauss(enc_mean, enc_std)
        prior_mean = torch.zeros(enc_mean.size()).to(enc_mean.device)
        prior_std = torch.ones(enc_std.size()).to(enc_std.device)

        out['kl_loss'] += kld_gauss(enc_mean, enc_std, prior_mean, prior_std)
        out['z_entropy'] -= entropy_gauss(enc_std)

        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])
        if self.params['cuda']:
            h = h.cuda()

        for t in range(states.size(0) - 1):
            y_t = states[t].clone()
            _, h = self.rnn(y_t.unsqueeze(0), h)

            for i in range(n_agents):
                x_t = states[t + 1][:, 2 * i:2 * i + 2].clone()

                dec_t = self.dec[i](torch.cat([z, h[-1]], 1))
                dec_mean_t = self.dec_mean[i](dec_t)
                dec_std_t = self.dec_std[i](dec_t)

                out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t)

        discrim = self.discrim(h[-1])
        discrim_mean = self.discrim_mean(discrim)
        discrim_std = self.discrim_std(discrim)

        out['discrim_loss'] += nll_gauss(discrim_mean, discrim_std, z)

        return out
    def sample(self, states, macro=None, burn_in=0):
        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])

        for t in range(states.size(0) - 1):
            y_t = states[t].clone()

            prior_t = self.prior(torch.cat([y_t, h[-1]], 1))
            prior_mean_t = self.prior_mean(prior_t)
            prior_std_t = self.prior_std(prior_t)

            z_t = sample_gauss(prior_mean_t, prior_std_t)

            dec_t = self.dec(torch.cat([y_t, z_t, h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            dec_std_t = self.dec_std(dec_t)

            if t >= burn_in:
                states[t + 1] = sample_gauss(dec_mean_t, dec_std_t)

            _, h = self.rnn(torch.cat([states[t + 1], z_t], 1).unsqueeze(0), h)

        return states, None
Example #7
0
    def sample(self, states, macro=None, burn_in=0):
        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])

        for t in range(states.size(0)):
            dec_t = self.dec(torch.cat([h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            dec_std_t = self.dec_std(dec_t)

            if t >= burn_in:
                states[t] = sample_gauss(dec_mean_t, dec_std_t)

            _, h = self.rnn(states[t].unsqueeze(0), h)

        return states, None
Example #8
0
    def sample(self, states, macro=None, burn_in=0):
        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])

        prior_mean = torch.zeros(states.size(1),
                                 self.params['z_dim']).to(states.device)
        prior_std = torch.ones(states.size(1),
                               self.params['z_dim']).to(states.device)
        z = sample_gauss(prior_mean, prior_std)

        for t in range(states.size(0) - 1):
            y_t = states[t].clone()
            _, h = self.rnn(y_t.unsqueeze(0), h)

            for i in range(self.params['n_agents']):
                dec_t = self.dec[i](torch.cat([z, h[-1]], 1))
                dec_mean_t = self.dec_mean[i](dec_t)
                dec_std_t = self.dec_std[i](dec_t)

                if t >= burn_in:
                    states[t + 1, :, 2 * i:2 * i + 2] = sample_gauss(
                        dec_mean_t, dec_std_t)

        return states, None
Example #9
0
    def forward(self, states, macro=None, hp=None):
        out = {}
        out['kl_loss'] = 0
        out['recon_loss'] = 0

        n_agents = self.params['n_agents']

        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])
        if self.params['cuda']:
            h = h.cuda()

        for t in range(states.size(0) - 1):
            y_t = states[t].clone()

            z_list = []

            for i in range(n_agents):
                x_t = states[t + 1][:, 2 * i:2 * i + 2].clone()

                enc_t = self.enc[i](torch.cat([x_t, y_t, h[-1]], 1))
                enc_mean_t = self.enc_mean[i](enc_t)
                enc_std_t = self.enc_std[i](enc_t)

                prior_t = self.prior[i](torch.cat([y_t, h[-1]], 1))
                prior_mean_t = self.prior_mean[i](prior_t)
                prior_std_t = self.prior_std[i](prior_t)

                z_t = sample_gauss(enc_mean_t, enc_std_t)
                z_list.append(z_t)

                dec_t = self.dec[i](torch.cat([y_t, z_t, h[-1]], 1))
                dec_mean_t = self.dec_mean[i](dec_t)
                dec_std_t = self.dec_std[i](dec_t)

                out['kl_loss'] += kld_gauss(enc_mean_t, enc_std_t,
                                            prior_mean_t, prior_std_t)
                out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t)

            z_concat = torch.cat(z_list, -1)
            _, h = self.rnn(torch.cat([y_t, z_concat], 1).unsqueeze(0), h)

        return out
    def forward(self, states, macro=None, hp=None):
        out = {}
        out['kl_loss'] = 0
        out['recon_loss'] = 0

        h = torch.zeros(self.params['n_layers'], states.size(1),
                        self.params['rnn_dim'])
        if self.params['cuda']:
            h = h.cuda()

        for t in range(states.size(0) - 1):

            y_t = states[t].clone()
            x_t = states[t + 1].clone()

            enc_t = self.enc(torch.cat([x_t, y_t, h[-1]], 1))
            enc_mean_t = self.enc_mean(enc_t)
            enc_std_t = self.enc_std(enc_t)

            prior_t = self.prior(torch.cat([y_t, h[-1]], 1))
            prior_mean_t = self.prior_mean(prior_t)
            prior_std_t = self.prior_std(prior_t)

            z_t = sample_gauss(enc_mean_t, enc_std_t)

            dec_t = self.dec(torch.cat([y_t, z_t, h[-1]], 1))
            dec_mean_t = self.dec_mean(dec_t)
            dec_std_t = self.dec_std(dec_t)

            _, h = self.rnn(torch.cat([x_t, z_t], 1).unsqueeze(0), h)

            out['kl_loss'] += kld_gauss(enc_mean_t, enc_std_t, prior_mean_t,
                                        prior_std_t)
            out['recon_loss'] += nll_gauss(dec_mean_t, dec_std_t, x_t)

        return out