コード例 #1
0
ファイル: dssm.py プロジェクト: blair101/SPO
    def collect_metrics(self, outputs):
        """
        collect_metrics
        """
        pos_logits = outputs.pos_logits
        pos_target = torch.ones_like(pos_logits)

        neg_logits = outputs.neg_logits
        neg_target = torch.zeros_like(neg_logits)

        pos_loss = F.binary_cross_entropy_with_logits(pos_logits,
                                                      pos_target,
                                                      reduction='none')
        neg_loss = F.binary_cross_entropy_with_logits(neg_logits,
                                                      neg_target,
                                                      reduction='none')

        loss = (pos_loss + neg_loss).mean()

        pos_acc = torch.sigmoid(pos_logits).gt(0.5).float().mean()
        neg_acc = torch.sigmoid(neg_logits).lt(0.5).float().mean()
        margin = (torch.sigmoid(pos_logits) - torch.sigmoid(neg_logits)).mean()
        metrics = Pack(loss=loss,
                       pos_acc=pos_acc,
                       neg_acc=neg_acc,
                       margin=margin)

        num_samples = pos_target.size(0)
        metrics.add(num_samples=num_samples)
        return metrics
コード例 #2
0
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        rnn_input_list = []
        cue_input_list = []
        out_input_list = []
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)
        rnn_input_list.append(input)
        cue_input_list.append(state.knowledge)

        if self.feature_size is not None:
            feature = state.feature.unsqueeze(1)
            rnn_input_list.append(feature)
            cue_input_list.append(feature)

        if self.attn_mode is not None:
            attn_memory = state.attn_memory
            attn_mask = state.attn_mask
            query = hidden[-1].unsqueeze(1)
            weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=attn_mask)
            rnn_input_list.append(weighted_context)
            cue_input_list.append(weighted_context)
            out_input_list.append(weighted_context)
            output.add(attn=attn)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        rnn_output, rnn_hidden = self.rnn(rnn_input, hidden)

        cue_input = torch.cat(cue_input_list, dim=-1)
        cue_output, cue_hidden = self.cue_rnn(cue_input, hidden)

        h_y = self.tanh(self.fc1(rnn_hidden))
        h_cue = self.tanh(self.fc2(cue_hidden))
        if self.concat:
            new_hidden = self.fc3(torch.cat([h_y, h_cue], dim=-1))
        else:
            k = self.sigmoid(self.fc3(torch.cat([h_y, h_cue], dim=-1)))
            new_hidden = k * h_y + (1 - k) * h_cue
        #out_input_list.append(new_hidden.transpose(0, 1))
        # bug fixed
        out_input_list.append(new_hidden[-1].unsqueeze(1))

        out_input = torch.cat(out_input_list, dim=-1)
        state.hidden = new_hidden

        if is_training:
            return out_input, state, output
        else:
            log_prob = self.output_layer(out_input)
            return log_prob, state, output
コード例 #3
0
        def collate(data_list):
            """
            collate
            """
            batch = Pack()
            # 手写各个结构
            # num_src
            batch['num_src'] = list2tensor([x['num_src'] for x in data_list])
            # num_tgt_input
            batch['num_tgt_input'] = list2tensor(
                [x['num_tgt_input'] for x in data_list])
            #  tgt_output
            batch['tgt_output'] = list2tensor(
                [x['tgt_output'] for x in data_list])
            batch['tgt_emo'] = list2tensor([x['tgt_emo'] for x in data_list])
            # mask
            batch['mask'] = list2tensor([x['mask'] for x in data_list])
            batch['raw_src'] = [x['raw_src'] for x in data_list]
            batch['raw_tgt'] = [x['raw_tgt'] for x in data_list]

            if 'id' in data_list[0].keys():
                batch['id'] = [x['id'] for x in data_list]

            if device >= 0:
                batch = batch.cuda(device=device)
            return batch
コード例 #4
0
ファイル: dataset.py プロジェクト: yunying24/Dialogue
 def collate(data_list):
     batch = Pack()
     for key in data_list[0].keys():
         batch[key] = list2tensor([x[key] for x in data_list])
     if device >= 0:
         batch = batch.cuda(device=device)
     return batch
コード例 #5
0
    def collect_rl_metrics(self, sample_outputs, greedy_outputs, target, gold_entity, entity_dir):
        """
        collect rl training metrics
        """
        num_samples = target.size(0)
        rl_metrics = Pack(num_samples=num_samples)
        loss = 0

        # log prob for sampling and greedily generation
        logits = sample_outputs.logits
        sample = sample_outputs.pred_word
        greedy_logits = greedy_outputs.logits
        greedy_sample = greedy_outputs.pred_word

        # cal reward
        sample_reward, _, _ = self.reward_fn(sample, target, gold_entity, entity_dir)
        greedy_reward, bleu_score, f1_score = self.reward_fn(greedy_sample, target, gold_entity, entity_dir)
        reward = sample_reward - greedy_reward

        # cal RL loss
        sample_log_prob = self.nll_loss(logits, sample, mask=sample.ne(self.padding_idx),
                                        reduction=False, matrix=False)  # [batch_size, max_len]
        nll = sample_log_prob * reward.to(sample_log_prob.device)
        nll = getattr(torch, self.nll_loss.reduction)(nll.sum(dim=-1))
        loss += nll

        # gen report
        rl_acc = accuracy(greedy_logits, target, padding_idx=self.padding_idx)
        if reward.dim() == 2:
            reward = reward.sum(dim=-1)
        rl_metrics.add(loss=loss, reward=reward.mean(), rl_acc=rl_acc,
                       bleu_score=bleu_score.mean(), f1_score=f1_score.mean())

        return rl_metrics
コード例 #6
0
ファイル: hntm_tg.py プロジェクト: wizare/CMTE
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        '''
        inputs: src, topic_src, topic_tgt, [tgt]
        '''
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2

        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        guide_score = enc_hidden[-1]
        decoder_init_state = enc_hidden

        if self.use_ntm:
            bow_src = inputs.bow
            ntm_stat = self.ntm(bow_src)
            outputs.add(ntm_loss=ntm_stat['loss'])

        embedding_weight = self.topic_embedder.weight
        topic_word_logit = self.ntm.topics.get_topic_word_logit()
        EPS = 1e-12
        w = topic_word_logit.transpose(0, 1)  # V x K
        nv = embedding_weight / (
            torch.norm(embedding_weight, dim=1, keepdim=True) + EPS)  # V x dim
        nw = w / (torch.norm(w, dim=0, keepdim=True) + EPS)  # V x K
        t = nv.transpose(0, 1) @ w  # dim x K
        t = t.transpose(0, 1)

        topic_feature_input = []
        if 'S' in self.decoder_attention_channels:
            src_labels = F.one_hot(inputs.topic_src_label,
                                   num_classes=self.topic_num)  # B * K
            src_topics = src_labels.float() @ t
            topic_feature_input.append(src_topics)
        if 'T' in self.decoder_attention_channels:
            tgt_labels = F.one_hot(inputs.topic_tgt_label,
                                   num_classes=self.topic_num)
            tgt_topics = tgt_labels.float() @ t
            topic_feature_input.append(tgt_topics)

        topic_feature = self.t_to_feature(
            torch.cat(topic_feature_input, dim=-1))

        dec_init_state = self.decoder.initialize_state(
            hidden=decoder_init_state,
            attn_memory=enc_outputs,
            memory_lengths=lengths,
            guide_score=guide_score,
            topic_feature=topic_feature)
        return outputs, dec_init_state
コード例 #7
0
 def collate(
         data_list):  # data_list的长度就是一个batch_size,每个元素都是__getitem__得到的
     """
     collate
     """
     batch = Pack()
     for key in data_list[0].keys():  # keys(): src, tgt, cue
         batch[key] = list2tensor([x[key] for x in data_list
                                   ])  # 所有的src, tgt, cue分别整合在一起
     if device >= 0:
         batch = batch.cuda(device=device)
     return batch
コード例 #8
0
 def collate(data_list):
     """
     collate
     """
     batch = Pack()
     # batch is a dict
     for key in data_list[0].keys():
         # data_list: a list of dict
         # so one sample is one dict
         batch[key] = list2tensor([x[key] for x in data_list])
     if device >= 0:
         batch = batch.cuda(device=device)
     return batch
コード例 #9
0
ファイル: mem_decoder.py プロジェクト: little-women/lic-2019
    def decode(self, input, state, is_training=False):
        last_hidden = state.hidden
        rnn_input_list = []
        output = Pack()

        embed_q = self.C[0](input).unsqueeze(1)  # b * e --> b * 1 * e
        batch_size = input.size(0)

        rnn_input_list.append(embed_q)

        if self.attn_mode is not None:
            attn_memory = state.attn_memory
            attn_mask = state.attn_mask
            query = last_hidden[-1].unsqueeze(1)
            weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=attn_mask)
            rnn_input_list.append(weighted_context)
            output.add(attn=attn)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        output, new_hidden = self.gru(rnn_input, last_hidden)
        state.hidden = new_hidden

        u = [new_hidden[0].squeeze()]
        for hop in range(self.max_hops):
            m_A = self.m_story[hop]
            m_A = m_A[:batch_size]
            if(len(list(u[-1].size())) == 1):
                u[-1] = u[-1].unsqueeze(0)  # used for bsz = 1.
            u_temp = u[-1].unsqueeze(1).expand_as(m_A)
            prob_lg = torch.sum(m_A * u_temp, 2)
            prob_p = self.log_softmax(prob_lg)
            m_C = self.m_story[hop + 1]
            m_C = m_C[:batch_size]

            prob = prob_p.unsqueeze(2).expand_as(m_C)
            o_k = torch.sum(m_C * prob, 1)
            if (hop == 0):
                p_vocab = self.W1(torch.cat((u[0], o_k), 1))
                prob_v = self.log_softmax(p_vocab)
            u_k = u[-1] + o_k
            u.append(u_k)
        p_ptr = prob_lg

        if is_training:
            # p_ptr, p_vocab 是 softmax 之前的值, 不是概率
            return p_vocab, state, output
        else:
            return prob_v, state, output
コード例 #10
0
ファイル: dssm.py プロジェクト: yunying24/Dialogue
    def forward(self, src_inputs, pos_tgt_inputs, neg_tgt_inputs,
                src_hidden=None, tgt_hidden=None):
        outputs = Pack()
        src_hidden = self.src_encoder(src_inputs, src_hidden)[1][-1]
        if self.with_project:
            src_hidden = self.project(src_hidden)

        pos_tgt_hidden = self.tgt_encoder(pos_tgt_inputs, tgt_hidden)[1][-1]
        neg_tgt_hidden = self.tgt_encoder(neg_tgt_inputs, tgt_hidden)[1][-1]
        pos_logits = (src_hidden * pos_tgt_hidden).sum(dim=-1)
        neg_logits = (src_hidden * neg_tgt_hidden).sum(dim=-1)
        outputs.add(pos_logits=pos_logits, neg_logits=neg_logits)

        return outputs
コード例 #11
0
    def encode(self, inputs, hidden=None, is_training=False):
        """
        encode
        """
        '''
        inputs: src, topic_src, topic_tgt, [tgt]
        '''
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2

        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        guide_score = enc_hidden[-1]
        decoder_init_state = enc_hidden

        bow_src = inputs.bow
        ntm_stat = self.ntm(bow_src)
        outputs.add(ntm_loss=ntm_stat['loss'])

        # obtain topic words
        _, tw_indices = self.ntm.get_topics().topk(self.topic_k,
                                                   dim=1)  # K * k

        src_labels = F.one_hot(inputs.topic_src_label,
                               num_classes=self.topic_num)  # B * K
        tgt_labels = F.one_hot(inputs.topic_tgt_label,
                               num_classes=self.topic_num)
        src_words = src_labels.float() @ tw_indices.float()  # B * k
        src_words = src_words.detach().long()
        tgt_words = tgt_labels.float() @ tw_indices.float()
        tgt_words = tgt_words.detach().long()

        # only src topic word
        src_outputs = self.topic_layer(src_words)  # b * k * h

        # only tgt topic word
        tgt_outputs = self.topic_layer(tgt_words)  # b * k * h

        dec_init_state = self.decoder.initialize_state(
            hidden=decoder_init_state,
            attn_memory=enc_outputs,
            memory_lengths=lengths,
            guide_score=guide_score,
            src_memory=src_outputs,
            tgt_memory=tgt_outputs,
        )
        return outputs, dec_init_state
コード例 #12
0
    def encode(self, enc_inputs, hidden=None):
        """
        encode
        """
        outputs = Pack()
        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)
        inputs, lengths = enc_inputs
        batch_size = enc_outputs.size(0)
        max_len = enc_outputs.size(1)
        attn_mask = sequence_mask(lengths, max_len).eq(0)

        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        # insert dialog memory
        if self.dialog_state_memory is None:
            assert self.dialog_history_memory is None
            assert self.history_index is None
            assert self.memory_masks is None
            self.dialog_state_memory = enc_outputs
            self.dialog_history_memory = enc_outputs
            self.history_index = inputs
            self.memory_masks = attn_mask
        else:
            batch_state_memory = self.dialog_state_memory[:batch_size, :, :]
            self.dialog_state_memory = torch.cat([batch_state_memory, enc_outputs], dim=1)
            batch_history_memory = self.dialog_history_memory[:batch_size, :, :]
            self.dialog_history_memory = torch.cat([batch_history_memory, enc_outputs], dim=1)
            batch_history_index = self.history_index[:batch_size, :]
            self.history_index = torch.cat([batch_history_index, inputs], dim=-1)
            batch_memory_masks = self.memory_masks[:batch_size, :]
            self.memory_masks = torch.cat([batch_memory_masks, attn_mask], dim=-1)

        batch_kb_inputs = self.kbs[:batch_size, :, :]
        batch_kb_state_memory = self.kb_state_memory[:batch_size, :, :]
        batch_kb_slot_memory = self.kb_slot_memory[:batch_size, :, :]
        batch_kb_slot_index = self.kb_slot_index[:batch_size, :]
        kb_mask = self.kb_mask[:batch_size, :]
        selector_mask = self.selector_mask[:batch_size, :]

        # create batched KB inputs
        kb_memory, selector = self.decoder.initialize_kb(kb_inputs=batch_kb_inputs, enc_hidden=enc_hidden)

        # initialize decoder state
        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            state_memory=self.dialog_state_memory,
            history_memory=self.dialog_history_memory,
            kb_memory=kb_memory,
            kb_state_memory=batch_kb_state_memory,
            kb_slot_memory=batch_kb_slot_memory,
            history_index=self.history_index,
            kb_slot_index=batch_kb_slot_index,
            attn_mask=self.memory_masks,
            attn_kb_mask=kb_mask,
            selector=selector,
            selector_mask=selector_mask
        )

        return outputs, dec_init_state
コード例 #13
0
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        input_feed= state.input_feed
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)
        input_feed = input_feed
        rnn_input = torch.cat([input,input_feed] ,dim=-1)
        rnn_output, new_hidden = self.rnn(rnn_input, hidden)



        attn_memory = state.attn_memory
        query = new_hidden[0][-1].unsqueeze(1)
        weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=state.mask.eq(0))


        final_output=torch.cat([query, weighted_context], dim=-1)

        # fusion_sigmod=self.fusion(final_output)
        #
        # fusion_hidden=fusion_sigmod*weighted_context+(1-fusion_sigmod)*query

        final_output=self.fc1(final_output)

        output.add(attn=attn)


        state.hidden = list(new_hidden)
        state.input_feed=final_output
        # state.input_feed=fusion_hidden
        out_input=final_output

        if is_training:
            return out_input, state, output
        else:
            log_prob = self.output_layer(out_input)
            return log_prob, state, output
コード例 #14
0
    def collect_metrics(self, outputs, target, ptr_index, kb_index):
        """
        collect_metrics
        """
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        # loss for generation
        logits = outputs.logits
        nll = self.nll_loss(logits, target)
        loss += nll
        '''
        # loss for gate
        pad_zeros = torch.zeros([num_samples, 1], dtype=torch.long)
        if self.use_gpu:
            pad_zeros = pad_zeros.cuda()
        ptr_index = torch.cat([ptr_index, pad_zeros], dim=-1).float()
        gate_logits = outputs.gate_logits
        loss_gate = self.bce_loss(gate_logits, ptr_index)
        loss += loss_gate
        '''
        # loss for selector
        # selector_target = kb_index.float()
        # selector_logits = outputs.selector_logits
        # selector_mask = outputs.selector_mask
        #
        # if selector_target.size(-1) < selector_logits.size(-1):
        #     pad_zeros = torch.zeros(size=(num_samples, selector_logits.size(-1)-selector_target.size(-1)),
        #                             dtype=torch.float)
        #     if self.use_gpu:
        #         pad_zeros = pad_zeros.cuda()
        #     selector_target = torch.cat([selector_target, pad_zeros], dim=-1)
        # loss_ptr = self.bce_loss(selector_logits, selector_target, mask=selector_mask)
        loss_ptr = torch.tensor(0.0)
        if self.use_gpu:
            loss_ptr = loss_ptr.cuda()
        loss += loss_ptr

        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(loss=loss,
                    ptr=loss_ptr,
                    acc=acc,
                    logits=logits,
                    prob=outputs.prob)

        return metrics
コード例 #15
0
    def decode(self, input, state, is_training=False):
        """
        decode
        """
        hidden = state.hidden
        rnn_input_list = []
        out_input_list = []
        output = Pack()

        if self.embedder is not None:
            input = self.embedder(input)

        # shape: (batch_size, 1, input_size)
        input = input.unsqueeze(1)
        rnn_input_list.append(input)

        if self.feature_size is not None:
            feature = state.feature.unsqueeze(1)
            rnn_input_list.append(feature)

        if self.attn_mode is not None:
            attn_memory = state.attn_memory
            attn_mask = state.attn_mask
            query = hidden[-1].unsqueeze(1)
            weighted_context, attn = self.attention(query=query,
                                                    memory=attn_memory,
                                                    mask=attn_mask)
            rnn_input_list.append(weighted_context)
            out_input_list.append(weighted_context)
            output.add(attn=attn)

        rnn_input = torch.cat(rnn_input_list, dim=-1)
        rnn_output, new_hidden = self.rnn(rnn_input, hidden)
        out_input_list.append(rnn_output)

        out_input = torch.cat(out_input_list, dim=-1)
        state.hidden = new_hidden

        if is_training:
            return out_input, state, output
        else:
            log_prob = self.output_layer(out_input)
            return log_prob, state, output
コード例 #16
0
    def generate(self,batch_iter, num_batches):
        self.model.eval()
        itoemo = ['NORM', 'POS', 'NEG']
        with torch.no_grad():
            results = []
            for inputs in batch_iter:
                enc_inputs = inputs
                dec_inputs = inputs.num_tgt_input
                enc_outputs=Pack()
                outputs = self.model.forward(enc_inputs, dec_inputs, hidden=None)
                outputs=outputs.logits
                preds = outputs.max(dim=2)
                # news_id = inputs.id
                tgt_raw = inputs.raw_tgt
                preds = preds[1].tolist()

                temp_a_1=[]
                emo_b_1=[]
                temp = []
                tgt_emo = inputs.tgt_emo[0].tolist()
                for  a, b, c in zip( tgt_raw, preds, tgt_emo):
                    # enc_outputs.add(preds=preds, scores=scores, emos=emos, target_emos=temp)
                    # result_batch = enc_outputs.flatten()
                    # results += result_batch
                    a = a[1:]
                    temp_a = []
                    emo_b = []
                    emo_c=[]

                    for i, entity in enumerate(a):
                        temp_a.append(entity)
                        emo_b.append(itoemo[b[i]])
                        emo_c.append(itoemo[c[i]])

                        # tgt_raw=tgt_raw[1:]
                    assert len(temp_a) == len(emo_b)
                    assert len(emo_c) == len(emo_b)
                    temp_a_1.append([temp_a])   #pred1
                    emo_b_1.append([emo_b])   # emo
                    temp.append(emo_c)

                # temp = []
                # tgt_emo = inputs.tgt_emo[0].tolist()
                # for item in tgt_emo:
                #     temp.append([itoemo[x] for x in item])

                # print(emo_b_1)
                # print(temp)

                if hasattr(inputs, 'id') and inputs.id is not None:
                    enc_outputs.add(id=inputs['id'])
                enc_outputs.add(tgt=tgt_raw, preds=temp_a_1, emos=emo_b_1, target_emos=temp)
                result_batch=enc_outputs.flatten()
                results+=result_batch
            return results
コード例 #17
0
    def collect_metrics(self, outputs, target, epoch=-1):
        """
        collect_metrics
        """
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        # test begin
        # nll = self.nll(torch.log(outputs.posterior_attn+1e-10), outputs.attn_index)
        # loss += nll
        # attn_acc = attn_accuracy(outputs.posterior_attn, outputs.attn_index)
        # metrics.add(attn_acc=attn_acc)
        # metrics.add(loss=loss)
        # return metrics
        # test end

        logits = outputs.logits
        scores = -self.nll_loss(logits, target, reduction=False)
        nll_loss = self.nll_loss(logits, target)
        num_words = target.ne(self.padding_idx).sum().item()
        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(nll=(nll_loss, num_words), acc=acc)
        # persona loss
        if 'attn_index' in outputs:
            attn_acc = attn_accuracy(outputs.cue_attn, outputs.attn_index)
            metrics.add(attn_acc=attn_acc)
            per_logits = torch.log(outputs.cue_attn +
                                   self.eps)  # cue_attn(batch_size, sent_num)
            per_labels = outputs.attn_index  ##(batch_size)
            use_per_loss = self.persona_loss(
                per_logits, per_labels)  # per_labels(batch_size)
            metrics.add(use_per_loss=use_per_loss)
            loss += 0.7 * use_per_loss
            loss += 0.3 * nll_loss
        else:
            loss += nll_loss

        metrics.add(loss=loss)
        return metrics, scores
コード例 #18
0
        def collate(data_list):
            """
            collate
            """
            batch = Pack()
            for key in data_list[0].keys():
                if key == 'topic':
                    continue
                batch[key] = list2tensor([x[key] for x in data_list])

            batch_bow = []
            for x in data_list:
                v = torch.zeros(bow_vocab_size, dtype=torch.float)
                x_bow = x['topic']  # dict
                for w, f in x_bow:
                    v[w] += f
                batch_bow.append(v)
            batch['bow'] = torch.stack(batch_bow)

            if device >= 0:
                batch = batch.cuda(device=device)
            return batch
コード例 #19
0
        def collate(data_list):
            """
            collate
            ---
            data_list: List[Dict]
            """
            batch = Pack()
            for key in data_list[0].keys():
                batch[key] = list2tensor([x[key] for x in data_list])
            if device >= 0:
                batch = batch.cuda(device=device)

            # copy mechanism prepare
            raw_src = [x['raw_src'].split() for x in data_list]
            token2idx, idx2token, batch_pos_idx_map, idx2idx_mapping \
                = build_copy_mapping(raw_src, vocab)
            batch['token2idx'] = token2idx
            batch['idx2token'] = idx2token
            batch['batch_pos_idx_map'] = batch_pos_idx_map
            batch['idx2idx_mapping'] = idx2idx_mapping
            batch['output'] = '???'

            return batch
コード例 #20
0
ファイル: seq2seq.py プロジェクト: yunying24/Dialogue
    def encode(self, inputs, hidden=None):
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2

        enc_outputs, enc_hidden = self.encoder(enc_inputs, hidden)

        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            attn_memory=enc_outputs if self.attn_mode else None,
            memory_lengths=lengths if self.attn_mode else None)
        return outputs, dec_init_state
コード例 #21
0
ファイル: batcher.py プロジェクト: leehamw/CRMN
def create_turn_batch(data_list):
    """
    create_turn_batch
    """
    turn_batches = []
    for data_dict in data_list:
        batch = Pack()
        for key in data_dict.keys():
            if key in ['src', 'tgt', 'ptr_index', 'kb_index']:
                batch[key] = list2tensor([x for x in data_dict[key]])
            else:
                batch[key] = data_dict[key]
        turn_batches.append(batch)

    return turn_batches
コード例 #22
0
    def collect_metrics(self, outputs, target, emo_target):
        """
        collect_metrics
        """
        num_samples = target[0].size(0)
        num_words = target[1].sum().item()
        metrics = Pack(num_samples=num_samples)
        target_len = target[1]
        mask = sequence_mask(target_len)
        mask = mask.float()
        # logits = outputs.logits
        # nll = self.nll_loss(logits, target)
        out_copy = outputs.out_copy
        #  out_copy  batch x max_len x src
        target_loss = out_copy.gather(2, target[0].unsqueeze(-1)).squeeze(-1)
        target_loss = target_loss * mask
        target_loss += 1e-15
        target_loss = target_loss.log()
        loss = -((target_loss.sum()) / num_words)

        out_emo = outputs.logits  #  batch x  max_len x  dim
        batch_size, max_len, class_num = out_emo.size()
        # out_emo=out_emo.view(batch_size*max_len, class_num)
        # emo_target=emo_target.view(-1)
        target_emo_loss = out_emo.gather(
            2, emo_target[0].unsqueeze(-1)).squeeze(-1)
        target_len -= 1
        mask_ = sequence_mask(target_len)
        mask_ = mask_.float()
        new_mask = mask.data.new(batch_size, max_len).zero_()
        # print(mask.size())
        # print(new_mask.size())
        new_mask[:, :max_len - 1] = mask_

        target_emo_loss = target_emo_loss * new_mask
        target_emo_loss += 1e-15
        target_emo_loss = target_emo_loss.log()
        emo_loss = -((target_emo_loss.sum()) / num_words)

        metrics.add(loss=loss)
        metrics.add(emo_loss=emo_loss)
        #  这里,我们将只计算
        acc = accuracy(out_copy, target[0], mask=mask)
        metrics.add(acc=acc)
        return metrics
コード例 #23
0
    def encode_infe(self, inputs, elmo_embed, hidden=None):
        outputs = Pack()
        enc_inputs = inputs.num_src
        # input_raw = inputs.raw_src
        enc_outputs, enc_hidden = self.encoder.infer(enc_inputs, elmo_embed,
                                                     hidden)

        if self.with_bridge:
            enc_hidden = self.bridge1(enc_hidden)

        layer, batch_size, dim = enc_hidden.size()
        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            input_feed=enc_hidden.data.new(batch_size, dim).zero_() \
                .unsqueeze(1),
            attn_memory=enc_outputs if self.attn_mode else None,
            mask=inputs.mask[0])
        return outputs, dec_init_state
コード例 #24
0
ファイル: memnet.py プロジェクト: little-women/lic-2019
    def encode(self, inputs, hidden=None):
        outputs = Pack()
        enc_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2
        enc_outputs, enc_hidden = self.rnn_encoder(enc_inputs, hidden)

        # knowledge
        batch_size, sent_num, sent = inputs.cue[0].size()
        tmp_len = inputs.cue[1]
        tmp_len[tmp_len > 0] -= 2
        cue_inputs = inputs.cue[0][:, :, 1:-1], tmp_len

        u = self.mem_encoder(cue_inputs, enc_hidden[-1])

        dec_init_state = self.decoder.initialize_state(
            hidden=u.unsqueeze(0),
            attn_memory=enc_outputs if self.attn_mode else None,
            memory_lengths=lengths if self.attn_mode else None)

        return outputs, dec_init_state
コード例 #25
0
ファイル: seq2seq.py プロジェクト: yunying24/Dialogue
    def collect_metrics(self, outputs, target):
        num_samples = target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        logits = outputs.logits
        nll = self.nll_loss(logits, target)
        num_words = target.ne(self.padding_idx).sum().item()
        acc = accuracy(logits, target, padding_idx=self.padding_idx)
        metrics.add(nll=(nll, num_words), acc=acc)
        loss += nll

        metrics.add(loss=loss)
        return metrics
コード例 #26
0
    def interact(self, src, cue=None):
        if src == "":
            return None

        inputs = Pack()
        src = self.src_field.numericalize([src])
        inputs.add(src=list2tensor(src))

        if cue is not None:
            cue = self.cue_field.numericalize([cue])
            inputs.add(cue=list2tensor(cue))
        if self.use_gpu:
            inputs = inputs.cuda()
        _, preds, _, _ = self.forward(inputs=inputs, num_candidates=1)

        pred = self.tgt_field.denumericalize(preds[0][0])

        return pred
コード例 #27
0
ファイル: pgnet.py プロジェクト: little-women/lic-2019
    def encode(self, inputs, hidden=None):
        outputs = Pack()
        hist_inputs = _, lengths = inputs.src[0][:, 1:-1], inputs.src[1] - 2
        # (batch_size, seq_length, hidden_size*num_directions)
        # (num_layers, batch_size, num_directions * hidden_size)
        hist_outputs, hist_hidden = self.hist_encoder(hist_inputs, hidden)

        if self.with_bridge:
            hist_hidden = self.bridge(hist_hidden)

        # knowledge
        batch_size, sent_num, sent = inputs.cue[0].size()
        tmp_len = inputs.cue[1]
        tmp_len[tmp_len > 0] -= 2
        fact_inputs = inputs.cue[0].view(-1, sent)[:, 1:-1], tmp_len.view(-1)
        fact_enc_outputs, fact_enc_hidden = self.fact_encoder(
            fact_inputs, hidden)
        # print(fact_enc_outputs.size())

        fact_outputs = fact_enc_outputs.view(batch_size, sent_num * (sent - 2),
                                             -1)

        # # (batch_size, sent_num, hidden_size)
        # fact_hidden = fact_enc_hidden[-1].view(batch_size, sent_num, -1)
        # # (batch_size, hidden_size)
        # fact_hidden = torch.sum(fact_hidden, 1).squeeze(1)

        # print(hist_hidden[-1].size(), hist_outputs.size(), fact_outputs.size())
        # print(lengths)
        # print(tmp_len)

        dec_init_state = self.decoder.initialize_state(
            hidden=hist_hidden,
            # fact_hidden=fact_hidden,
            fact=inputs.cue[0][:, :, 1:-1].contiguous().view(batch_size, -1),
            hist=inputs.src[0][:, 1:-1],
            attn_fact=fact_outputs if self.attn_mode else None,
            attn_hist=hist_outputs if self.attn_mode else None,
            fact_lengths=tmp_len if self.attn_mode else None,
            hist_lengths=lengths if self.attn_mode else None)

        return outputs, dec_init_state
コード例 #28
0
    def encode(self, inputs, hidden=None):
        """
        encode
        """
        outputs = Pack()
        enc_inputs, lengths = inputs.num_src
        pos_inputs = inputs.num_pos[0]
        enc_outputs, enc_hidden = self.encoder(enc_inputs, pos_inputs, hidden)

        if self.with_bridge:
            enc_hidden = self.bridge1(enc_hidden)

        layer, batch_size, dim = enc_hidden.size()
        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            input_feed=enc_hidden.data.new(batch_size,dim).zero_() \
                              .unsqueeze(1),
            attn_memory=enc_outputs if self.attn_mode else None,
            mask= inputs.mask[0])
        return outputs, dec_init_state
コード例 #29
0
ファイル: hseq2seq.py プロジェクト: little-women/lic-2019
    def encode(self, inputs, hidden=None):
        """
        encode
        """
        outputs = Pack()

        tmp_len = inputs.src[1]
        tmp_len[tmp_len > 0] -= 2
        enc_inputs = _, lengths = inputs.src[0][:, :, 1:-1], tmp_len
        hiera_lengths = lengths.gt(0).long().sum(dim=1)

        enc_outputs, enc_hidden, _ = self.encoder(enc_inputs, hidden)

        if self.with_bridge:
            enc_hidden = self.bridge(enc_hidden)

        dec_init_state = self.decoder.initialize_state(
            hidden=enc_hidden,
            attn_memory=enc_outputs if self.attn_mode else None,
            memory_lengths=hiera_lengths if self.attn_mode else None)
        return outputs, dec_init_state
コード例 #30
0
 def collate(data_list):
     """
     collate
     """
     data_list1, data_list2 = zip(*data_list)
     batch1 = Pack()
     batch2 = Pack()
     data_list1 = list(data_list1)
     data_list2 = list(data_list2)
     for key in data_list1[0].keys():
         batch1[key] = list2tensor([x[key] for x in data_list1])
     if device >= 0:
         batch1 = batch1.cuda(device=device)
     for key in list(data_list2)[0].keys():
         batch2[key] = list2tensor([x[key] for x in data_list2])
     if device >= 0:
         batch2 = batch2.cuda(device=device)
     return batch1, batch2