Esempio n. 1
0
    def forward(self, input, discrete=False):
        # This if was modified. I think the reason is because in AutoSpeech individual items have 2 dimensions (time, freq), while in my ToyASV2019 items have three dimensions (channels, time, freq), but channels is always 1.
        if len(input.shape) < 4:
            input = input.unsqueeze(1)
        s0 = s1 = self.stem(input)
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                if discrete:
                    weights = self.alphas_reduce
                else:
                    weights = gumbel_softmax(
                        F.log_softmax(self.alphas_reduce, dim=-1))
            else:
                if discrete:
                    weights = self.alphas_normal
                else:
                    weights = gumbel_softmax(
                        F.log_softmax(self.alphas_normal, dim=-1))
            s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
        v = self.global_pooling(s1)
        v = v.view(v.size(0), -1)

        # This is not needed, in anti-spoofing we always forward through the whole network
        # if not self.training:
        #    return v

        y = self.classifier(v)

        return y
Esempio n. 2
0
File: models.py Progetto: yvlian/GKT
 def forward(self, data, sp_send, sp_rec, sp_send_t, sp_rec_t):
     r"""
     Parameters:
         data: input concept embedding matrix
         sp_send: one-hot encoded send-node index(sparse tensor)
         sp_rec: one-hot encoded receive-node index(sparse tensor)
         sp_send_t: one-hot encoded send-node index(sparse tensor, transpose)
         sp_rec_t: one-hot encoded receive-node index(sparse tensor, transpose)
     Shape:
         data: [concept_num, embedding_dim]
         sp_send: [edge_num, concept_num]
         sp_rec: [edge_num, concept_num]
         sp_send_t: [concept_num, edge_num]
         sp_rec_t: [concept_num, edge_num]
     Return:
         graphs: latent graph list modeled by z which has different edge types
         output: the reconstructed data
         prob: q(z|x) distribution
     """
     logits = self.encoder(
         data, sp_send, sp_rec, sp_send_t,
         sp_rec_t)  # [edge_num, output_dim(edge_type_num)]
     edges = gumbel_softmax(logits, tau=self.tau,
                            dim=-1)  # [edge_num, edge_type_num]
     prob = F.softmax(logits, dim=-1)
     output = self.decoder(data, edges, sp_send, sp_rec, sp_send_t,
                           sp_rec_t)  # [concept_num, embedding_dim]
     graphs = self._get_graph(edges, sp_send, sp_rec)
     return graphs, output, prob
Esempio n. 3
0
    def encode_t2r(self, inputs, traj_emb_pred, n_epoch):
        latent_edge_samples, latent_edge_probs, latent_edge_logits = \
            None, None, None
        if self.opt_dynamics(n_epoch) and hasattr(self, 'enc_t2r'):
            if traj_emb_pred is None:
                nri_input = inputs['traj']
            else:
                nri_input = traj_emb_pred
                if self.enc_t2r_detach_i2t_grad:
                    nri_input = nri_input.detach()

            latent_edge_logits = self.enc_t2r(nri_input)

            latent_edge_samples = gumbel_softmax(latent_edge_logits,
                                                 self.temp,
                                                 hard=not self.training)

            latent_edge_probs = F.softmax(latent_edge_logits, dim=-1)

            # [B * T, n_atoms * (n_atoms - 1), 256] - dynamic graph inference
            # [B * 1, n_atoms * (n_atoms - 1), 256] - static graph inference
            batch_size = nri_input.size(0)
            shp = latent_edge_logits.shape
            shp = [batch_size, -1] + list(shp[1:])
            latent_edge_logits = latent_edge_logits.view(shp)
            latent_edge_samples = latent_edge_samples.view(shp)
            latent_edge_probs = latent_edge_probs.view(shp)

        return latent_edge_samples, latent_edge_probs, latent_edge_logits
Esempio n. 4
0
 def forward(self, x, y, K):
     '''
     :param x:
         encoded utterance in shape (B, 2*H)
     :param y:
         encoded response in shape (B, 2*H) (optional)
     :param K:
         encoded knowledge in shape (B, N, 2*H)
     :return:
         prior, posterior, selected knowledge, selected knowledge logits for BOW_loss
     '''
     if y is not None:
         prior = F.log_softmax(torch.bmm(x.unsqueeze(1), K.transpose(-1, -2)), dim=-1).squeeze(1)
         response = self.mlp(torch.cat((x, y), dim=-1))  # response: [n_batch, 2*n_hidden]
         K = K.transpose(-1, -2)  # K: [n_batch, 2*n_hidden, N]
         posterior_logits = torch.bmm(response.unsqueeze(1), K).squeeze(1)
         posterior = F.softmax(posterior_logits, dim=-1)
         k_idx = gumbel_softmax(posterior_logits, self.temperature)  # k_idx: [n_batch, N(one_hot)]
         k_i = torch.bmm(K, k_idx.unsqueeze(2)).squeeze(2)  # k_i: [n_batch, 2*n_hidden]
         k_logits = F.log_softmax(self.mlp_k(k_i), dim=-1)  # k_logits: [n_batch, n_vocab]
         return prior, posterior, k_i, k_logits  # prior: [n_batch, N], posterior: [n_batch, N]
     else:
         n_batch = K.size(0)
         k_i = torch.Tensor(n_batch, 2*self.n_hidden).cuda()
         prior = torch.bmm(x.unsqueeze(1), K.transpose(-1, -2)).squeeze(1)
         k_idx = prior.max(1)[1].unsqueeze(1)  # k_idx: [n_batch, 1]
         for i in range(n_batch):
             k_i[i] = K[i, k_idx[i]]
         return k_i
Esempio n. 5
0
def code_transformer(codes, step, max_steps):
    codes_prob, codes_logprob, codes_oh, codes_oh_prob, codes_oh_trans, codes_oh_scale, \
    codes_sg, codes_hg_trans, codes_hg_scale = {}, {}, {}, {}, {}, {}, {}, {}, {}
    alpha = np.power(1 - step / max_steps, 2)
    for name in codes:
        logits = codes[name]
        prob = torch.softmax(logits, 1)  # B x n_values
        logprob = torch.log_softmax(logits, 1)

        argmax = torch.argmax(logits, 1, True)  # B
        one_hot = torch.zeros_like(logits).scatter_(1, argmax, 1.0)
        one_hot_prob = prob * one_hot
        one_hot_trans = one_hot - one_hot_prob.detach() + one_hot_prob
        one_hot_scale = one_hot_prob / one_hot_prob.sum(1, keepdim=True).detach()

        soft_gumbel, hard_gumbel_trans, hard_gumbel_scale = U.gumbel_softmax(logits, tau=1, dim=1)

        codes_prob[name] = prob
        codes_logprob[name] = logprob
        codes_oh[name] = one_hot
        codes_oh_prob[name] = one_hot_prob
        codes_oh_trans[name] = one_hot_trans
        codes_oh_scale[name] = one_hot_scale
        codes_sg[name] = soft_gumbel
        codes_hg_trans[name] = hard_gumbel_trans
        codes_hg_scale[name] = hard_gumbel_scale

    return {'logits': codes, 'prob': codes_prob, 'logprob': codes_logprob,
            'one_hot': codes_oh, 'one_hot_prob': codes_oh_prob,
            'one_hot_trans': codes_oh_trans, 'one_hot_scale': codes_oh_scale,
            'soft_gumbel': codes_sg, 'hard_gumbel_trans': codes_hg_trans, 'hard_gumbel_scale': codes_hg_scale}
        def nextStep(query, key, value, mask=None, tau=1):
            nonlocal i

            dot_products = (query.unsqueeze(1) * key).sum(
                -1)  # batch x query_len x key_len

            if self.attend_mode == "only_attend_front":
                dot_products[:, i + 1:] -= 1e9

            if self.window > 0:
                dot_products[:, :max(0, i - self.window)] -= 1e9
                dot_products[:, i + self.window + 1:] -= 1e9

            if self.attend_mode == "not_attend_self":
                dot_products[:, i] -= 1e9

            if mask is not None:
                dot_products -= (1 - mask) * 1e9

            logits = dot_products / self.scale
            if self.gumbel_attend and self.training:
                probs = gumbel_softmax(logits, tau, dim=-1)
            else:
                probs = torch.softmax(logits, dim=-1)

            probs = probs * ((dot_products <= -5e8).sum(-1, keepdim=True) <
                             dot_products.shape[-1]).float()

            i += 1
            return torch.einsum("ij, ijk->ik", self.dropout(probs), value)
Esempio n. 7
0
    def _get_sender_lstm_output(self, inputs):
        samples = []
        batch_size = inputs.shape[0]
        sample_loss = torch.zeros(batch_size, device=self.config['device'])
        total_kl = torch.zeros(batch_size, device=self.config['device'])
        hx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])
        cx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])

        for num in range(self.config['num_binary_messages']):
            hx, cx = self.sender_cell(inputs, (hx, cx))
            output = self.sender_project(hx)
            pre_logits = self.sender_out(output)

            sample = utils.gumbel_softmax(
                pre_logits,
                self.temperature[num],
                self.config['device'],
            )

            logits_dist = dists.OneHotCategorical(logits=pre_logits)
            prior_logits = self.prior[num].unsqueeze(0)
            prior_logits = prior_logits.expand(batch_size, self.output_size)
            prior_dist = dists.OneHotCategorical(logits=prior_logits)
            kl = dists.kl_divergence(logits_dist, prior_dist)
            total_kl += kl

            samples.append(sample)
        return samples, sample_loss, total_kl
Esempio n. 8
0
def code_transformer(codes):
    codes_prob, codes_oh, codes_oh_prob, codes_oh_trans, codes_oh_scale, codes_sg, codes_hg = {}, {}, {}, {}, {}, {}, {}
    for name in codes:
        logits = codes[name]
        prob = torch.softmax(logits, 1)  # B x n_values

        argmax = torch.argmax(logits, 1, True)  # B
        one_hot = torch.zeros_like(logits).scatter_(1, argmax, 1.0)
        one_hot_prob = prob * one_hot
        one_hot_trans = one_hot - one_hot_prob.detach() + one_hot_prob
        one_hot_scale = one_hot_prob / one_hot_prob.sum(1,
                                                        keepdim=True).detach()

        soft_gumbel, hard_gumbel, _ = U.gumbel_softmax(logits, tau=1, dim=1)

        codes_prob[name] = prob
        codes_oh[name] = one_hot
        codes_oh_prob[name] = one_hot_prob
        codes_oh_trans[name] = one_hot_trans
        codes_oh_scale[name] = one_hot_scale
        codes_sg[name] = soft_gumbel
        codes_hg[name] = hard_gumbel

    return {
        'prob': codes_prob,
        'one_hot': codes_oh,
        'one_hot_prob': codes_oh_prob,
        'one_hot_trans': codes_oh_trans,
        'one_hot_scale': codes_oh_scale,
        'soft_gumbel': codes_sg,
        'hard_gumbel': codes_hg
    }
Esempio n. 9
0
File: modules.py Progetto: zizai/NRI
    def forward(self,
                data,
                rel_type,
                rel_rec,
                rel_send,
                pred_steps=1,
                burn_in=False,
                burn_in_steps=1,
                dynamic_graph=False,
                encoder=None,
                temp=None):

        inputs = data.transpose(1, 2).contiguous()

        time_steps = inputs.size(1)

        # inputs has shape
        # [batch_size, num_timesteps, num_atoms, num_dims]

        # rel_type has shape:
        # [batch_size, num_atoms*(num_atoms-1), num_edge_types]

        hidden = Variable(
            torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape))
        if inputs.is_cuda:
            hidden = hidden.cuda()

        pred_all = []

        for step in range(0, inputs.size(1) - 1):

            if burn_in:
                if step <= burn_in_steps:
                    ins = inputs[:, step, :, :]
                else:
                    ins = pred_all[step - 1]
            else:
                assert (pred_steps <= time_steps)
                # Use ground truth trajectory input vs. last prediction
                if not step % pred_steps:
                    ins = inputs[:, step, :, :]
                else:
                    ins = pred_all[step - 1]

            if dynamic_graph and step >= burn_in_steps:
                # NOTE: Assumes burn_in_steps = args.timesteps
                logits = encoder(
                    data[:, :, step - burn_in_steps:step, :].contiguous(),
                    rel_rec, rel_send)
                rel_type = gumbel_softmax(logits, tau=temp, hard=True)

            pred, hidden = self.single_step_forward(ins, rel_rec, rel_send,
                                                    rel_type, hidden)
            pred_all.append(pred)

        preds = torch.stack(pred_all, dim=1)

        return preds.transpose(1, 2).contiguous()
Esempio n. 10
0
 def route(self, x, hard=True):
     x = x.float()
     logits = self.linear(x)
     if hard and self.training:
         return gumbel_softmax(logits)
     elif hard and not self.training:
         return ohe_from_logits(logits)
     else:
         return logits.softmax(dim=-1)
Esempio n. 11
0
    def update(self, agent_id, obs, acs, rews, next_obs, dones ,t_step, logger=None):
    
        obs = torch.from_numpy(obs).float()
        acs = torch.from_numpy(acs).float()
        rews = torch.from_numpy(rews[:,agent_id]).float()
        next_obs = torch.from_numpy(next_obs).float()
        dones = torch.from_numpy(dones[:,agent_id]).float()

        acs = acs.view(-1,2)
                
        # --------- update critic ------------ #        
        self.critic_optimizer.zero_grad()
        
        all_trgt_acs = self.target_policy(next_obs) 
    
        target_value = (rews + self.gamma *
                        self.target_critic(next_obs,all_trgt_acs) *
                        (1 - dones)) 
        
        actual_value = self.critic(obs,acs)
        vf_loss = MSELoss(actual_value, target_value.detach())

        # Minimize the loss
        vf_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 1)
        self.critic_optimizer.step()

        # --------- update actor --------------- #
        self.policy_optimizer.zero_grad()

        if self.discrete_action:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = self.policy(obs)
            curr_pol_vf_in = curr_pol_out


        pol_loss = -self.critic(obs,curr_pol_vf_in).mean()
        #pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1)
        self.policy_optimizer.step()

        self.update_all_targets()
        self.eps -= self.eps_decay
        self.eps = max(self.eps, 0)
        

        if logger is not None:
            logger.add_scalars('agent%i/losses' % self.agent_name,
                               {'vf_loss': vf_loss,
                                'pol_loss': pol_loss},
                               self.niter)
Esempio n. 12
0
 def forward(self, query, temperature=1.):
     '''
     query: batch_size x timestep x embedding_dim
     embedding: n_latent x embedding_dim
     '''
     unnormalized_logits = query @ self.embedding.weight
     unnormalized_logits = unnormalized_logits / torch.norm(query, p=2, dim=-1, keepdim=True)
     logits = unnormalized_logits / torch.norm(torch.t(self.embedding.weight), p=2, dim=-1)
     proba = gumbel_softmax(logits, temperature=temperature)
     output = self.embedding(proba)
     return proba, output
 def forward(self, input, temperature=5.0):
     s0 = s1 = self.stem(input)
     for i, cell in enumerate(self.cells):
         if cell.reduction:
             if self.alphas_reduce.size(1) == 1:
                 # weights = F.softmax(self.alphas_reduce, dim=0)
                 # weights = F.gumbel_softmax(self.alphas_reduce, temperature, dim=0)
                 weights = gumbel_softmax(F.log_softmax(self.alphas_reduce,
                                                        dim=0),
                                          tau=temperature,
                                          hard=True,
                                          dim=0)
             else:
                 # weights = F.softmax(self.alphas_reduce, dim=-1)
                 # weights = F.gumbel_softmax(self.alphas_reduce, temperature, dim=-1)
                 weights = gumbel_softmax(F.log_softmax(self.alphas_reduce,
                                                        dim=-1),
                                          tau=temperature,
                                          hard=True,
                                          dim=-1)
         else:
             if self.alphas_normal.size(1) == 1:
                 # weights = F.softmax(self.alphas_normal, dim=0)
                 # weights = F.gumbel_softmax(self.alphas_normal, temperature, dim=0)
                 weights = gumbel_softmax(F.log_softmax(self.alphas_normal,
                                                        dim=0),
                                          tau=temperature,
                                          hard=True,
                                          dim=0)
             else:
                 # weights = F.softmax(self.alphas_normal, dim=-1)
                 # weights = F.gumbel_softmax(self.alphas_normal, temperature, dim=-1)
                 weights = gumbel_softmax(F.log_softmax(self.alphas_normal,
                                                        dim=-1),
                                          tau=temperature,
                                          hard=True,
                                          dim=-1)
         s0, s1 = s1, cell(s0, s1, weights)
     out = self.global_pooling(s1)
     logits = self.classifier(out.view(out.size(0), -1))
     return logits
Esempio n. 14
0
    def update(self):
        batch_states, batch_actions, batch_rewards, batch_new_states, batch_dones = self.replay_memory.sample_mini_batch(
            batch_size=self.batch_size)
        batch_states = batch_states.to(self.device)
        batch_actions = batch_actions.to(self.device)
        batch_rewards = batch_rewards.to(self.device)
        batch_new_states = batch_new_states.to(self.device)
        batch_dones = batch_dones.to(self.device)
        critic_loss_per_agent = []
        actor_loss_per_agent = []
        for idx in range(len(self.actors)):
            actor = self.actors[idx]
            critic = self.critics[idx]
            old_actor = self.old_actors[idx]
            old_critic = self.old_critics[idx]
            actor_optimizer = self.actor_optimizers[idx]
            critic_optimizer = self.critic_optimizers[idx]

            # update critic
            predict_Q = critic(state=batch_states,
                               actions=batch_actions).squeeze(-1)
            old_actor_actions = old_actor(batch_new_states)

            target_actions = batch_actions.clone().detach()
            target_actions[:, idx, :] = old_actor_actions
            target_actions = convert_to_onehot(target_actions,
                                               epsilon=self.epsilon)
            target_Q = self.gamma * old_critic(
                state=batch_new_states, actions=target_actions).squeeze(-1) * (
                    1 - batch_dones) + batch_rewards
            c_loss = self.critic_loss(input=predict_Q,
                                      target=target_Q.detach())
            c_loss.backward()
            torch.nn.utils.clip_grad_norm(critic.parameters(), 0.5)
            critic_optimizer.step()
            critic_optimizer.zero_grad()
            critic_loss_per_agent.append(c_loss.item())

            # update actor
            actor_actions = actor(batch_states)
            actor_actions = gumbel_softmax(actor_actions, hard=True)
            predict_actions = batch_actions.clone().detach()
            predict_actions[:, idx, :] = actor_actions
            a_loss = -critic(state=batch_states,
                             actions=predict_actions).squeeze(-1)
            a_loss = a_loss.mean()
            torch.nn.utils.clip_grad_norm(actor.parameters(), 0.5)
            a_loss.backward()
            actor_optimizer.step()
            actor_optimizer.zero_grad()
            actor_loss_per_agent.append(a_loss.item())
        return sum(actor_loss_per_agent) / len(actor_loss_per_agent), sum(
            critic_loss_per_agent) / len(critic_loss_per_agent)
Esempio n. 15
0
 def select_action(self, state, temperature=None, is_tensor=False, is_target=False):
     # TODO after finished: add temperature to Gumbel sampling
     # __import__('ipdb').set_trace()
     st = state
     if not is_tensor:
         st = torch.from_numpy(state).view(1, -1).float().to(device)
     if is_target:
         action = self.policy_targ(st)
     else:
         __import__('ipdb').set_trace()
         action = self.policy(st)
     action_with_noise = gumbel_softmax(action, hard=True).detach()
     return action_with_noise
Esempio n. 16
0
 def select_action(self, state, temperature=None, is_tensor=False, is_target=False):
     if self.use_warmup and self.action_count < WARMUP_STEPS:
         self.action_count += 1
         # TODO after finished: add temperature to Gumbel sampling
         # TODO ADD varmup steps to sac
         
     st = state
     if not is_tensor:
         st = torch.from_numpy(state).view(1, -1).float().to(device)
     if is_target:
         action = self.policy_targ(st)
         # action = self.policy_targ(st)
     else:
         # __import__('ipdb').set_trace()
         action = self.policy(st)
     action_with_noise = gumbel_softmax(action, hard=True).detach()
     return action_with_noise
Esempio n. 17
0
    def forward(self, inputs, rel_rec, rel_send, tau, hard, pred_steps):
        # NOTE: Assumes that we have the same graph across all samples.
        #         logits = self.rel_graph # (inputs, rel_rec, rel_send)
        edges = gumbel_softmax(self.rel_graph, tau, hard)

        inputs = inputs.transpose(1, 2).contiguous()

        #         sizes = [rel_type.size(0), inputs.size(1), rel_type.size(1),
        #                  rel_type.size(2)]
        # NOTE: Assumes rel_type is constant (i.e. same across all time steps).
        #         rel_type = rel_type.unsqueeze(1).expand(sizes)

        time_steps = inputs.size(1)
        assert (pred_steps <= time_steps)
        preds = []

        # Only take n-th timesteps as starting points (n: pred_steps) 5
        last_pred = inputs[:, 0::pred_steps, :, :]

        #         curr_rel_type = rel_type[:, 0::pred_steps, :, :]
        # NOTE: Assumes rel_type is constant (i.e. same across all time steps).

        # Run n prediction steps
        for step in range(0, pred_steps):
            last_pred = self.single_step_forward(last_pred, rel_rec, rel_send,
                                                 edges)
            preds.append(last_pred)

        sizes = [
            preds[0].size(0), preds[0].size(1) * pred_steps, preds[0].size(2),
            preds[0].size(3)
        ]

        output = Variable(torch.zeros(sizes))
        if inputs.is_cuda:
            output = output.cuda()

        # Re-assemble correct timeline
        for i in range(len(preds)):  #10
            #5 fixed points, 10 each sequence. preds[i] means the ith of each sequence.
            output[:, i::pred_steps, :, :] = preds[i]
        pred_all = output[:, :(inputs.size(1) - 1), :, :]

        return pred_all.transpose(1, 2).contiguous(), \
    self.rel_graph.squeeze(1).expand([inputs.size(0), self.num_nodes*(self.num_nodes-1), self.edge_types])
Esempio n. 18
0
    def forward(self, input, img_emb, lengths):
        lengths, sorted_idx = torch.sort(lengths, descending=True)

        if len(input.size()) > 2:
            input = input[sorted_idx]
            img_emb = img_emb[sorted_idx]

            if not self.use_gumbel_generator:
                gumbel = torch.zeros_like(input)
                for i in range(input.size(1)):
                    gumbel[:, i] = gumbel_softmax(input[:, i].squeeze(),
                                                  tau=0.5).unsqueeze(1)
                input = gumbel

            input_emb = torch.mm(input.view(-1,input.size(-1)), self.embedding.weight)\
                .view(input.size(0),-1, self.embedding.embedding_dim)
        else:
            input_emb = self.embedding(input)
            # input_emb = add_gaussian(input_emb, std=0.01)

        input_emb = self.word_dropout(input_emb)
        packed_input = rnn_utils.pack_padded_sequence(input_emb,
                                                      lengths.data.tolist(),
                                                      batch_first=True)

        outputs, hidden = self.rnn(packed_input, img_emb.unsqueeze(0))

        # process outputs
        outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
        # get the last time step for each sequence
        idx = (lengths - 1).view(-1, 1).expand(outputs.size(0),
                                               outputs.size(2)).unsqueeze(1)
        decoded = outputs.gather(1, idx).squeeze()
        # lengths, sorted_idx = torch.sort(lengths, descending=True)

        img = self.img_embedding(img_emb[:, :self.masked_size])

        # augmented = torch.cat([decoded ,img],1)
        # augmented = torch.cat([outputs[:,-1,:].squeeze() ,img],1)
        # augmented = torch.cat([hidden.squeeze() ,img],1)
        pred = self.fc(decoded)

        return torch.sigmoid(pred).mean(0), hidden  # .view(1)
Esempio n. 19
0
    def forward(self, image):
        batch_size = image.shape[0]

        h_img = self.beholder(image).detach()

        start = [self.w2i["<BOS>"] for _ in range(batch_size)]
        gen_idx = []
        done = np.array([False for _ in range(batch_size)])

        h_img = h_img.unsqueeze(0).view(1, -1, self.D_hid).repeat(1, 1, 1)
        hid = h_img
        ft = torch.tensor(start, dtype=torch.long,
                          device=device).view(-1).unsqueeze(1)
        input = self.emb(ft)
        msg_lens = [self.seq_len for _ in range(batch_size)]

        for idx in range(self.seq_len):
            input = F.relu(input)
            self.rnn.flatten_parameters()
            output, hid = self.rnn(input, hid)

            output = output.view(-1, self.D_hid)
            output = self.hid_to_voc(output)
            output = output.view(-1, self.vocab_size)

            top1, topi = U.gumbel_softmax(output, self.temp, self.hard)
            gen_idx.append(top1)

            for ii in range(batch_size):
                if topi[ii] == self.w2i["<EOS>"]:
                    done[ii] = True
                    msg_lens[ii] = idx + 1
                if np.array_equal(done,
                                  np.array([True for _ in range(batch_size)])):
                    break

            input = self.emb(topi)

        gen_idx = torch.stack(gen_idx).permute(1, 0, 2)
        msg_lens = torch.tensor(msg_lens, dtype=torch.long, device=device)
        return gen_idx, msg_lens
Esempio n. 20
0
    def test_prior(self, data):
        batch_size = data.shape[0]

        input_embs = self.sender_embedding(data)
        inputs = input_embs.view(
            batch_size,
            self.config['num_digits'] * self.config['embedding_size_sender'])
        hx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])
        cx = torch.zeros(batch_size,
                         self.config['num_lstm_sender'],
                         device=self.config['device'])

        samples = []
        log_probs = 0
        post_probs = 0
        for num in range(self.config['num_binary_messages']):
            hx, cx = self.sender_cell(inputs, (hx, cx))
            output = self.sender_project(hx)
            pre_logits = self.sender_out(output)
            posterior_prob = torch.log_softmax(pre_logits, -1)
            sample = utils.gumbel_softmax(pre_logits, self.temperature[num],
                                          self.config['device'])
            samples.append(sample)

            maxz = torch.argmax(sample, dim=-1, keepdim=True)
            h_z = torch.zeros(sample.shape,
                              device=self.config['device']).scatter_(
                                  -1, maxz, 1)
            prior_dst = dists.OneHotCategorical(logits=self.prior[num])
            log_prob = prior_dst.log_prob(h_z).detach().cpu().numpy()
            log_probs += log_prob
            post_probs += posterior_prob[torch.arange(batch_size),
                                         maxz.squeeze()]

        samples = torch.stack(samples).permute(1, 0, 2)
        prior_prob = log_probs / self.config['num_binary_messages']
        post_prob = post_probs.detach().cpu().numpy(
        ) / self.config['num_binary_messages']
        return post_prob, prior_prob, samples
Esempio n. 21
0
    def reconstruct(self, gen_images, gen_captions, gen_len):
        sorted_lengths, sorted_idx = torch.sort(gen_len, descending=True)
        gen_captions = gen_captions[sorted_idx]
        gumbel = torch.zeros_like(gen_captions)
        for i in range(gen_captions.size(1)):
            gumbel[:, i] = gumbel_softmax(gen_captions[:, i].squeeze(),
                                          tau=0.5).unsqueeze(1)

        gumbel_emb = torch.mm(gumbel.view(-1,gumbel.size(-1)), self.embedding.weight)\
                .view(gumbel.size(0),-1, self.embedding.embedding_dim)

        img_enc = self.img_encoder_forward(gen_images)
        txt_enc = self.txt_encoder_forward(gumbel_emb)

        _, reversed_idx = torch.sort(sorted_idx)
        txt_enc = txt_enc[reversed_idx]

        img_mu, img_logv, img_z = self.Hidden2Z_img(img_enc)
        txt_mu, txt_logv, txt_z = self.Hidden2Z_txt(txt_enc)

        return img_mu, img_logv, img_z, txt_mu, txt_logv, txt_z
Esempio n. 22
0
    def forward(self, query, key, value, mask=None, tau=1):
        dot_products = (query.unsqueeze(2) * key.unsqueeze(1)).sum(
            -1)  # batch x query_len x key_len

        if self.attend_mode == "only_attend_front":
            assert query.shape[1] == key.shape[1]
            tri = cuda(torch.ones(key.shape[1], key.shape[1]).triu(1),
                       device=query) * 1e9
            dot_products = dot_products - tri.unsqueeze(0)
        elif self.attend_mode == "only_attend_back":
            assert query.shape[1] == key.shape[1]
            tri = cuda(torch.ones(key.shape[1], key.shape[1]).tril(1),
                       device=query) * 1e9
            dot_products = dot_products - tri.unsqueeze(0)
        elif self.attend_mode == "not_attend_self":
            assert query.shape[1] == key.shape[1]
            eye = cuda(torch.eye(key.shape[1]), device=query) * 1e9
            dot_products = dot_products - eye.unsqueeze(0)

        if self.window > 0:
            assert query.shape[1] == key.shape[1]
            window_mask = cuda(torch.ones(key.shape[1], key.shape[1]),
                               device=query)
            window_mask = (window_mask.triu(self.window + 1) +
                           window_mask.tril(self.window + 1)) * 1e9
            dot_products = dot_products - window_mask.unsqueeze(0)

        if mask is not None:
            dot_products -= (1 - mask) * 1e9

        logits = dot_products / self.scale
        if self.gumbel_attend and self.training:
            probs = gumbel_softmax(logits, tau, dim=-1)
        else:
            probs = torch.softmax(logits, dim=-1)

        probs = probs * ((dot_products <= -5e8).sum(-1, keepdim=True) <
                         dot_products.shape[-1]).float()

        return torch.matmul(self.dropout(probs), value)
Esempio n. 23
0
 def select_action(self, state, temperature=None, is_tensor=False, is_target=False):
     self.action_count += 1
     if self.action_count < self.start_steps:
         # return random action:
         # self.action
         return self.random_action()
     # print("select action")
     st = state
     if not is_tensor:
         st = torch.from_numpy(state).view(1, -1).float().to(device)
     if is_target:
         action = self.policy_targ(st)
         # action = self.policy_targ(st)
     else:
         # __import__('ipdb').set_trace()
         
         # print("not target")
         action = self.policy(st)
         noise = (self.act_noise**0.5)*torch.randn(action.shape)
         # __import__('ipdb').set_trace()
         action += noise
     action_with_noise = gumbel_softmax(action, hard=True).detach()
     # __import__('ipdb').set_trace()
     return action_with_noise
Esempio n. 24
0
    def generator(self,
                  input,
                  input_step,
                  input_size,
                  hidden_size,
                  batch_size,
                  reuse=False):
        with tf.variable_scope("generator") as scope:
            # lstm cell and wrap with dropout
            g_lstm_cell = tf.contrib.rnn.BasicLSTMCell(input_size,
                                                       forget_bias=0.0,
                                                       state_is_tuple=True)
            g_lstm_cell_1 = tf.contrib.rnn.BasicLSTMCell(input_size,
                                                         forget_bias=0.0,
                                                         state_is_tuple=True)

            g_lstm_cell_attention = tf.contrib.rnn.AttentionCellWrapper(
                g_lstm_cell, attn_length=10)
            g_lstm_cell_attention_1 = tf.contrib.rnn.AttentionCellWrapper(
                g_lstm_cell_1, attn_length=10)

            if self.attention == 1:
                g_lstm_cell_drop = tf.contrib.rnn.DropoutWrapper(
                    g_lstm_cell_attention, output_keep_prob=0.9)
                g_lstm_cell_drop_1 = tf.contrib.rnn.DropoutWrapper(
                    g_lstm_cell_attention_1, output_keep_prob=0.9)
            else:
                g_lstm_cell_drop = tf.contrib.rnn.DropoutWrapper(
                    g_lstm_cell, output_keep_prob=0.9)
                g_lstm_cell_drop_1 = tf.contrib.rnn.DropoutWrapper(
                    g_lstm_cell_1, output_keep_prob=0.9)

            g_cell = tf.contrib.rnn.MultiRNNCell(
                [g_lstm_cell_drop, g_lstm_cell_drop_1], state_is_tuple=True)
            g_state_ = g_cell.zero_state(batch_size, tf.float32)
            # g_W_o = utils.glorot([hidden_size, input_size])
            # g_b_o = tf.Variable(tf.random_normal([input_size]))

            # neural network
            g_outputs = []
            g_state = g_state_
            for i in range(input_step):
                if i > 0: tf.get_variable_scope().reuse_variables()
                (g_cell_output, g_state) = g_cell(
                    input[:, i, :],
                    g_state)  # cell_out: [batch_size, hidden_size]
                g_outputs.append(
                    g_cell_output
                )  # output: shape[input_step][batch_size, hidden_size]

            if self.gumbel == 0:
                g_output = tf.reshape(tf.concat(g_outputs, axis=1),
                                      [-1, input_size])
                g_y_soft = tf.nn.softmax(g_output)
                self.z_ = tf.reshape(g_y_soft,
                                     [batch_size, input_step, input_size])
            else:
                g_output = tf.reshape(tf.concat(g_outputs, axis=1),
                                      [batch_size, input_step, input_size])
                self.z_ = utils.gumbel_softmax(g_output,
                                               self.temperature,
                                               hard=True)

            # concentrate input and output of rnn
            x = tf.concat([input, self.z_], axis=1)
            return x
    def forward(self, query, key, value, mask=None, tau=1):
        dot_products = (query.unsqueeze(2) * key.unsqueeze(1)).sum(
            -1)  # batch x query_len x key_len

        if self.relative_clip:
            dot_relative = torch.einsum(
                "ijk,tk->ijt", query,
                self.key_relative.weight)  # batch * query_len * relative_size

            batch_size, query_len, key_len = dot_products.shape

            diag_dim = max(query_len, key_len)
            if self.diag_id.shape[0] < diag_dim:
                self.diag_id = np.zeros((diag_dim, diag_dim))
                for i in range(diag_dim):
                    for j in range(diag_dim):
                        if i <= j - self.relative_clip:
                            self.diag_id[i, j] = 0
                        elif i >= j + self.relative_clip:
                            self.diag_id[i, j] = self.relative_clip * 2
                        else:
                            self.diag_id[i, j] = i - j + self.relative_clip
            diag_id = LongTensor(self.diag_id[:query_len, :key_len])

            dot_relative = reshape(
                dot_relative, "bld", "bl_d",
                key_len).gather(-1,
                                reshape(diag_id, "lm", "_lm_", batch_size,
                                        -1))[:, :, :,
                                             0]  # batch * query_len * key_len
            dot_products = dot_products + dot_relative

        if self.attend_mode == "only_attend_front":
            assert query.shape[1] == key.shape[1]
            tri = cuda(torch.ones(key.shape[1], key.shape[1]).triu(1),
                       device=query) * 1e9
            dot_products = dot_products - tri.unsqueeze(0)
        elif self.attend_mode == "only_attend_back":
            assert query.shape[1] == key.shape[1]
            tri = cuda(torch.ones(key.shape[1], key.shape[1]).tril(1),
                       device=query) * 1e9
            dot_products = dot_products - tri.unsqueeze(0)
        elif self.attend_mode == "not_attend_self":
            assert query.shape[1] == key.shape[1]
            eye = cuda(torch.eye(key.shape[1]), device=query) * 1e9
            dot_products = dot_products - eye.unsqueeze(0)

        if self.window > 0:
            assert query.shape[1] == key.shape[1]
            window_mask = cuda(torch.ones(key.shape[1], key.shape[1]),
                               device=query)
            window_mask = (window_mask.triu(self.window + 1) +
                           window_mask.tril(self.window + 1)) * 1e9
            dot_products = dot_products - window_mask.unsqueeze(0)

        if mask is not None:
            dot_products -= (1 - mask) * 1e9

        logits = dot_products / self.scale
        if self.gumbel_attend and self.training:
            probs = gumbel_softmax(logits, tau, dim=-1)
        else:
            probs = torch.softmax(logits, dim=-1)

        probs = probs * (
            (dot_products <= -5e8).sum(-1, keepdim=True) <
            dot_products.shape[-1]).float()  # batch_size * query_len * key_len
        probs = self.dropout(probs)

        res = torch.matmul(probs, value)  # batch_size * query_len * d_value

        if self.relative_clip:
            if self.recover_id.shape[0] < query_len:
                self.recover_id = np.zeros((query_len, self.relative_size))
                for i in range(query_len):
                    for j in range(self.relative_size):
                        self.recover_id[i, j] = i + j - self.relative_clip
            recover_id = LongTensor(self.recover_id[:key_len])
            recover_id[recover_id < 0] = key_len
            recover_id[recover_id >= key_len] = key_len

            probs = torch.cat([probs, zeros(batch_size, query_len, 1)], -1)
            relative_probs = probs.gather(
                -1,
                reshape(recover_id, "qr", "_qr",
                        batch_size))  # batch_size * query_len * relative_size
            res = res + torch.einsum(
                "bqr,rd->bqd", relative_probs,
                self.value_relative.weight)  # batch_size * query_len * d_value

        return res
Esempio n. 26
0
    def train_td3(self, batch_size):
        self.n_updates += 1
        batch = self.memory.sample(min(batch_size, len(self.memory)))
        states_i, actions_i, rewards_i, next_states_i, dones_i = batch
        # __import__('ipdb').set_trace()
        if self.use_maddpg:
            states_all = torch.cat(states_i, 1)
            next_states_all = torch.cat(next_states_i, 1)
            actions_all = torch.cat(actions_i, 1)
        for i, agent in enumerate(self.agents):
            # print("training_qnet")
            if not self.use_maddpg:
                states_all = states_i[i]
                next_states_all = next_states_i[i]
                actions_all = actions_i[i]
            if self.use_maddpg:  
                next_actions_all = [ag.policy(next_state)
                                    for ag, next_state in zip(self.agents, next_states_i)]

                [self.batch_add_random_acts(e, i) for i, e in enumerate(next_actions_all)]
                next_actions_all = [onehot_from_logits(e) for e in next_actions_all]
            else:
                actions_and_logits = [onehot_from_logits(agent.policy(next_states_i[i]))]
                next_actions_all = [e[0] for e in actions_and_logits]
            total_obs = torch.cat([next_states_all, torch.cat(next_actions_all, 1)], 1)
            qnet_targs = []
            for qnet in self.agents[i].qnet_targs:
                qnet_targs.append(qnet(total_obs).detach())
            rewards = rewards_i[i].view(-1, 1)
            dones = dones_i[i].view(-1, 1)
            qnet_mins = torch.min(qnet_targs[0], qnet_targs[1])
            target_q = rewards + (1 - dones) * GAMMA * (qnet_mins)
            losses = []
            for j, qnet in enumerate(self.agents[i].qnets):
                input_q = qnet(torch.cat([states_all, actions_all], 1))
                self.agents[i].q_optimizers[j].zero_grad()
                loss = self.criterion(input_q, target_q.detach())
                losses.append(loss.item())
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(qnet.parameters(), 0.5)
                self.agents[i].q_optimizers[j].step()


        if self.args.use_writer:
            self.writer.add_scalar(f"Agent_{i}: q_net_loss: ", np.mean(losses), self.n_updates)
        if self.n_updates % 2 == 0:
            for i in range(self.n_agents):
                # print("training policy")
                actor_loss = 0
                # ACTOR gradient ascent of Q(s, π(s | ø)) with respect to ø
                # use gumbel softmax max temp trick
                policy_out = self.agents[i].policy(states_i[i])
                gumbel_sample = gumbel_softmax(policy_out, hard=True)
                if self.use_maddpg:
                    actions_curr_pols = [onehot_from_logits(agent_.policy(state))
                                         for agent_, state in zip(self.agents, states_i)]

                    for action_batch in actions_curr_pols:
                        action_batch.detach_()
                    actions_curr_pols[i] = gumbel_sample
                    actor_loss = - self.agents[i].qnets[0](torch.cat([states_all.detach(),
                                                       torch.cat(actions_curr_pols, 1)], 1)).mean()
                else:
                    actor_loss = - self.agents[i].qnets[0](torch.cat([states_all.detach(),
                                                       gumbel_sample], 1)).mean()
                self.agents[i].p_optimizer.zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.agents[i].policy.parameters(), 0.5)
                self.agents[i].p_optimizer.step()
                actions_i[i].detach_()
                if self.args.use_writer:
                    self.writer.add_scalar(f"Agent_{i}: policy_objective: ", actor_loss.item(), self.n_updates)
                self.update_all_targets()
Esempio n. 27
0
    def sample_and_train_sac(self, batch_size):
        # TODO ADD Model saving, optimize code
        batch = self.memory.sample(min(batch_size, len(self.memory)))
        states_i, actions_i, rewards_i, next_states_i, dones_i = batch
        # __import__('ipdb').set_trace()        
        if self.use_maddpg:
            states_all = torch.cat(states_i, 1)
            next_states_all = torch.cat(next_states_i, 1)
            actions_all = torch.cat(actions_i, 1)
        for i, agent in enumerate(self.agents):
            if not self.use_maddpg:
                states_all = states_i[i]
                next_states_all = next_states_i[i]
                actions_all = actions_i[i]
            if self.use_maddpg:  
                actions_and_logits = [onehot_from_logits(ag.policy(next_state), logprobs=True)
                                    for ag, next_state in zip(self.agents, next_states_i)]

                next_actions_all = [e[0] for e in actions_and_logits]
                next_logits_all = [self.sac_alpha*e[1] for e in actions_and_logits]
                # __import__('ipdb').set_trace()
            else:
                actions_and_logits = [onehot_from_logits(agent.policy(next_states_i[i]),
                                                       logprobs=True)]
                next_actions_all = [e[0] for e in actions_and_logits]
                next_logits_all = [self.sac_alpha*e[1] for e in actions_and_logits]
                
            # computing target
            total_obs = torch.cat([next_states_all, torch.cat(next_actions_all, 1)], 1)
            
            # target_q = self.agents[i].qnet_targ(total_obs).detach()
            qnet_targs = []
            for qnet in self.agents[i].qnet_targs:
                qnet_targs.append(qnet(total_obs).detach())
            rewards = rewards_i[i].view(-1, 1)
            dones = dones_i[i].view(-1, 1)
            qnet_mins = torch.min(qnet_targs[0], qnet_targs[1])
            # __import__('ipdb').set_trace()
            logits_idx = i if self.use_maddpg else 0
            logits_agent = next_logits_all[logits_idx]
            # if len(qnet_mins.squeeze(-1)) != len(logits_agent.squeeze(-1)):
            #     __import__('ipdb').set_trace()
            target_q = rewards + (1 - dones) * GAMMA * (qnet_mins -
                                     logits_agent.reshape(qnet_mins.shape))
            # __import__('ipdb').set_trace()
            # computing the inputs
            for j, qnet in enumerate(self.agents[i].qnets):
                input_q = qnet(torch.cat([states_all, actions_all], 1))
                self.agents[i].q_optimizers[j].zero_grad()
                # print("----")
                # __import__('ipdb').set_trace() 
                loss = self.criterion(input_q, target_q.detach())
                # print('after')
                loss.backward()
                torch.nn.utils.clip_grad_norm_(qnet.parameters(), 0.5)
                self.agents[i].q_optimizers[j].step()

            # __import__('ipdb').set_trace()
            actor_loss = 0
            # ACTOR gradient ascent of Q(s, π(s | ø)) with respect to ø
            # use gumbel softmax max temp trick
            policy_out = self.agents[i].policy(states_i[i])
            gumbel_sample, act_logprobs = gumbel_softmax(policy_out, hard=True, logprobs=True)
            act_logprobs = self.sac_alpha*act_logprobs
            # __import__('ipdb').set_trace() 
            if self.use_maddpg:
                with torch.no_grad():
                    actions_curr_pols = [onehot_from_logits(agent_.policy(state))
                                         for agent_, state in zip(self.agents, states_i)]
                actions_curr_pols[i] = gumbel_sample
                total_obs = torch.cat([states_all, torch.cat(actions_curr_pols, 1)], 1)
                qnet_outs = []
                for qnet in self.agents[i].qnets:
                    qnet_outs.append(qnet(total_obs))
                qnet_mins = torch.min(qnet_outs[0], qnet_outs[1])
                actor_loss = - qnet_mins.mean()
                # __import__('ipdb').set_trace()
            else:
                # actor_loss = - self.agents[i].qnet(torch.cat([states_all.detach(),
                #                                    gumbel_sample], 1)).mean()
                # actions_curr_pols[i] = gumbel_sample
                # __import__('ipdb').set_trace()
                total_obs = torch.cat([states_all, gumbel_sample], 1)
                qnet_outs = []
                for qnet in self.agents[i].qnets:
                    qnet_outs.append(qnet(total_obs))
                qnet_mins = torch.min(qnet_outs[0], qnet_outs[1])
                actor_loss = - qnet_mins.mean()
            # actor_loss += (policy_out**2).mean() * 1e-3

            self.agents[i].p_optimizer.zero_grad()
            actor_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.parameters(), 5)
            # torch.nn.utils.clip_grad_norm_(self.agents[i].policy.parameters(), 0.5)
            self.agents[i].p_optimizer.step()
            # detach the forward propagated action samples
            actions_i[i].detach_()
            # __import__('ipdb').set_trace()
            if self.args.use_writer:
                self.writer.add_scalars("Agent_%i" % i, {
                    "vf_loss": loss,
                    "actor_loss": actor_loss
                }, self.n_updates)
        
        self.update_all_targets()
        self.n_updates += 1
Esempio n. 28
0
    def gumbel_decoder(self, z=None):

        batch_size = self.batch_size

        hidden = self.latent2hidden(z)

        if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            hidden = hidden.view(self.hidden_factor, batch_size,
                                 self.hidden_size)

        hidden = hidden.unsqueeze(0)

        # required for dynamic stopping of sentence generation
        sequence_idx = torch.arange(
            0, batch_size, out=self.tensor()).long()  # all idx of batch
        sequence_running = torch.arange(0, batch_size, out=self.tensor()).long(
        )  # all idx of batch which are still generating
        sequence_mask = torch.ones(batch_size, out=self.tensor()).byte()

        running_seqs = torch.arange(0, batch_size, out=self.tensor()).long(
        )  # idx of still generating sequences with respect to current loop

        generations = self.tensor(batch_size, self.max_sequence_length).fill_(
            self.pad_idx).long()

        t = 0
        while (t < self.max_sequence_length and len(running_seqs) > 0):

            if t == 0:
                input_sequence = to_var(
                    torch.Tensor(batch_size).fill_(self.sos_idx).long())

            input_sequence = input_sequence.unsqueeze(1)

            input_embedding = self.embedding(input_sequence)

            output, hidden = self.decoder_rnn(input_embedding, hidden)

            logits = self.outputs2vocab(output)

            # input_sequence = self._sample(logits)
            input_sequence = gumbel_softmax(logits)

            # save next input
            generations = self._save_sample(generations, input_sequence,
                                            sequence_running, t)

            # update gloabl running sequence
            sequence_mask[sequence_running] = (input_sequence !=
                                               self.eos_idx).data
            sequence_running = sequence_idx.masked_select(sequence_mask)

            # update local running sequences
            running_mask = (input_sequence != self.eos_idx).data
            running_seqs = running_seqs.masked_select(running_mask)

            # prune input and hidden state according to local update
            if len(running_seqs) > 0:
                input_sequence = input_sequence[running_seqs]
                hidden = hidden[:, running_seqs]

                running_seqs = torch.arange(0,
                                            len(running_seqs),
                                            out=self.tensor()).long()

            t += 1

        return generations, z
Esempio n. 29
0
    def sample_and_train(self, batch_size):
        # TODO ADD Model saving, optimize code
        batch = self.memory.sample(min(batch_size, len(self.memory)))

        states_i, actions_i, rewards_i, next_states_i, dones_i = batch

        states_all = torch.cat(states_i, 1)
        next_states_all = torch.cat(next_states_i, 1)
        actions_all = torch.cat(actions_i, 1)

        for i, agent in enumerate(self.agents):
            next_actions_all = [
                onehot_from_logits(ag.policy_targ(next_state))
                for ag, next_state in zip(self.agents, next_states_i)
            ]
            # computing target
            total_obs = torch.cat(
                [next_states_all,
                 torch.cat(next_actions_all, 1)], 1)
            target_q = self.agents[i].qnet_targ(total_obs).detach()
            rewards = rewards_i[i].view(-1, 1)
            dones = dones_i[i].view(-1, 1)
            target_q = rewards + (1 - dones) * GAMMA * target_q

            # computing the inputs
            input_q = self.agents[i].qnet(
                torch.cat([states_all, actions_all], 1))
            self.agents[i].q_optimizer.zero_grad()
            loss = self.criterion(input_q, target_q.detach())
            # print("LOSS", loss)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.agents[i].qnet.parameters(),
                                           0.5)
            self.agents[i].q_optimizer.step()
            actor_loss = 0
            # ACTOR gradient ascent of Q(s, π(s | ø)) with respect to ø

            # use gumbel softmax max temp trick
            policy_out = self.agents[i].policy(states_i[i])
            gumbel_sample = gumbel_softmax(policy_out, hard=True)

            actions_curr_pols = [
                onehot_from_logits(agent_.policy(state))
                for agent_, state in zip(self.agents, states_i)
            ]

            for action_batch in actions_curr_pols:
                action_batch.detach_()
            actions_curr_pols[i] = gumbel_sample

            actor_loss = -self.agents[i].qnet(
                torch.cat(
                    [states_all.detach(),
                     torch.cat(actions_curr_pols, 1)], 1)).mean()
            actor_loss += (policy_out**2).mean() * 1e-3

            self.agents[i].p_optimizer.zero_grad()
            actor_loss.backward()
            # nn.utils.clip_grad_norm_(self.policy.parameters(), 5)
            torch.nn.utils.clip_grad_norm_(self.agents[i].policy.parameters(),
                                           0.5)
            self.agents[i].p_optimizer.step()
            # detach the forward propagated action samples
            actions_i[i].detach_()

            if self.args.use_writer:
                self.writer.add_scalars("Agent_%i" % i, {
                    "vf_loss": loss,
                    "actor_loss": actor_loss
                }, self.n_updates)

        self.update_all_targets()
        self.n_updates += 1
Esempio n. 30
0
    def forward(self,
                data,
                rel_type,
                pred_steps=1,
                burn_in=False,
                burn_in_steps=1,
                dynamic_graph=False,
                encoder=None,
                temp=None):

        # inputs shape [B, T, K, num_dims]
        inputs = data.transpose(1, 2).contiguous()
        time_steps = inputs.size(1)

        # rel_type [B, 1 or T, K*(K-1), n_edge_types]
        if rel_type.size(1) == 1:
            # static graph case
            rt_shp = rel_type.size
            sizes = [rt_shp(0), inputs.size(1), rt_shp(-2), rt_shp(-1)]
            rel_type = rel_type.expand(sizes)

        zrs = torch.zeros(
            (1, inputs.size(0), inputs.size(2), self.msg_out_shape),
            device=inputs.device)
        hidden = (zrs, zrs)
        pred_all = []
        first_step = True
        rollout_zeros = self.rollout_zeros and \
                        (self.rollout_zeros_in_train or not self.training)
        for step in range(0, inputs.size(1) - 1):

            if burn_in:
                if step <= burn_in_steps or self.input_delta:
                    ins = inputs[:, step, :, :]
                    if self.input_delta and not first_step:
                        ins = ins - pred_all[step - 1]
                else:
                    ins = pred_all[step - 1]
            else:
                assert (pred_steps <= time_steps)
                # Use ground truth trajectory inputs vs. last prediction
                if not step % pred_steps or self.input_delta:
                    ins = inputs[:, step, :, :]
                    if self.input_delta:
                        if not first_step and not rollout_zeros:
                            ins = ins - pred_all[step - 1]
                        elif rollout_zeros:
                            ins = torch.zeros_like(ins, device=ins.device)
                else:
                    ins = pred_all[step - 1]

            if dynamic_graph and step >= burn_in_steps:
                # Note assumes burn_in_steps == args.timesteps
                logits = encoder(
                    data[:, :, step - burn_in_steps:step, :].contiguous(),
                    self.rr_full, self.rs_full)
                curr_rel_type = gumbel_softmax(logits, tau=temp, hard=True)
            else:
                curr_rel_type = rel_type[:, step, :, :]

            pred, hidden = self.single_step_forward(ins, hidden, curr_rel_type)
            pred_all.append(pred)
            first_step = False

        preds = torch.stack(pred_all, dim=1)

        return preds.transpose(1, 2).contiguous()