예제 #1
0
파일: dssm.py 프로젝트: blair101/SPO
    def collect_metrics(self, outputs):
        """
        collect_metrics
        """
        pos_logits = outputs.pos_logits
        pos_target = torch.ones_like(pos_logits)

        neg_logits = outputs.neg_logits
        neg_target = torch.zeros_like(neg_logits)

        pos_loss = F.binary_cross_entropy_with_logits(pos_logits,
                                                      pos_target,
                                                      reduction='none')
        neg_loss = F.binary_cross_entropy_with_logits(neg_logits,
                                                      neg_target,
                                                      reduction='none')

        loss = (pos_loss + neg_loss).mean()

        pos_acc = torch.sigmoid(pos_logits).gt(0.5).float().mean()
        neg_acc = torch.sigmoid(neg_logits).lt(0.5).float().mean()
        margin = (torch.sigmoid(pos_logits) - torch.sigmoid(neg_logits)).mean()
        metrics = Pack(loss=loss,
                       pos_acc=pos_acc,
                       neg_acc=neg_acc,
                       margin=margin)

        num_samples = pos_target.size(0)
        metrics.add(num_samples=num_samples)
        return metrics
예제 #2
0
    def collect_rl_metrics(self, sample_outputs, greedy_outputs, target, gold_entity, entity_dir):
        """
        collect rl training metrics
        """
        num_samples = target.size(0)
        rl_metrics = Pack(num_samples=num_samples)
        loss = 0

        # log prob for sampling and greedily generation
        logits = sample_outputs.logits
        sample = sample_outputs.pred_word
        greedy_logits = greedy_outputs.logits
        greedy_sample = greedy_outputs.pred_word

        # cal reward
        sample_reward, _, _ = self.reward_fn(sample, target, gold_entity, entity_dir)
        greedy_reward, bleu_score, f1_score = self.reward_fn(greedy_sample, target, gold_entity, entity_dir)
        reward = sample_reward - greedy_reward

        # cal RL loss
        sample_log_prob = self.nll_loss(logits, sample, mask=sample.ne(self.padding_idx),
                                        reduction=False, matrix=False)  # [batch_size, max_len]
        nll = sample_log_prob * reward.to(sample_log_prob.device)
        nll = getattr(torch, self.nll_loss.reduction)(nll.sum(dim=-1))
        loss += nll

        # gen report
        rl_acc = accuracy(greedy_logits, target, padding_idx=self.padding_idx)
        if reward.dim() == 2:
            reward = reward.sum(dim=-1)
        rl_metrics.add(loss=loss, reward=reward.mean(), rl_acc=rl_acc,
                       bleu_score=bleu_score.mean(), f1_score=f1_score.mean())

        return rl_metrics
예제 #3
0
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        rnn_input_list = []
        cue_input_list = []
        out_input_list = []
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)
        rnn_input_list.append(input)
        cue_input_list.append(state.knowledge)

        if self.feature_size is not None:
            feature = state.feature.unsqueeze(1)
            rnn_input_list.append(feature)
            cue_input_list.append(feature)

        if self.attn_mode is not None:
            attn_memory = state.attn_memory
            attn_mask = state.attn_mask
            query = hidden[-1].unsqueeze(1)
            weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=attn_mask)
            rnn_input_list.append(weighted_context)
            cue_input_list.append(weighted_context)
            out_input_list.append(weighted_context)
            output.add(attn=attn)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        rnn_output, rnn_hidden = self.rnn(rnn_input, hidden)

        cue_input = torch.cat(cue_input_list, dim=-1)
        cue_output, cue_hidden = self.cue_rnn(cue_input, hidden)

        h_y = self.tanh(self.fc1(rnn_hidden))
        h_cue = self.tanh(self.fc2(cue_hidden))
        if self.concat:
            new_hidden = self.fc3(torch.cat([h_y, h_cue], dim=-1))
        else:
            k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1)))
            new_hidden = k * h_y + (1 - k) * h_cue
        #out_input_list.append(new_hidden.transpose(0, 1))
        # bug fixed
        out_input_list.append(new_hidden[-1].unsqueeze(1))

        out_input = torch.cat(out_input_list, dim=-1)
        state.hidden = new_hidden

        if is_training:
            return out_input, state, output
        else:
            log_prob = self.output_layer(out_input)
            return log_prob, state, output
예제 #4
0
    def generate(self,batch_iter, num_batches):
        self.model.eval()
        itoemo = ['NORM', 'POS', 'NEG']
        with torch.no_grad():
            results = []
            for inputs in batch_iter:
                enc_inputs = inputs
                dec_inputs = inputs.num_tgt_input
                enc_outputs=Pack()
                outputs = self.model.forward(enc_inputs, dec_inputs, hidden=None)
                outputs=outputs.logits
                preds = outputs.max(dim=2)
                # news_id = inputs.id
                tgt_raw = inputs.raw_tgt
                preds = preds[1].tolist()

                temp_a_1=[]
                emo_b_1=[]
                temp = []
                tgt_emo = inputs.tgt_emo[0].tolist()
                for  a, b, c in zip( tgt_raw, preds, tgt_emo):
                    # enc_outputs.add(preds=preds, scores=scores, emos=emos, target_emos=temp)
                    # result_batch = enc_outputs.flatten()
                    # results += result_batch
                    a = a[1:]
                    temp_a = []
                    emo_b = []
                    emo_c=[]

                    for i, entity in enumerate(a):
                        temp_a.append(entity)
                        emo_b.append(itoemo[b[i]])
                        emo_c.append(itoemo[c[i]])

                        # tgt_raw=tgt_raw[1:]
                    assert len(temp_a) == len(emo_b)
                    assert len(emo_c) == len(emo_b)
                    temp_a_1.append([temp_a])   #pred1
                    emo_b_1.append([emo_b])   # emo
                    temp.append(emo_c)

                # temp = []
                # tgt_emo = inputs.tgt_emo[0].tolist()
                # for item in tgt_emo:
                #     temp.append([itoemo[x] for x in item])

                # print(emo_b_1)
                # print(temp)

                if hasattr(inputs, 'id') and inputs.id is not None:
                    enc_outputs.add(id=inputs['id'])
                enc_outputs.add(tgt=tgt_raw, preds=temp_a_1, emos=emo_b_1, target_emos=temp)
                result_batch=enc_outputs.flatten()
                results+=result_batch
            return results
예제 #5
0
파일: hntm_tg.py 프로젝트: wizare/CMTE
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        '''
        inputs: src, topic_src, topic_tgt, [tgt]
        '''
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2

        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        guide_score = enc_hidden[-1]
        decoder_init_state = enc_hidden

        if self.use_ntm:
            bow_src = inputs.bow
            ntm_stat = self.ntm(bow_src)
            outputs.add(ntm_loss=ntm_stat['loss'])

        embedding_weight = self.topic_embedder.weight
        topic_word_logit = self.ntm.topics.get_topic_word_logit()
        EPS = 1e-12
        w = topic_word_logit.transpose(0, 1)  # V x K
        nv = embedding_weight / (
            torch.norm(embedding_weight, dim=1, keepdim=True) + EPS)  # V x dim
        nw = w / (torch.norm(w, dim=0, keepdim=True) + EPS)  # V x K
        t = nv.transpose(0, 1) @ w  # dim x K
        t = t.transpose(0, 1)

        topic_feature_input = []
        if 'S' in self.decoder_attention_channels:
            src_labels = F.one_hot(inputs.topic_src_label,
                                   num_classes=self.topic_num)  # B * K
            src_topics = src_labels.float() @ t
            topic_feature_input.append(src_topics)
        if 'T' in self.decoder_attention_channels:
            tgt_labels = F.one_hot(inputs.topic_tgt_label,
                                   num_classes=self.topic_num)
            tgt_topics = tgt_labels.float() @ t
            topic_feature_input.append(tgt_topics)

        topic_feature = self.t_to_feature(
            torch.cat(topic_feature_input, dim=-1))

        dec_init_state = self.decoder.initialize_state(
            hidden=decoder_init_state,
            attn_memory=enc_outputs,
            memory_lengths=lengths,
            guide_score=guide_score,
            topic_feature=topic_feature)
        return outputs, dec_init_state
예제 #6
0
    def decode(self, input, state, is_training=False):
        last_hidden = state.hidden
        rnn_input_list = []
        output = Pack()

        embed_q = self.C[0](input).unsqueeze(1)  # b * e --> b * 1 * e
        batch_size = input.size(0)

        rnn_input_list.append(embed_q)

        if self.attn_mode is not None:
            attn_memory = state.attn_memory
            attn_mask = state.attn_mask
            query = last_hidden[-1].unsqueeze(1)
            weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=attn_mask)
            rnn_input_list.append(weighted_context)
            output.add(attn=attn)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        output, new_hidden = self.gru(rnn_input, last_hidden)
        state.hidden = new_hidden

        u = [new_hidden[0].squeeze()]
        for hop in range(self.max_hops):
            m_A = self.m_story[hop]
            m_A = m_A[:batch_size]
            if(len(list(u[-1].size())) == 1):
                u[-1] = u[-1].unsqueeze(0)  # used for bsz = 1.
            u_temp = u[-1].unsqueeze(1).expand_as(m_A)
            prob_lg = torch.sum(m_A * u_temp, 2)
            prob_p = self.log_softmax(prob_lg)
            m_C = self.m_story[hop + 1]
            m_C = m_C[:batch_size]

            prob = prob_p.unsqueeze(2).expand_as(m_C)
            o_k = torch.sum(m_C * prob, 1)
            if (hop == 0):
                p_vocab = self.W1(torch.cat((u[0], o_k), 1))
                prob_v = self.log_softmax(p_vocab)
            u_k = u[-1] + o_k
            u.append(u_k)
        p_ptr = prob_lg

        if is_training:
            # p_ptr, p_vocab 是 softmax 之前的值, 不是概率
            return p_vocab, state, output
        else:
            return prob_v, state, output
예제 #7
0
    def collect_metrics(self, outputs, target):
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        logits = outputs.logits
        nll = self.nll_loss(logits, target)
        num_words = target.ne(self.padding_idx).sum().item()
        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(nll=(nll, num_words), acc=acc)
        loss += nll

        metrics.add(loss=loss)
        return metrics
예제 #8
0
파일: dssm.py 프로젝트: yunying24/Dialogue
    def forward(self, src_inputs, pos_tgt_inputs, neg_tgt_inputs,
                src_hidden=None, tgt_hidden=None):
        outputs = Pack()
        src_hidden = self.src_encoder(src_inputs, src_hidden)[1][-1]
        if self.with_project:
            src_hidden = self.project(src_hidden)

        pos_tgt_hidden = self.tgt_encoder(pos_tgt_inputs, tgt_hidden)[1][-1]
        neg_tgt_hidden = self.tgt_encoder(neg_tgt_inputs, tgt_hidden)[1][-1]
        pos_logits = (src_hidden * pos_tgt_hidden).sum(dim=-1)
        neg_logits = (src_hidden * neg_tgt_hidden).sum(dim=-1)
        outputs.add(pos_logits=pos_logits, neg_logits=neg_logits)

        return outputs
예제 #9
0
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        '''
        inputs: src, topic_src, topic_tgt, [tgt]
        '''
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2

        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        guide_score = enc_hidden[-1]
        decoder_init_state = enc_hidden

        bow_src = inputs.bow
        ntm_stat = self.ntm(bow_src)
        outputs.add(ntm_loss=ntm_stat['loss'])

        # obtain topic words
        _, tw_indices = self.ntm.get_topics().topk(self.topic_k,
                                                   dim=1)  # K * k

        src_labels = F.one_hot(inputs.topic_src_label,
                               num_classes=self.topic_num)  # B * K
        tgt_labels = F.one_hot(inputs.topic_tgt_label,
                               num_classes=self.topic_num)
        src_words = src_labels.float() @ tw_indices.float()  # B * k
        src_words = src_words.detach().long()
        tgt_words = tgt_labels.float() @ tw_indices.float()
        tgt_words = tgt_words.detach().long()

        # only src topic word
        src_outputs = self.topic_layer(src_words)  # b * k * h

        # only tgt topic word
        tgt_outputs = self.topic_layer(tgt_words)  # b * k * h

        dec_init_state = self.decoder.initialize_state(
            hidden=decoder_init_state,
            attn_memory=enc_outputs,
            memory_lengths=lengths,
            guide_score=guide_score,
            src_memory=src_outputs,
            tgt_memory=tgt_outputs,
        )
        return outputs, dec_init_state
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        input_feed= state.input_feed
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)
        input_feed = input_feed
        rnn_input = torch.cat([input,input_feed] ,dim=-1)
        rnn_output, new_hidden = self.rnn(rnn_input, hidden)



        attn_memory = state.attn_memory
        query = new_hidden[0][-1].unsqueeze(1)
        weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=state.mask.eq(0))


        final_output=torch.cat([query, weighted_context], dim=-1)

        # fusion_sigmod=self.fusion(final_output)
        #
        # fusion_hidden=fusion_sigmod*weighted_context+(1-fusion_sigmod)*query

        final_output=self.fc1(final_output)

        output.add(attn=attn)


        state.hidden = list(new_hidden)
        state.input_feed=final_output
        # state.input_feed=fusion_hidden
        out_input=final_output

        if is_training:
            return out_input, state, output
        else:
            log_prob = self.output_layer(out_input)
            return log_prob, state, output
예제 #11
0
    def collect_metrics(self, outputs, target, ptr_index, kb_index):
        """
        collect_metrics
        """
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        # loss for generation
        logits = outputs.logits
        nll = self.nll_loss(logits, target)
        loss += nll
        '''
        # loss for gate
        pad_zeros = torch.zeros([num_samples, 1], dtype=torch.long)
        if self.use_gpu:
            pad_zeros = pad_zeros.cuda()
        ptr_index = torch.cat([ptr_index, pad_zeros], dim=-1).float()
        gate_logits = outputs.gate_logits
        loss_gate = self.bce_loss(gate_logits, ptr_index)
        loss += loss_gate
        '''
        # loss for selector
        # selector_target = kb_index.float()
        # selector_logits = outputs.selector_logits
        # selector_mask = outputs.selector_mask
        #
        # if selector_target.size(-1) < selector_logits.size(-1):
        #     pad_zeros = torch.zeros(size=(num_samples, selector_logits.size(-1)-selector_target.size(-1)),
        #                             dtype=torch.float)
        #     if self.use_gpu:
        #         pad_zeros = pad_zeros.cuda()
        #     selector_target = torch.cat([selector_target, pad_zeros], dim=-1)
        # loss_ptr = self.bce_loss(selector_logits, selector_target, mask=selector_mask)
        loss_ptr = torch.tensor(0.0)
        if self.use_gpu:
            loss_ptr = loss_ptr.cuda()
        loss += loss_ptr

        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(loss=loss,
                    ptr=loss_ptr,
                    acc=acc,
                    logits=logits,
                    prob=outputs.prob)

        return metrics
예제 #12
0
    def interact(self, src, cue=None):
        if src == "":
            return None

        inputs = Pack()
        src = self.src_field.numericalize([src])
        inputs.add(src=list2tensor(src))

        if cue is not None:
            cue = self.cue_field.numericalize([cue])
            inputs.add(cue=list2tensor(cue))
        if self.use_gpu:
            inputs = inputs.cuda()
        _, preds, _, _ = self.forward(inputs=inputs, num_candidates=1)

        pred = self.tgt_field.denumericalize(preds[0][0])

        return pred
예제 #13
0
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        rnn_input_list = []
        out_input_list = []
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)
        rnn_input_list.append(input)

        if self.feature_size is not None:
            feature = state.feature.unsqueeze(1)
            rnn_input_list.append(feature)

        if self.attn_mode is not None:
            attn_memory = state.attn_memory
            attn_mask = state.attn_mask
            query = hidden[-1].unsqueeze(1)
            weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=attn_mask)
            rnn_input_list.append(weighted_context)
            out_input_list.append(weighted_context)
            output.add(attn=attn)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        rnn_output, new_hidden = self.rnn(rnn_input, hidden)
        out_input_list.append(rnn_output)

        out_input = torch.cat(out_input_list, dim=-1)
        state.hidden = new_hidden

        if is_training:
            return out_input, state, output
        else:
            log_prob = self.output_layer(out_input)
            return log_prob, state, output
예제 #14
0
    def collect_metrics(self, outputs, target, bridge=None, epoch=-1):
        """
        collect_metrics
        """
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        # response generation
        logits = outputs.logits
        nll = self.nll_loss(logits, target)
        num_words = target.ne(self.padding_idx).sum().item()
        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(nll=(nll, num_words), acc=acc)
        loss += nll

        # neural topic model
        ntm_loss = outputs.ntm_loss.sum().item()
        loss += ntm_loss / self.topic_vocab_size * 0.3

        metrics.add(loss=loss)
        return metrics
예제 #15
0
    def collect_metrics(self, outputs, target, epoch=-1):
        """
        collect_metrics
        """
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        # test begin
        # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index)
        # loss += nll
        # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index)
        # metrics.add(attn_acc=attn_acc)
        # metrics.add(loss=loss)
        # return metrics
        # test end

        logits = outputs.logits
        scores = -self.nll_loss(logits, target, reduction=False)
        nll_loss = self.nll_loss(logits, target)
        num_words = target.ne(self.padding_idx).sum().item()
        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(nll=(nll_loss, num_words), acc=acc)
        # persona loss
        if 'attn_index' in outputs:
            attn_acc = attn_accuracy(outputs.cue_attn, outputs.attn_index)
            metrics.add(attn_acc=attn_acc)
            per_logits = torch.log(outputs.cue_attn +
                                   self.eps)  # cue_attn(batch_size, sent_num)
            per_labels = outputs.attn_index  ##(batch_size)
            use_per_loss = self.persona_loss(
                per_logits, per_labels)  # per_labels(batch_size)
            metrics.add(use_per_loss=use_per_loss)
            loss += 0.7 * use_per_loss
            loss += 0.3 * nll_loss
        else:
            loss += nll_loss

        metrics.add(loss=loss)
        return metrics, scores
예제 #16
0
    def collect_metrics(self, outputs, target, emo_target):
        """
        collect_metrics
        """
        num_samples = target[0].size(0)
        num_words = target[1].sum().item()
        metrics = Pack(num_samples=num_samples)
        target_len = target[1]
        mask = sequence_mask(target_len)
        mask = mask.float()
        # logits = outputs.logits
        # nll = self.nll_loss(logits, target)
        out_copy = outputs.out_copy
        #  out_copy  batch x max_len x src
        target_loss = out_copy.gather(2, target[0].unsqueeze(-1)).squeeze(-1)
        target_loss = target_loss * mask
        target_loss += 1e-15
        target_loss = target_loss.log()
        loss = -((target_loss.sum()) / num_words)

        out_emo = outputs.logits  #  batch x  max_len x  dim
        batch_size, max_len, class_num = out_emo.size()
        # out_emo=out_emo.view(batch_size*max_len, class_num)
        # emo_target=emo_target.view(-1)
        target_emo_loss = out_emo.gather(
            2, emo_target[0].unsqueeze(-1)).squeeze(-1)
        target_len -= 1
        mask_ = sequence_mask(target_len)
        mask_ = mask_.float()
        new_mask = mask.data.new(batch_size, max_len).zero_()
        # print(mask.size())
        # print(new_mask.size())
        new_mask[:, :max_len - 1] = mask_

        target_emo_loss = target_emo_loss * new_mask
        target_emo_loss += 1e-15
        target_emo_loss = target_emo_loss.log()
        emo_loss = -((target_emo_loss.sum()) / num_words)

        metrics.add(loss=loss)
        metrics.add(emo_loss=emo_loss)
        #  这里,我们将只计算
        acc = accuracy(out_copy, target[0], mask=mask)
        metrics.add(acc=acc)
        return metrics
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2
        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)

        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        # knowledge
        batch_size, sent_num, sent = inputs.cue[0].size()
        tmp_len = inputs.cue[1]
        tmp_len[tmp_len > 0] -= 2
        cue_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], tmp_len.view(-1)
        cue_enc_outputs, cue_enc_hidden = self.knowledge_encoder(
            cue_inputs, hidden)
        cue_outputs = cue_enc_hidden[-1].view(batch_size, sent_num, -1)
        # Attention
        weighted_cue, cue_attn = self.prior_attention(
            query=enc_hidden[-1].unsqueeze(1),
            memory=cue_outputs,
            mask=inputs.cue[1].eq(0))
        cue_attn = cue_attn.squeeze(1)
        outputs.add(prior_attn=cue_attn)
        indexs = cue_attn.max(dim=1)[1]
        # hard attention
        if self.use_gs:
            knowledge = cue_outputs.gather(1, \
                indexs.view(-1, 1, 1).repeat(1, 1, cue_outputs.size(-1)))
        else:
            knowledge = weighted_cue
        if self.use_posterior:
            tgt_enc_inputs = inputs.tgt[0][:, 1:-1], inputs.tgt[1] - 2
            _, tgt_enc_hidden = self.knowledge_encoder(tgt_enc_inputs, hidden)
            posterior_weighted_cue, posterior_attn = self.posterior_attention(
                # P(z|u,r)
                # query=torch.cat([dec_init_hidden[-1], tgt_enc_hidden[-1]], dim=-1).unsqueeze(1)
                # P(z|r)
                query=tgt_enc_hidden[-1].unsqueeze(1),
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))
            posterior_attn = posterior_attn.squeeze(1)
            outputs.add(posterior_attn=posterior_attn)
            # Gumbel Softmax
            if self.use_gs:
                gumbel_attn = F.gumbel_softmax(torch.log(posterior_attn +
                                                         1e-10),
                                               0.1,
                                               hard=True)
                outputs.add(gumbel_attn=gumbel_attn)
                knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs)
                indexs = gumbel_attn.max(-1)[1]
            else:
                knowledge = posterior_weighted_cue
                indexs = posterior_attn.max(dim=1)[1]

            if self.use_bow:
                bow_logits = self.bow_output_layer(knowledge)
                outputs.add(bow_logits=bow_logits)
            if self.use_dssm:
                dssm_knowledge = self.dssm_project(knowledge)
                outputs.add(dssm=dssm_knowledge)
                outputs.add(reply_vec=tgt_enc_hidden[-1])
                # neg sample
                neg_idx = torch.arange(enc_inputs[1].size(0)).type_as(
                    enc_inputs[1])
                neg_idx = (neg_idx + 1) % neg_idx.size(0)
                neg_tgt_enc_inputs = tgt_enc_inputs[0][
                    neg_idx], tgt_enc_inputs[1][neg_idx]
                _, neg_tgt_enc_hidden = self.knowledge_encoder(
                    neg_tgt_enc_inputs, hidden)
                pos_logits = (enc_hidden[-1] * tgt_enc_hidden[-1]).sum(dim=-1)
                neg_logits = (enc_hidden[-1] *
                              neg_tgt_enc_hidden[-1]).sum(dim=-1)
                outputs.add(pos_logits=pos_logits, neg_logits=neg_logits)
        elif is_training:
            if self.use_gs:
                gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10),
                                               0.1,
                                               hard=True)
                knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs)
                indexs = gumbel_attn.max(-1)[1]
            else:
                knowledge = weighted_cue

        outputs.add(indexs=indexs)
        if 'index' in inputs.keys():
            outputs.add(attn_index=inputs.index)

        if self.use_kd:
            knowledge = self.knowledge_dropout(knowledge)

        if self.weight_control:
            weights = (enc_hidden[-1] * knowledge.squeeze(1)).sum(dim=-1)
            weights = self.sigmoid(weights)
            # norm in batch
            # weights = weights / weights.mean().item()
            outputs.add(weights=weights)
            knowledge = knowledge * weights.view(-1, 1, 1).repeat(
                1, 1, knowledge.size(-1))

        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            attn_memory=enc_outputs if self.attn_mode else None,
            memory_lengths=lengths if self.attn_mode else None,
            knowledge=knowledge)
        return outputs, dec_init_state
예제 #18
0
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        '''
	    #inputs: 嵌套形式为{分离src和target和cue->(分离数据和长度->tensor数据值    
	    #{'src':( 数据值-->shape(batch_size , sen_num , max_len), 句子长度值--> shape(batch_size,sen_num) ),
          'tgt':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ),
          'cue' :( 数据值-->shape(batch_size, max_len), 句子长度值--> shape(batch_size) ),
          'label':( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) ),
          'index': ( 数据值-->shape(batch_size , max_len), 句子长度值-->shape(batch_size) )
          }
	    '''
        outputs = Pack()
        ''' 第二阶段'''
        if self.task_id == 1:

            enc_inputs = inputs.src[0][:, 1:-1], inputs.src[1] - 2
            lengths = inputs.src[1] - 2  # (batch_size)
            enc_outputs, enc_hidden, enc_embedding = self.encoder(
                enc_inputs, hidden)
            # enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size)
            # enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size)

            if self.with_bridge:
                enc_hidden = self.bridge(enc_hidden)

            # tem_bth,tem_len,tem_hi_size =enc_outputs.size()# batch_size, max_len-2, 2*rnn_hidden_size)
            key_index, len_key_index = inputs.index[0], inputs.index[
                1]  # key_index(batch_size , idx_max_len)
            max_len = key_index.size(1)
            key_mask = sequence_mask(len_key_index, max_len).eq(
                0)  # key_mask(batch_size , idx_max_len)
            key_hidden = torch.gather(
                enc_embedding, 1,
                key_index.unsqueeze(-1).repeat(1, 1, enc_embedding.size(
                    -1)))  # (batch_size ,idx_max_len, 2*rnn_hidden_size)
            key_global = key_hidden.masked_fill(
                key_mask.unsqueeze(-1),
                0.0).sum(1) / len_key_index.unsqueeze(1).float()
            key_global = self.key_linear(
                key_global)  # (batch_size, 2*rnn_hidden_size)
            # persona_aware = torch.cat([key_global, enc_hidden[-1]], dim=-1)  # (batch_size ,2*rnn_hidden_size)
            persona_aware = key_global + enc_hidden[
                -1]  #(batch_size , 2*rnn_hidden_size)

            # persona
            batch_size, sent_num, sent = inputs.cue[0].size()
            cue_len = inputs.cue[1]  # (batch_size,sen_num)
            cue_len[cue_len > 0] -= 2  # (batch_size, sen_num)
            cue_inputs = inputs.cue[0].view(-1, sent)[:,
                                                      1:-1], cue_len.view(-1)
            # cue_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num))
            cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder(
                cue_inputs, hidden)
            # cue_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size)
            # cue_enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size)
            cue_outputs = cue_enc_hidden[-1].view(batch_size, sent_num, -1)
            cue_enc_outputs = cue_enc_outputs.view(
                batch_size, sent_num, cue_enc_outputs.size(1), -1
            )  # cue_enc_outputs:(batch_size, sent_num , max_len-2, 2*rnn_hidden_size)
            cue_len = cue_len.view(batch_size, sent_num)

            # cue_outputs:(batch_size, sent_num, 2 * rnn_hidden_size)
            # Attention
            weighted_cue1, cue_attn1 = self.persona_attention(
                query=persona_aware.unsqueeze(1),
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))
            # weighted_cue:(batch_size , 1 , 2 * rnn_hidden_size)
            persona_memory1 = weighted_cue1 + persona_aware.unsqueeze(1)
            weighted_cue2, cue_attn2 = self.persona_attention(
                query=persona_memory1,
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))
            persona_memory2 = weighted_cue2 + persona_aware.unsqueeze(1)
            weighted_cue3, cue_attn3 = self.persona_attention(
                query=persona_memory2,
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))

            cue_attn = cue_attn3.squeeze(1)
            # cue_attn:(batch_size, sent_num)
            outputs.add(cue_attn=cue_attn)
            indexs = cue_attn.max(dim=1)[1]  # (batch_size)
            if is_training:
                # gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10), 0.1, hard=True)
                # persona = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs)
                # indexs = gumbel_attn.max(-1)[1]
                # cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(1)  # (batch_size)
                persona = cue_enc_outputs.gather(
                    1,
                    indexs.view(-1, 1, 1, 1).repeat(
                        1, 1, cue_enc_outputs.size(2),
                        cue_enc_outputs.size(3))).squeeze(
                            1)  # (batch_size , max_len-2, 2*rnn_hidden_size)
                cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(
                    1)  # (batch_size)
            else:
                persona = cue_enc_outputs.gather(
                    1,
                    indexs.view(-1, 1, 1, 1).repeat(
                        1, 1, cue_enc_outputs.size(2),
                        cue_enc_outputs.size(3))).squeeze(
                            1)  # (batch_size , max_len-2, 2*rnn_hidden_size)
                cue_lengths = cue_len.gather(1, indexs.unsqueeze(1)).squeeze(
                    1)  # (batch_size)

            outputs.add(indexs=indexs)
            outputs.add(attn_index=inputs.label)  # (batch_size)

            dec_init_state = self.decoder.initialize_state(
                hidden=enc_hidden,
                attn_memory=enc_outputs if self.attn_mode else None,
                memory_lengths=lengths
                if self.attn_mode else None,  # (batch_size)
                cue_enc_outputs=
                persona,  # (batch_size, max_len-2, 2*rnn_hidden_size)
                cue_lengths=cue_lengths,  # (batch_size)
                task_id=self.task_id)

            # if 'index' in inputs.keys():
            #     outputs.add(attn_index=inputs.index)

        elif self.task_id == 0:
            ''' 第一阶段'''
            # enc_inputs:((batch_size,max_len-2), (batch_size-2))**src去头去尾
            # hidden:None
            batch_size, sent_num, sent_len = inputs.src[0].size()
            src_lengths = inputs.src[1]  # (batch_size,sent_num)
            src_lengths[src_lengths > 0] -= 2
            # src_lengths(batch_size, sent_num)
            src_inputs = inputs.src[0].view(
                -1, sent_len)[:, 1:-1], src_lengths.view(-1)
            # src_inputs:((batch_size*sent_num , max_len-2),(batch_size*sent_num))
            src_enc_outputs, enc_hidden, _ = self.encoder(src_inputs, hidden)

            if self.with_bridge:
                enc_hidden = self.bridge(enc_hidden)

            # src_enc_outputs:(batch_size*sent_num , max_len-2, 2*rnn_hidden_size)
            # enc_hidden:(层数 , batch_size*sent_num, 2 * rnn_hidden_size)
            src_outputs = torch.mean(
                enc_hidden.view(self.num_layers, batch_size, sent_num, -1),
                2)  # 池化
            # src_outputs:(层数,batch_size,  2 * rnn_hidden_size)

            # persona:((batch_size,max_len-2), (batch_size))**persona的Tensor去头去尾
            cue_inputs = inputs.cue[0][:, 1:-1], inputs.cue[1] - 2
            cue_lengths = inputs.cue[1] - 2  # (batch_size)
            cue_enc_outputs, cue_enc_hidden, _ = self.persona_encoder(
                cue_inputs, hidden)
            # cue_enc_outputs:(batch_size, max_len-2, 2*rnn_hidden_size)
            # cue_enc_hidden:(num_layer , batch_size , 2*rnn_hidden_size)

            dec_init_state = self.decoder.initialize_state(
                hidden=src_outputs,
                attn_memory=src_enc_outputs.view(
                    batch_size, sent_num, sent_len -
                    2, -1) if self.attn_mode else None,
                # (batch_size, sent_num , max_len-2, 2*rnn_hidden_size)
                memory_lengths=src_lengths
                if self.attn_mode else None,  # (batch_size,sent_num)
                cue_enc_outputs=
                cue_enc_outputs,  # (batch_size, max_len-2, 2*rnn_hidden_size)
                cue_lengths=cue_lengths,
                task_id=self.task_id  # (batch_size)
            )
        return outputs, dec_init_state
예제 #19
0
    def iterate(self,
                turn_inputs,
                kb_inputs,
                optimizer=None,
                grad_clip=None,
                is_training=True,
                method="GAN",
                mask=False):
        """
        iterate
        note: this function iterate in the whole model (muti-agent) instead of single sub_model
        """

        if isinstance(optimizer, tuple):
            optimizerG, optimizerDB, optimizerDE = optimizer

        # clear all memory before the begin of a new batch computation
        for name, model in self.named_children():
            if name.startswith("model_"):
                model.reset_memory()
                model.load_kb_memory(kb_inputs)

        # store the whole model (muti_agent)'s metric
        metrics_list_S, metrics_list_TB, metrics_list_TE = [], [], []
        metrics_list_G, metrics_list_DB, metrics_list_DE = [], [], []
        mask_list_S, length_list = [], []
        # store the whole model (muti_agent)'s loss
        total_loss_DB, total_loss_DE, total_loss_G = 0, 0, 0
        # use to compute final loss (sum of each agent's loss) per turn for the cumulated total_loss in a batch
        loss = Pack()
        # use to store kb_mask for three single model
        kd_masks = Pack()

        # compare evaluation metric (bleu/f1score) among models
        if method in ('1-3', 'GAN'):
            # TODO complete
            bleu_ENS_gt_S, bleu_ENS_gt_TB, f1score_ENS_gt_TE = True, True, True
        else:
            # compute bleu_S_gt_TB per batch (compute metric for the following training batch)
            # (key: batch/following/training)
            res_bleu = self.compare_metric(generator_1=self.generator_S,
                                           generator_2=self.generator_TB,
                                           turn_inputs=turn_inputs,
                                           kb_inputs=kb_inputs,
                                           type='bleu',
                                           data_name=self.data_name)
            if isinstance(res_bleu, tuple):
                bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu
            else:
                assert isinstance(res_bleu, bool)
                bleu_S_gt_TB, bleu_S_gt_TB_str = res_bleu, ''
            if self.model_TE is not None:
                res_f1score = self.compare_metric(
                    generator_1=self.generator_S,
                    generator_2=self.generator_TE,
                    turn_inputs=turn_inputs,
                    kb_inputs=kb_inputs,
                    type='f1score',
                    data_name=self.data_name)
                if isinstance(res_f1score, tuple):
                    f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score
                else:
                    assert isinstance(res_f1score, bool)
                    f1score_S_gt_TE, f1score_S_gt_TE_str = res_f1score, ''
        """ update discriminator """

        # clear all memory again because of cumulation of the memory in the computation of the above generator
        for name, model in self.named_children():
            if name.startswith("model_"):
                model.reset_memory()
                model.load_kb_memory(kb_inputs)

        # begin iterate (a dialogue batch)
        for i, inputs in enumerate(turn_inputs):

            for name, model in self.named_children():
                if name.startswith("model_"):
                    if model.use_gpu:
                        inputs = inputs.cuda()
                    src, src_lengths = inputs.src
                    tgt, tgt_lengths = inputs.tgt
                    task_label = inputs.task
                    gold_entity = inputs.gold_entity
                    ptr_index, ptr_lengths = inputs.ptr_index
                    kb_index, kb_index_lengths = inputs.kb_index
                    enc_inputs = src[:, 1:
                                     -1], src_lengths - 2  # filter <bos> <eos>
                    dec_inputs = tgt[:, :-1], tgt_lengths - 1  # filter <eos>
                    target = tgt[:, 1:]  # filter <bos>
                    target_mask = sequence_mask(tgt_lengths - 1)
                    kd_mask = sequence_kd_mask(tgt_lengths - 1, target, name,
                                               self.ent_idx, self.nen_idx)

                    outputs = model.forward(enc_inputs, dec_inputs)
                    metrics = model.collect_metrics(outputs, target, ptr_index,
                                                    kb_index)

                    if name == "model_S":
                        metrics_list_S.append(metrics)
                    elif name == "model_TB":
                        metrics_list_TB.append(metrics)
                    else:
                        metrics_list_TE.append(metrics)

                    kd_masks[name] = kd_mask if mask else target_mask
                    loss[name] = metrics

                    model.update_memory(
                        dialog_state_memory=outputs.dialog_state_memory,
                        kb_state_memory=outputs.kb_state_memory)

            # store necessary data for three single model
            if self.model_TE is not None:
                kd_mask_e = kd_masks.model_TE
            kd_mask_s = kd_masks.model_S
            kd_mask_b = kd_masks.model_TB
            mask_list_S.append(kd_mask_s)
            length_list.append(tgt_lengths - 1)

            assert False not in (kd_mask_b == kd_mask_e)

            errD_B = self.discriminator_update(netD=self.discriminator_B,
                                               real_data=loss.model_TB.prob,
                                               fake_data=loss.model_S.prob,
                                               lengths=tgt_lengths - 1,
                                               mask=kd_mask_b)
            errD_E = self.discriminator_update(netD=self.discriminator_E,
                                               real_data=loss.model_TE.prob,
                                               fake_data=loss.model_S.prob,
                                               lengths=tgt_lengths - 1,
                                               mask=kd_mask_e)
            # collect discriminator‘s total loss
            metrics_DB = Pack(num_samples=metrics.num_samples)
            metrics_DE = Pack(num_samples=metrics.num_samples)
            metrics_DB.add(loss=errD_B, logits=0.0, prob=0.0)
            metrics_DE.add(loss=errD_E, logits=0.0, prob=0.0)
            metrics_list_DB.append(metrics_DB)
            metrics_list_DE.append(metrics_DE)

            # update in a batch
            total_loss_DB = total_loss_DB + errD_B
            total_loss_DE = total_loss_DE + errD_E
            loss.clear()
            kd_masks.clear()

        # check loss
        if torch.isnan(total_loss_DB) or torch.isnan(total_loss_DE):
            raise ValueError("NAN loss encountered!")

        # compute and update gradient
        if is_training:
            assert not None in (optimizerDB, optimizerDE)
            optimizerDB.zero_grad()
            optimizerDE.zero_grad()
            total_loss_DB.backward()
            total_loss_DE.backward()
            if grad_clip is not None and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    parameters=self.discriminator_B.parameters(),
                    max_norm=grad_clip)
                torch.nn.utils.clip_grad_norm_(
                    parameters=self.discriminator_E.parameters(),
                    max_norm=grad_clip)
            optimizerDB.step()
            optimizerDE.step()
        """ update generator """

        # begin iterate (a dialogue batch)
        n_turn = len(metrics_list_S)
        assert n_turn == len(turn_inputs) == len(mask_list_S)
        for i in range(n_turn):
            errG, errG_B, errG_E, nll = self.generator_update(
                netG=self.model_S,
                netDB=self.discriminator_B,
                netDE=self.discriminator_E,
                fake_data=metrics_list_S[i].prob,
                length=length_list[i],
                mask=mask_list_S[i],
                nll=metrics_list_S[i].loss,
                lambda_g=self.lambda_g)

            # collect generator‘s total loss
            metrics_G = Pack(num_samples=metrics_list_S[i].num_samples)
            metrics_G.add(loss=errG,
                          loss_gb=errG_B,
                          loss_ge=errG_E,
                          loss_nll=nll,
                          logits=0.0,
                          prob=0.0)
            metrics_list_G.append(metrics_G)

            # update in a batch
            total_loss_G += errG

        # check loss
        if torch.isnan(total_loss_G):
            raise ValueError("NAN loss encountered!")

        # compute and update gradient
        if is_training:
            assert optimizerG is not None
            optimizerG.zero_grad()
            total_loss_G.backward()
            if grad_clip is not None and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(
                    parameters=self.model_S.parameters(), max_norm=grad_clip)
            optimizerG.step()

        return metrics_list_S, metrics_list_G, metrics_list_DB, metrics_list_DE
예제 #20
0
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        rnn_input_list = []
        cue_input_list = []
        out_input_list = []
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)
        rnn_input_list.append(input)
        cue_input_list.append(state.knowledge)

        if self.feature_size is not None:
            feature = state.feature.unsqueeze(1)
            rnn_input_list.append(feature)
            cue_input_list.append(feature)

        if self.attn_mode is not None:
            weighted_context, attn = self.attention(
                query=hidden[-1].unsqueeze(1),
                memory=state.src_enc_outputs,
                mask=state.src_mask)
            rnn_input_list.append(weighted_context)
            cue_input_list.append(weighted_context)
            out_input_list.append(weighted_context)
            output.add(attn=attn)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        rnn_output, rnn_hidden = self.rnn(rnn_input, hidden)

        cue_input = torch.cat(cue_input_list, dim=-1)
        cue_output, cue_hidden = self.cue_rnn(cue_input, hidden)

        h_y = self.tanh(self.fc1(rnn_hidden))
        h_cue = self.tanh(self.fc2(cue_hidden))
        if self.concat:
            new_hidden = self.fc3(torch.cat([h_y, h_cue], dim=-1))
        else:
            k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1)))
            new_hidden = k * h_y + (1 - k) * h_cue
        out_input_list.append(new_hidden[-1].unsqueeze(1))
        out_input = torch.cat(out_input_list, dim=-1)
        prob = self.output_layer(out_input)
        if self.copy:
            batch_size, sent_num, sent, _ = state.cue_enc_outputs.size()
            _, knowledge_attn = self.attention(
                query=hidden[-1].unsqueeze(1).repeat(sent_num, 1, 1),
                memory=state.cue_enc_outputs.view(batch_size * sent_num, sent,
                                                  -1),
                mask=state.cue_mask.view(batch_size * sent_num, -1))
            knowledge_attn = state.cue_attn.unsqueeze(
                2) * knowledge_attn.squeeze(1).view(batch_size, sent_num, -1)
            knowledge_attn = knowledge_attn.view(batch_size, 1, -1)
            output.add(knowledge_attn=knowledge_attn)
            p = F.softmax(self.fc4(
                torch.cat([
                    input, new_hidden[-1].unsqeeze(1), weighted_context,
                    state.knowledge
                ],
                          dim=-1)),
                          dim=-1)
            output.add(p=p)
            p = p.split(1, dim=2)
            prob = (p[0] * prob).scatter_add(2, state.src_inputs.unsqueeze(1),
                                             p[1] * attn)
            prob = prob.scatter_add(2,
                                    state.cue_inputs.view(batch_size, 1, -1),
                                    p[2] * knowledge_attn)
        log_prob = torch.log(prob + 1e-10)

        state.hidden = new_hidden

        return log_prob, state, output
예제 #21
0
    def decode(self,
               input,
               state,
               is_training=False
               ):  # 这里是每一个时间步执行一次,注意这里batch_size特指有效长度,即当前时间步无padding的样本数
        """
        decode
        """
        # hidden: src_outputs:(层数, batch_size,  2 * rnn_hidden_size)
        hidden = state.hidden
        task_id = state.task_id
        rnn_input_list = []
        cue_input_list = []
        out_input_list = []  # 为decoder的输出层做准备
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)  # (batch_size,input_size)

        input = input.unsqueeze(1)  # (batch_size , 1 , input_size)
        rnn_input_list.append(input)
        # persona = state.cue_enc_outputs  # persona:(batch_size, 1 , 2*rnn_hidden_size)这里的persona是加权和后的persona上下文

        if self.feature_size is not None:
            feature = state.feature.unsqueeze(1)
            rnn_input_list.append(feature)
            cue_input_list.append(feature)

        # 对enc_hidden作attention
        if self.attn_mode is not None:
            # 第二阶段
            if task_id == 1:
                attn_memory = state.attn_memory  # (batch_size , max_len-2, 2*rnn_hidden_size)
                attn_mask = state.attn_mask
                query = hidden[-1].unsqueeze(
                    1)  # (batch_size, 1, 2*rnn_hidden_size)
                weighted_context, attn = self.attention(
                    query=query, memory=attn_memory, mask=attn_mask
                )  #attn_mask(batch_size, num_enc_inputs)  weighted_context(batch_size,1, 2*rnn_hidden_size)

            # 第一阶段
            elif task_id == 0:
                ''' 分别对3个相似query做attention'''
                attn_memory = state.attn_memory  # (batch_size,sent_num , max_len-2, 2*rnn_hidden_size)
                batch_size, sent_num, sent_len = attn_memory.size(
                    0), attn_memory.size(1), attn_memory.size(2)
                attn_memory = attn_memory.view(
                    batch_size * sent_num, sent_len,
                    -1)  # (batch_size*sent_num , max_len-2, 2*rnn_hidden_size)
                attn_mask = state.attn_mask.view(
                    batch_size * sent_num, -1
                )  # attn_mask(batch_size*sent_num, max_len-2) 填充的0全部变成1,其他的变成0
                query = hidden[-1].unsqueeze(1).repeat(1, sent_num, 1).view(
                    batch_size * sent_num, 1,
                    -1)  # (batch_size*sent_num , 1, 2*rnn_hidden_size)
                weighted_context, attn = self.attention(
                    query=query, memory=attn_memory, mask=attn_mask
                )  # weighted_context(batch_size*sent_num, 1 , 2*rnn_hidden_size)
                weighted_context = torch.mean(
                    weighted_context.squeeze(1).view(batch_size, sent_num, -1),
                    dim=1).unsqueeze(
                        1)  # weighted_context(batch_size, 1,2*rnn_hidden_size)

            rnn_input_list.append(weighted_context)
            cue_input_list.append(weighted_context)
            out_input_list.append(weighted_context)
            output.add(attn=attn)
            ''' 对persona做attention'''
            cue_attn_memory = state.cue_enc_outputs  # (batch_size, max_len-2, 2*rnn_hidden_size)
            cue_attn_mask = state.cue_attn_mask  # (batch_size,max_len-2)
            cue_query = hidden[-1].unsqueeze(
                1)  # (batch_size, 1, 2*rnn_hidden_size)
            cue_weighted_context, cue_attn = self.per_word_attention(
                query=cue_query, memory=cue_attn_memory, mask=cue_attn_mask)
            # cue_weighted_context(batch_size, 1, 2*rnn_hidden_size)
            # cue_attn((batch_size, 1, memory_size))
            cue_input_list.append(cue_weighted_context)
            # out_input_list.append(cue_weighted_context)
            output.add(cue_attn=cue_attn)

        rnn_input = torch.cat(
            rnn_input_list, dim=-1
        )  # rnn_input(batch_size, 1 , input_size + 2*rnn_hidden_size + 2*rnn_hidden_size)
        rnn_output, rnn_hidden = self.rnn(
            rnn_input,
            hidden)  # rnn_hidden(层数, batch_size , 2*rnn_hidden_size)

        cue_input = torch.cat(cue_input_list,
                              dim=-1)  #(batch_size, 1 , 4*rnn_hidden_size)
        cue_output, cue_hidden = self.cue_rnn(
            cue_input, hidden)  #cue_hidden(1, batch_size , 2*rnn_hidden_size)

        h_y = self.tanh(self.fc1(rnn_hidden))
        h_cue = self.tanh(self.fc2(cue_hidden))
        if self.concat:
            new_hidden = self.fc3(torch.cat(
                [h_y, h_cue], dim=-1))  #(1, batch_size , 2*rnn_hidden_size)
        else:
            k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1)))
            new_hidden = k * h_y + (1 - k) * h_cue
        state.hidden = new_hidden  # (层数, batch_size , 2*rnn_hidden_size)为下一个时间步更新hidden

        out_input_list.append(
            new_hidden[-1].unsqueeze(1))  # (batch_size, 1 , 2*rnn_hidden_size)
        out_input = torch.cat(
            out_input_list, dim=-1
        )  # (batch_size, 1 , 4*rnn_hidden_size)这里是要输入给为decoder的输出层的,相当于c+h

        if is_training:
            return out_input, state, output  # out_input: 要输入给为decoder的输出层;  state:decoder隐层状态;  output:一个pack字典,包含key"attn"
        else:  # 一个时间步           #out_input(batch_size, 1 , 4*rnn_hidden_size)这里是要输入给为decoder的输出层的,相当于c+h
            log_prob = self.output_layer(out_input)
            return log_prob, state, output
    def collect_metrics(self, outputs, target, epoch=-1):
        """
        collect_metrics
        """
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        # test begin
        # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index)
        # loss += nll
        # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index)
        # metrics.add(attn_acc=attn_acc)
        # metrics.add(loss=loss)
        # return metrics
        # test end

        logits = outputs.logits
        scores = -self.nll_loss(logits, target, reduction=False)
        nll_loss = self.nll_loss(logits, target)
        num_words = target.ne(self.padding_idx).sum().item()
        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(nll=(nll_loss, num_words), acc=acc)

        if self.use_posterior:
            kl_loss = self.kl_loss(torch.log(outputs.prior_attn + 1e-10),
                                   outputs.posterior_attn.detach())
            metrics.add(kl=kl_loss)
            if self.use_bow:
                bow_logits = outputs.bow_logits
                bow_labels = target[:, :-1]
                bow_logits = bow_logits.repeat(1, bow_labels.size(-1), 1)
                bow = self.nll_loss(bow_logits, bow_labels)
                loss += bow
                metrics.add(bow=bow)
            if self.use_dssm:
                mse = self.mse_loss(outputs.dssm, outputs.reply_vec.detach())
                loss += mse
                metrics.add(mse=mse)
                pos_logits = outputs.pos_logits
                pos_target = torch.ones_like(pos_logits)
                neg_logits = outputs.neg_logits
                neg_target = torch.zeros_like(neg_logits)
                pos_loss = F.binary_cross_entropy_with_logits(pos_logits,
                                                              pos_target,
                                                              reduction='none')
                neg_loss = F.binary_cross_entropy_with_logits(neg_logits,
                                                              neg_target,
                                                              reduction='none')
                loss += (pos_loss + neg_loss).mean()
                metrics.add(pos_loss=pos_loss.mean(), neg_loss=neg_loss.mean())

            if epoch == -1 or epoch > self.pretrain_epoch or \
               (self.use_bow is not True and self.use_dssm is not True):
                loss += nll_loss
                loss += kl_loss
                if self.use_pg:
                    posterior_probs = outputs.posterior_attn.gather(
                        1, outputs.indexs.view(-1, 1))
                    reward = -perplexity(logits, target, self.weight,
                                         self.padding_idx) * 100
                    pg_loss = -(reward.detach() -
                                self.baseline) * posterior_probs.view(-1)
                    pg_loss = pg_loss.mean()
                    loss += pg_loss
                    metrics.add(pg_loss=pg_loss, reward=reward.mean())
            if 'attn_index' in outputs:
                attn_acc = attn_accuracy(outputs.posterior_attn,
                                         outputs.attn_index)
                metrics.add(attn_acc=attn_acc)
        else:
            loss += nll_loss

        metrics.add(loss=loss)
        return metrics, scores
예제 #23
0
파일: rewards.py 프로젝트: leehamw/CRMN
def reward_fn1(self, preds, targets, gold_ents, ptr_index, task_label):
    """
    reward_fn1
    General reward
    """
    # parameters
    alpha1 = 1.0
    alpha2 = 0.3

    # acc reward
    '''
    # get the weighted mask
    no_padding_mask = preds.ne(self.padding_idx).float()
    trues = (preds == targets).float()
    if self.padding_idx is not None:
        weights = no_padding_mask
        acc = (weights * trues).sum(dim=1) / weights.sum(dim=1)
    else:
        acc = trues.mean(dim=1)
    '''

    pred_text = self.tgt_field.denumericalize(preds)
    tgt_text = self.tgt_field.denumericalize(targets)
    batch_size = targets.size(0)
    batch_kb_inputs = self.kbs[:batch_size, :, :]
    kb_plain = self.kb_field.denumericalize(batch_kb_inputs)

    result = Pack()
    result.add(pred_text=pred_text, tgt_text=tgt_text, gold_ents=gold_ents, kb_plain=kb_plain)
    result_list = result.flatten()

    # bleu reward
    bleu_score = []
    for res in result_list:
        hyp_toks = res.pred_text.split()
        ref_toks = res.tgt_text.split()
        try:
            bleu_1 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks,
                                   smoothing_function=SmoothingFunction().method7,
                                   weights=[1, 0, 0, 0])
        except:
            bleu_1 = 0
        try:
            bleu_2 = sentence_bleu(references=[ref_toks], hypothesis=hyp_toks,
                                   smoothing_function=SmoothingFunction().method7,
                                   weights=[0.5, 0.5, 0, 0])
        except:
            bleu_2 = 0
        bleu = (bleu_1 + bleu_2) / 2
        bleu_score.append(bleu)
    bleu_score = torch.tensor(bleu_score, dtype=torch.float)

    # entity f1 reward
    f1_score = []
    report_f1 = []
    for res in result_list:
        if len(res.gold_ents) == 0:
            f1_pred = 1.0
        else:
            # TODO: change the way
            #gold_entity = ' '.join(res.gold_ents).replace('_', ' ').split()
            #pred_sent = res.pred_text.replace('_', ' ')
            gold_entity = res.gold_ents
            pred_sent = res.pred_text
            f1_pred, _ = compute_prf(gold_entity, pred_sent,
                                     global_entity_list=[], kb_plain=res.kb_plain)
            report_f1.append(f1_pred)
        f1_score.append(f1_pred)
    if len(report_f1) == 0:
        report_f1.append(0.0)
    f1_score = torch.tensor(f1_score, dtype=torch.float)
    report_f1 = torch.tensor(report_f1, dtype=torch.float)

    if self.use_gpu:
        bleu_score = bleu_score.cuda()
        f1_score = f1_score.cuda()
        report_f1 = report_f1.cuda()

    # compound reward
    #reward = alpha1 * bleu_score.unsqueeze(-1) + alpha2 * f1_score.unsqueeze(-1)
    reward = alpha1 * bleu_score.unsqueeze(-1)

    return reward, bleu_score, report_f1
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        # inputs就是一个batch的数据{'src':batch_size条,'tgt':batch_size条,'cue':batch_size条}
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[
            1] - 2  # 在field.py中str2num的时候,在每个句子前后都会加bos,eos
        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        # enc_inputs: (batch_size, seq_len)
        # enc_output: (batch_size, seq_len, num_directions * hidden_size
        # enc_hidden: (num_layers * num_directions, batch_size, hidden_size)->(num_layers, batch_size, num_directions * hidden_size)
        #          此处(1, batch_size, num_directions*hidden_size)取[-1]变成(batch_size,num_directions*hidden_size)
        # 这里由于rnn_encoder.py的实现,2*hidden_size = config.hidden_size

        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        # knowledge
        batch_size, sent_num, sent = inputs.cue[0].size(
        )  # cue[0] for knowledge content, 3D
        tmp_len = inputs.cue[1]
        tmp_len[tmp_len > 0] -= 2  # 去掉bos, eos
        cue_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], tmp_len.view(
            -1)  # 1:-1去掉bos, eos
        cue_enc_outputs, cue_enc_hidden = self.knowledge_encoder(
            cue_inputs, hidden)
        cue_outputs = cue_enc_hidden[-1].view(
            batch_size, sent_num,
            -1)  # cue_enc_hidden[-1]每条knowledge的表示, cue比src, tgt多一维
        # Attention
        weighted_cue, cue_attn = self.prior_attention(
            query=enc_hidden[-1].unsqueeze(1),
            memory=cue_outputs,
            mask=inputs.cue[1].eq(0))
        cue_attn = cue_attn.squeeze(1)
        outputs.add(prior_attn=cue_attn)
        indexs = cue_attn.max(dim=1)[1]
        # hard attention 取max值
        if self.use_gs:
            knowledge = cue_outputs.gather(1, \
                indexs.view(-1, 1, 1).repeat(1, 1, cue_outputs.size(-1)))
        else:
            knowledge = weighted_cue

        if self.use_posterior:  # p(k|y) not p(k|x,y)
            tgt_enc_inputs = inputs.tgt[0][:, 1:-1], inputs.tgt[1] - 2
            _, tgt_enc_hidden = self.knowledge_encoder(tgt_enc_inputs, hidden)
            posterior_weighted_cue, posterior_attn = self.posterior_attention(
                # P(z|u,r)
                # query=torch.cat([dec_init_hidden[-1], tgt_enc_hidden[-1]], dim=-1).unsqueeze(1)
                # P(z|r)
                query=tgt_enc_hidden[-1].unsqueeze(1),
                memory=cue_outputs,
                mask=inputs.cue[1].eq(0))
            posterior_attn = posterior_attn.squeeze(1)
            outputs.add(posterior_attn=posterior_attn)
            # Gumbel Softmax
            if self.use_gs:
                gumbel_attn = F.gumbel_softmax(torch.log(posterior_attn +
                                                         1e-10),
                                               0.1,
                                               hard=True)  # 防止log内为0
                outputs.add(gumbel_attn=gumbel_attn)
                knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs)
                indexs = gumbel_attn.max(-1)[1]
            else:
                knowledge = posterior_weighted_cue
                indexs = posterior_attn.max(dim=1)[1]

            if self.use_bow:
                bow_logits = self.bow_output_layer(knowledge)
                outputs.add(bow_logits=bow_logits)
            if self.use_dssm:
                dssm_knowledge = self.dssm_project(knowledge)
                outputs.add(dssm=dssm_knowledge)
                outputs.add(reply_vec=tgt_enc_hidden[-1])
                # neg sample
                neg_idx = torch.arange(enc_inputs[1].size(0)).type_as(
                    enc_inputs[1])
                neg_idx = (neg_idx + 1) % neg_idx.size(0)
                neg_tgt_enc_inputs = tgt_enc_inputs[0][
                    neg_idx], tgt_enc_inputs[1][neg_idx]
                _, neg_tgt_enc_hidden = self.knowledge_encoder(
                    neg_tgt_enc_inputs, hidden)
                pos_logits = (enc_hidden[-1] * tgt_enc_hidden[-1]).sum(dim=-1)
                neg_logits = (enc_hidden[-1] *
                              neg_tgt_enc_hidden[-1]).sum(dim=-1)
                outputs.add(pos_logits=pos_logits, neg_logits=neg_logits)
        elif is_training:
            if self.use_gs:
                gumbel_attn = F.gumbel_softmax(torch.log(cue_attn + 1e-10),
                                               0.1,
                                               hard=True)
                knowledge = torch.bmm(gumbel_attn.unsqueeze(1), cue_outputs)
                indexs = gumbel_attn.max(-1)[1]
            else:
                knowledge = weighted_cue

        outputs.add(indexs=indexs)
        if 'index' in inputs.keys():
            outputs.add(attn_index=inputs.index)

        if self.use_kd:
            knowledge = self.knowledge_dropout(knowledge)

        if self.weight_control:  # 给knowledge表示再加权,权重就是和enc_hidden[-1]的相似度
            weights = (enc_hidden[-1] * knowledge.squeeze(1)).sum(dim=-1)
            weights = self.sigmoid(weights)
            # norm in batch
            # weights = weights / weights.mean().item()
            outputs.add(weights=weights)
            knowledge = knowledge * weights.view(-1, 1, 1).repeat(
                1, 1, knowledge.size(-1))

        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            attn_memory=enc_outputs if self.attn_mode else None,
            memory_lengths=lengths if self.attn_mode else None,
            knowledge=knowledge)
        return outputs, dec_init_state
예제 #25
0
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        rnn_input_list = []
        out_input_list = []
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)

        rnn_input_list.append(input)
        out_input_list.append(input)

        if self.attn_mode is not None:
            # (batch_size, 1, hidden_size)
            query = hidden[-1].unsqueeze(1)

            # history attention
            weighted_hist, attn_h = self.hist_attention(query=query,
                                                        memory=state.attn_hist,
                                                        mask=state.hist_mask)
            rnn_input_list.append(weighted_hist)
            out_input_list.append(weighted_hist)
            output.add(attn_h=attn_h)

            # fact attention
            weighted_fact, attn_f = self.fact_attention(query=query,
                                                        memory=state.attn_fact,
                                                        mask=state.fact_mask)
            rnn_input_list.append(weighted_fact)
            out_input_list.append(weighted_fact)
            output.add(attn_f=attn_f)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        rnn_output, new_hidden = self.rnn(rnn_input, hidden)
        out_input_list.append(rnn_output)

        # cat (fact_hidden, hist_hidden, hidden, x)
        # (batch_size, 1, out_input_size)
        out_input = torch.cat(out_input_list, dim=-1)
        state.hidden = new_hidden

        if is_training:
            return out_input, state, output
        else:
            p_mode = self.ff(out_input)

            # prob_hist = input.new_zeros(
            #     size=(batch_size, 1, self.output_size),
            #     dtype=torch.float)

            # prob_fact = input.new_zeros(
            #     size=(batch_size, 1, self.output_size),
            #     dtype=torch.float)

            prob_vocab = self.output_layer(out_input)

            weighted_prob = prob_vocab * p_mode[:, :, 0].unsqueeze(2)
            weighted_f = output.attn_f * p_mode[:, :, 1].unsqueeze(2)
            weighted_h = output.attn_h * p_mode[:, :, 2].unsqueeze(2)
            weighted_prob = convert_dist(weighted_h, state.hist, weighted_prob)
            weighted_prob = convert_dist(weighted_f, state.fact, weighted_prob)

            # a = torch.cat((prob_vocab, prob_hist, prob_fact), -
            #               1).view(batch_size * 1, self.output_size, -1)
            # b = p_mode.view(batch_size * 1, -1).unsqueeze(2)

            # prob = torch.bmm(a, b).squeeze().view(batch_size, 1, -1)

            log_prob = torch.log(weighted_prob + 1e-10)
            return log_prob, state, output