Пример #1
0
    def forward(self, input, mask=None, state=None):
        if mask is None:
            mask = torch.ones(*input.size()[:-1])
        if self.batch_first:
            input, mask = input.transpose(0, 1), mask.transpose(0, 1)
        batch_size = input.size(1)
        n_cells = 2 * self.n_layers if self.bidirectional else self.n_layers
        if state is None:
            state = (T.to_cuda(
                torch.zeros(n_cells, batch_size, self.dim_hidden)),
                     T.to_cuda(
                         torch.zeros(n_cells, batch_size, self.dim_hidden)))
        i_layer = 0
        output = input
        new_states = []
        while i_layer < n_cells:
            state_i = (s[i_layer].unsqueeze(-1) for s in state)
            output_fw, state_i = self.layers[i_layer](output, mask, state_i)
            new_states.append(state_i)
            i_layer += 1

            if self.bidirectional:
                state_i = (s[i_layer].unsqueeze(-1) for s in state)
                output_bw, state_i = self.layers[i_layer](output, mask,
                                                          state_i)
                new_states.append(state_i)
                i_layer += 1
                output = torch.cat((output_fw, output_bw), -1)
            else:
                output = output_fw
        state = (torch.cat(s) for s in zip(*new_states))

        if self.batch_first:
            output = output.transpose(0, 1)
        return output
Пример #2
0
    def forward(self, input, topic, mask=None, state=None):
        if mask is None:
            mask = T.to_cuda(torch.ones(*input.size()[:-1]))

        if self.batch_first:
            input, mask = input.transpose(0, 1), mask.transpose(0, 1)

        batch_size = input.size(1)
        if state is None:
            state = T.to_cuda(torch.zeros((1, batch_size, self.dim_hidden)))

        if self.reverse:
            input, mask = T.flip(input), T.flip(mask)

        do_dropout = (self.training and self.drop_method != 'none'
                      and self.drop_prob > 0.0)

        ZR_x = self.x2zr(input)
        tp_zr_x, tp_zr_h = self.tp2zr_x(topic), self.tp2zr_h(topic)
        Z_xtp, R_xtp = torch.chunk(ZR_x * tp_zr_x.unsqueeze(0), 2, -1)
        ZR_xtp = torch.cat((self.xtp2z(Z_xtp), self.xtp2r(R_xtp)), -1)

        H_x = self.x2h(input)
        tp_h_x, tp_h_h = self.tp2h_x(topic), self.tp2h_h(topic)
        H_xtp = self.xtp2h(H_x * tp_h_x.unsqueeze(0))

        output = []
        for i, (zr_xtp, h_xtp) in enumerate(zip(ZR_xtp, H_xtp)):
            m = mask[i].unsqueeze(1)
            h_ = state.squeeze(0)

            zr_htp = self.h2zr(h_) * tp_zr_h
            z_htp, r_htp = torch.chunk(zr_htp, 2, -1)
            zr_htp = torch.cat((self.htp2z(z_htp), self.htp2r(r_htp)), -1)
            pre_zr = (zr_xtp + zr_htp).sigmoid()
            z, r = torch.chunk(pre_zr, 2, dim=1)
            h = (h_xtp + r * self.htp2h(self.h2h(h_) * tp_h_h)).tanh()
            h = (1 - z) * h + z * h_

            if do_dropout and self.drop_method == 'link':
                h = F.dropout(h, p=self.drop_prob, training=self.training)

            h = m * h + (1. - m) * h_

            if do_dropout and self.drop_method == 'output':
                output.append(
                    F.dropout(h, p=self.drop_prob, training=self.training))
            else:
                output.append(h)
            state = h.unsqueeze(0)

        output = torch.stack(output)

        if self.reverse:
            output = T.flip(output)

        if self.batch_first:
            output, state = output.transpose(0, 1), state.transpose(0, 1)
        return output, state
Пример #3
0
 def ranking_loss(self, rank_matrices, gen_y_plain, slot_temp):
     rank_loss = None
     for islot, slot in enumerate(slot_temp):
         onty_label = T.to_cuda(
             torch.Tensor(self.ont_vocabs[slot].token2index(
                 [gen_y[islot] for gen_y in gen_y_plain])).long())
         rank_matrix = rank_matrices[islot]
         target_rank_score = rank_matrix.gather(1, onty_label.unsqueeze(-1))
         res_matrix = rank_matrix.masked_fill(
             T.to_cuda(torch.arange(rank_matrix.size(-1))).expand_as(
                 rank_matrix) == onty_label.unsqueeze(-1), -1e9)
         contr_score = res_matrix.max(1)[0]
         contr_loss = torch.relu(1 - target_rank_score + contr_score).mean()
         rank_loss = contr_loss if rank_loss is None else rank_loss + contr_loss
     return rank_loss
Пример #4
0
    def forward(self, input, mask=None, state=None):
        if mask is None:
            mask = T.to_cuda(torch.ones(*input.size()[:-1]))

        if self.batch_first:
            input, mask = input.transpose(0,1), mask.transpose(0,1)

        batch_size = input.size(1)
        if state is None:
            state = T.to_cuda(torch.zeros(1, batch_size, self.dim_hidden))

        if self.reverse:
            input, mask = T.flip(input), T.flip(mask)

        do_dropout = (self.training and 
            self.drop_method != 'none' and self.drop_prob > 0.0)

        ZR_x, H_x = self.x2zr(input), self.x2h(input)

        output = []
        for i, (zr_x, h_x) in enumerate(zip(ZR_x, H_x)):
            m = mask[i].unsqueeze(1)
            h_ = state.squeeze(0)

            pre_zr = (zr_x + self.h2zr(h_)).sigmoid()
            z, r = torch.chunk(pre_zr, 2, dim=1)
            h = (h_x + r*self.h2h(h_)).tanh()
            h = (1 - z) * h + z * h_

            if do_dropout and self.drop_method == 'link':
                h = F.dropout(h, p=self.drop_prob, training=self.training)

            h = m * h + (1. - m) * h_

            if do_dropout and self.drop_method == 'output':
                output.append(F.dropout(h, p=self.drop_prob, training=self.training))
            else:
                output.append(h)
            state = h.unsqueeze(0)

        output = torch.stack(output)

        if self.reverse:
            output = T.flip(output)

        if self.batch_first:
            output, state = output.transpose(0, 1), state.transpose(0, 1)
        return output, state
Пример #5
0
def get_vocab_map(word_vocab, mem_vocab):
    map_id = []
    for imw, mw in enumerate(mem_vocab.get_vocab()):
        iww = word_vocab.index_of(mw)
        map_id.append(iww)
    map_id = T.to_cuda(torch.Tensor(map_id).long())
    return map_id
Пример #6
0
    def encode_and_decode(self, data, use_teacher_forcing, slot_temp=None):
        slot_temp = data["slot_temp"][0]

        # Build unknown mask for memory to encourage generalization
        if self.args['unk_mask'] and self.decoder.training:
            story_size = data['context'].size()
            rand_mask = np.ones(story_size)
            bi_mask = np.random.binomial([np.ones((story_size[0],story_size[1]))], 1-self.dropout)[0]
            rand_mask = rand_mask * bi_mask
            rand_mask = torch.Tensor(rand_mask)
            rand_mask = T.to_cuda(rand_mask)
            story = data['context'] * rand_mask.long()
        else:
            story = data['context']

        # Encode dialog history
        encoded_outputs, encoded_hidden = self.encoder(
                            story.transpose(0, 1), data['context_len'])

        # Get the words that can be copy from the memory
        batch_size = len(data['context_len'])
        self.copy_list = data['context_plain']
        max_res_len = data['generate_y'].size(2) if self.encoder.training else 10
        all_point_outputs, all_gate_outputs, words_point_out, words_class_out = self.decoder.forward(batch_size, 
            encoded_hidden, encoded_outputs, data['context_len'], story, max_res_len, data['generate_y'], 
            use_teacher_forcing, slot_temp) 
        return all_point_outputs, all_gate_outputs, words_point_out, words_class_out
Пример #7
0
 def process_sent(self, tokens):
     """
     Args:
         tokens: list of str tokens. (in batch format, 2-level)
     Returns:
         x, m: seq-first.
     """ 
     token_ids = batch_to_ids(tokens)
     return T.to_cuda(token_ids)
Пример #8
0
 def __drop_scaled(self, value):
     keep_prob = 1.0 - self.drop_prob
     self.mask = T.to_cuda(torch.bernoulli(
         torch.Tensor(1, self.dim_hidden).fill_(keep_prob)))
     value.data.set_(torch.mul(value, self.mask).data)
     value.data *= 1.0/(1.0 - self.drop_prob)
Пример #9
0
    def forward(self,
                batch_size,
                encoded_hidden,
                encoded_outputs,
                encoded_lens,
                ostory,
                story,
                pstory,
                max_res_len,
                target_batches,
                use_teacher_forcing,
                slot_temp,
                postr_info=None):
        all_point_outputs = torch.zeros(len(slot_temp), batch_size,
                                        max_res_len, self.mem_vocab_size)
        all_gate_outputs = torch.zeros(len(slot_temp), batch_size,
                                       self.nb_gate)
        if torch.cuda.is_available():
            all_point_outputs = all_point_outputs.cuda()
            all_gate_outputs = all_gate_outputs.cuda()

        # Get the slot embedding
        slot_emb_dict = {}
        for i, slot in enumerate(slot_temp):
            # Domain embbeding
            if slot.split("-")[0] in self.slot_w2i.keys():
                domain_w2idx = [self.slot_w2i[slot.split("-")[0]]]
                domain_w2idx = T.to_cuda(torch.tensor(domain_w2idx))
                domain_emb = self.Slot_emb(domain_w2idx)
            # Slot embbeding
            if slot.split("-")[1] in self.slot_w2i.keys():
                slot_w2idx = [self.slot_w2i[slot.split("-")[1]]]
                slot_w2idx = T.to_cuda(torch.tensor(slot_w2idx))
                slot_emb = self.Slot_emb(slot_w2idx)

            # Combine two embeddings as one query
            combined_emb = domain_emb + slot_emb
            slot_emb_dict[slot] = combined_emb

        # Compute pointer-generator output, decoding each (domain, slot) one-by-one
        words_point_out = []
        all_rank_matrices = []
        kld_reg = []

        mem_embedding = self.src_embedding(self.mem2word_map)

        analyze = [{} for _ in range(ostory.size(0))]

        for islot, slot in enumerate(slot_temp):
            hidden = encoded_hidden
            words = []
            slot_emb = slot_emb_dict[slot]
            slot_ctrl = self.dropout_layer(slot_emb).expand(
                batch_size, self.hidden_size)

            # C2F_A
            onty = [w.split() for w in self.ont_vocabs[slot].get_vocab()]
            onty_seqs = self.word_vocab.token2index(onty)
            onty_seqs, onty_lens = pad_and_merge(onty_seqs)
            _, onty_repr = self.encoder.forward_batfirst(
                T.to_cuda(onty_seqs), T.to_cuda(onty_lens))  # VD
            onty_repr = onty_repr.squeeze(0)

            cand_qry = slot_ctrl.unsqueeze(1) + onty_repr.unsqueeze(0)
            cand_att_sc = torch.bmm(cand_qry, encoded_outputs.transpose(1, 2))
            for i, l in enumerate(encoded_lens):
                if l < cand_att_sc.size(-1):
                    cand_att_sc[i, :, l:] = -np.inf
            cand_att_alpha = F.softmax(cand_att_sc, -1)
            cand_ctx = torch.bmm(cand_att_alpha, encoded_outputs)  # BVD

            rank_matrix = cand_ctx.mul(
                onty_repr.unsqueeze(0).expand_as(cand_ctx)).sum(-1)
            all_rank_matrices.append(rank_matrix)

            topk_rsc, topk_idx = torch.topk(rank_matrix, k=3, dim=1)
            topk_ctx = cand_ctx.gather(1,
                                       topk_idx.unsqueeze(2).expand(
                                           topk_idx.size(0), topk_idx.size(1),
                                           cand_ctx.size(2)))  # BKD
            topk_sc = torch.bmm(
                topk_ctx,
                self.W_att_prior_slot(slot_ctrl).unsqueeze(-1)).squeeze(
                    -1)  # BK
            topk_alpha = F.softmax(topk_sc, -1)
            postr_ctx = torch.bmm(topk_alpha.unsqueeze(1), topk_ctx).squeeze(1)

            # ######## old code
            # prior_ctx, prior_scores = self.attend2(
            #     encoded_outputs, encoded_outputs, self.W_att_prior_slot(slot_ctrl), encoded_lens)
            # ########
            prior_ctx = postr_ctx  # switch

            decoder_input, hidden = None, None

            for wi in range(max_res_len):
                if wi == 0:
                    decoder_input = slot_ctrl.contiguous()
                    hidden = prior_ctx.unsqueeze(0).contiguous()

                    dec_state, hidden = self.gru(decoder_input.unsqueeze(0),
                                                 hidden)

                    all_gate_outputs[islot] = self.W_gate(hidden.squeeze(0))
                else:
                    dec_state, hidden = self.gru(
                        decoder_input.expand_as(hidden), hidden)

                context_vec, logits, prob = self.attend(
                    encoded_outputs, hidden.squeeze(0), encoded_lens)

                p_vocab = self.attend_vocab(mem_embedding, hidden.squeeze(0))
                p_gen_vec = torch.cat(
                    [dec_state.squeeze(0), context_vec, decoder_input], -1)
                vocab_pointer_switches = self.sigmoid(self.W_ratio(p_gen_vec))

                p_context_ptr = T.to_cuda(torch.zeros(p_vocab.size()))
                p_context_ptr.scatter_add_(1, pstory, prob)
                final_p_vocab = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr + \
                                vocab_pointer_switches.expand_as(p_context_ptr) * p_vocab

                pred_word = torch.argmax(final_p_vocab, dim=1)
                all_point_outputs[islot, :, wi, :] = final_p_vocab

                words.append([
                    self.mem_vocab.index2token(w_idx.item())
                    for w_idx in pred_word
                ])

                if self.args["analyze"] and wi == 0:
                    k = 3
                    topk_prior_sc = [[(' '.join(
                        self.word_vocab.index2token(
                            ostory[i, w - 2:w + 3].clone().cpu().numpy())),
                                       v.item())
                                      for v, w in zip(*torch.topk(sc, k))]
                                     for i, sc in enumerate(prior_scores)]
                    topk_final_w = [[
                        (self.word_vocab.token_of(w.item()), v.item())
                        for v, w in zip(tpkv, tpkw)
                    ] for tpkv, tpkw in zip(*torch.topk(final_p_vocab, k, -1))]
                    topk_ptr_w = [[(' '.join(
                        self.word_vocab.index2token(
                            ostory[i, w - 2:w + 3].clone().cpu().numpy())),
                                    v.item())
                                   for v, w in zip(*torch.topk(sc, k))]
                                  for i, sc in enumerate(prob)]
                    topk_gen_w = [[
                        (self.word_vocab.token_of(w.item()), v.item())
                        for v, w in zip(tpkv, tpkw)
                    ] for tpkv, tpkw in zip(*torch.topk(p_vocab, k, -1))]

                    tgt_w = [
                        ' '.join(
                            self.word_vocab.index2token([
                                w.item() for w in inst[islot]
                                if w not in [1, 2]
                            ])) for inst in target_batches
                    ]
                    sws = vocab_pointer_switches.view(-1).tolist()

                    gate_slot = all_gate_outputs[islot].tolist()

                    for i_inst in range(batch_size):
                        analyze[i_inst][slot] = {
                            "ctx": topk_prior_sc[i_inst],
                            "ptr": topk_ptr_w[i_inst],
                            "gen": topk_gen_w[i_inst],
                            "tgt": tgt_w[i_inst],
                            "sw": sws[i_inst],
                            "final": topk_final_w[i_inst],
                            "gate": gate_slot[i_inst]
                        }

                if use_teacher_forcing:
                    decoder_input = mem_embedding[
                        target_batches[:, islot,
                                       wi]]  # Chosen word is next input
                else:
                    decoder_input = mem_embedding[pred_word]
                if torch.cuda.is_available():
                    decoder_input = decoder_input.cuda()

            # if not self.training:
            words_point_out.append(words)

        if self.args["analyze"]:
            return all_point_outputs, all_gate_outputs, words_point_out, [], analyze, all_rank_matrices
        else:
            return all_point_outputs, all_gate_outputs, words_point_out, [], all_rank_matrices
Пример #10
0
 def get_state(self, bsz):
     """Get cell states and hidden states."""
     return T.to_cuda(torch.zeros(2, bsz, self.hidden_size))
Пример #11
0
    def forward(self, input, mask=None, state=None):
        if mask is None:
            mask = T.to_cuda(torch.ones(*input.size()[:-1]))

        if self.batch_first:
            input, mask = input.transpose(0, 1), mask.transpose(0, 1)

        batch_size = input.size(1)
        if state is None:
            state = T.to_cuda(torch.zeros((1, batch_size, self.dim_hidden)))

        if self.reverse:
            input, mask = T.flip(input), T.flip(mask)

        do_dropout = (self.training and self.drop_method != 'none'
                      and self.drop_prob > 0.0)

        H_x = self.x2h(input)

        output = []
        for i, h_x in enumerate(H_x):
            m = mask[i].unsqueeze(1)
            h_, c_ = state
            h_, c_ = h_.squeeze(0), c_.squeeze(0)

            preact = h_x + self.h2h(h_)
            i, f, o, cc = torch.chunk(preact, 4, 1)
            i, f, o = i.sigmoid(), f.sigmoid(), o.sigmoid()
            cc = cc.tanh()

            if do_dropout and self.drop_method == 'semeniuta':
                cc = F.dropout(cc, p=self.drop_prob, training=self.training)

            c = c_ * f + cc * i

            if do_dropout and self.drop_method == 'moon':
                self.__drop_scaled(c)

            h = o * c.tanh()

            if do_dropout:
                if self.drop_method == 'link':
                    h = F.dropout(h, p=self.drop_prob, training=self.training)
                elif self.drop_method == 'gal':
                    self.__drop_scaled(h)

            c = m * c + (1. - m) * c_
            h = m * h + (1. - m) * h_

            if do_dropout and self.drop_method == 'output':
                output.append(
                    F.dropout(h, p=self.drop_prob, training=self.training))
            else:
                output.append(h)
            h, c = h.unsqueeze(0), c.unsqueeze(0)
            state = (h, c)

        output = torch.stack(output)

        if self.reverse:
            output = T.flip(output)

        if self.batch_first:
            output, state = output.transpose(0, 1), state.transpose(0, 1)
        return output, state