Beispiel #1
0
class PostEncoder(nn.Module):
    def __init__(self, param):
        super().__init__()
        self.args = args = param.args
        self.param = param

        self.postGRU = MyGRU(args.embedding_size,
                             args.eh_size,
                             bidirectional=True)

    def forward(self, incoming):
        incoming.hidden = hidden = Storage()
        i = incoming.state.num
        # incoming.post.embedding : batch * sen_num * length * vec_dim
        # post_length : batch * sen_num
        raw_post = incoming.post.embedding
        raw_post_length = LongTensor(incoming.data.post_length[:, i])
        incoming.state.valid_sen = torch.sum(torch.nonzero(raw_post_length), 1)
        raw_reverse = torch.cumsum(torch.gt(raw_post_length, 0), 0) - 1
        incoming.state.reverse_valid_sen = raw_reverse * torch.ge(
            raw_reverse, 0).to(torch.long)
        valid_sen = incoming.state.valid_sen
        incoming.state.valid_num = valid_sen.shape[0]

        post = torch.index_select(raw_post, 0, valid_sen).transpose(
            0, 1)  # [length, valid_num, vec_dim]
        post_length = torch.index_select(raw_post_length, 0,
                                         valid_sen).cpu().numpy()

        hidden.h, hidden.h_n = self.postGRU.forward(post,
                                                    post_length,
                                                    need_h=True)
        hidden.length = post_length

    def detail_forward(self, incoming):
        incoming.hidden = hidden = Storage()
        # incoming.post.embedding : batch * sen_num * length * vec_dim
        # post_length : batch * sen_num
        raw_post = incoming.post.embedding
        raw_post_length = LongTensor(incoming.data.post_length)
        incoming.state.valid_sen = torch.sum(torch.nonzero(raw_post_length), 1)
        raw_reverse = torch.cumsum(torch.gt(raw_post_length, 0), 0) - 1
        incoming.state.reverse_valid_sen = raw_reverse * torch.ge(
            raw_reverse, 0).to(torch.long)
        valid_sen = incoming.state.valid_sen
        incoming.state.valid_num = valid_sen.shape[0]

        post = torch.index_select(raw_post, 0, valid_sen).transpose(
            0, 1)  # [length, valid_num, vec_dim]
        post_length = torch.index_select(
            raw_post_length, 0, valid_sen).cpu().numpy()  # [valid_num]

        hidden.h, hidden.h_n = self.postGRU.forward(post,
                                                    post_length,
                                                    need_h=True)
        hidden.length = post_length
Beispiel #2
0
class WikiEncoder(nn.Module):
    def __init__(self, param):
        super().__init__()
        self.args = args = param.args
        self.param = param

        self.sentenceGRU = MyGRU(args.embedding_size,
                                 args.eh_size,
                                 bidirectional=True)

    def forward(self, incoming):
        i = incoming.state.num
        batch = incoming.wiki.embedding.shape[0]

        incoming.wiki_hidden = wiki_hidden = Storage()
        incoming.wiki_sen = incoming.data.wiki[:,
                                               i]  # [batch, wiki_sen_num, wiki_sen_len]
        wiki_length = incoming.data.wiki_length[:, i].reshape(
            -1)  # (batch * wiki_sen_num)
        embed = incoming.wiki.embedding.reshape(
            (-1, incoming.wiki.embedding.shape[2], self.args.embedding_size))
        # (batch * wiki_sen_num) * wiki_sen_len * embedding_size
        embed = embed.transpose(
            0, 1)  # wiki_sen_len * (batch * wiki_sen_num) * embedding_size

        wiki_hidden.h1, wiki_hidden.h_n1 = self.sentenceGRU.forward(
            embed, wiki_length, need_h=True)
        # [wiki_sen_len, batch * wiki_sen_num, 2 * eh_size], [batch * wiki_sen_num, 2 * eh_size]
        wiki_hidden.h1 = wiki_hidden.h1.reshape(
            (wiki_hidden.h1.shape[0], batch, -1, wiki_hidden.h1.shape[-1]))
        # [wiki_sen_len, batch,  wiki_sen_num, 2 * eh_size]
        wiki_hidden.h_n1 = wiki_hidden.h_n1.reshape(
            (batch, -1, 2 * self.args.eh_size)).transpose(0, 1)
Beispiel #3
0
class PostEncoder(nn.Module):
	def __init__(self, param):
		super().__init__()
		self.args = args = param.args
		self.param = param

		self.postGRU = MyGRU(args.embedding_size, args.eh_size, bidirectional=True)

	def forward(self, incoming):
		incoming.hidden = hidden = Storage()
		hidden.h_n = self.postGRU.forward(incoming.post.embedding, incoming.data.post_length)
Beispiel #4
0
class PostEncoder(nn.Module):
	def __init__(self, param):
		super().__init__()
		self.args = args = param.args
		self.param = param

		self.postGRU = MyGRU(args.embedding_size, args.eh_size, bidirectional=True)
		self.drop = nn.Dropout(args.droprate)
		if self.args.batchnorm:
			self.seqnorm = SequenceBatchNorm(args.eh_size * 2)
			self.batchnorm = nn.BatchNorm1d(args.eh_size * 2)

	def forward(self, incoming):
		incoming.hidden = hidden = Storage()
		hidden.h_n, hidden.h = self.postGRU.forward(incoming.post.embedding, incoming.data.post_length, need_h=True)
		if self.args.batchnorm:
			hidden.h = self.seqnorm(hidden.h, incoming.data.post_length)
			hidden.h_n = self.batchnorm(hidden.h_n)
		hidden.h = self.drop(hidden.h)
		hidden.h_n = self.drop(hidden.h_n)
Beispiel #5
0
class GenNetwork(nn.Module):
    def __init__(self, param):
        super().__init__()
        self.args = args = param.args
        self.param = param

        self.GRULayer = MyGRU(args.embedding_size + 2 * args.eh_size,
                              args.dh_size,
                              initpara=False)
        self.wLinearLayer = nn.Linear(args.dh_size,
                                      param.volatile.dm.vocab_size)
        self.lossCE = nn.NLLLoss(ignore_index=param.volatile.dm.unk_id)
        self.wCopyLinear = nn.Linear(args.eh_size * 2, args.dh_size)
        self.drop = nn.Dropout(args.droprate)
        self.start_generate_id = 2

    def teacherForcing(self, inp, gen):
        embedding = inp.embedding  # length * valid_num * embedding_dim
        length = inp.resp_length  # valid_num
        wiki_cv = inp.wiki_cv  # valid_num * (2 * eh_size)
        wiki_cv = wiki_cv.unsqueeze(0).repeat(embedding.shape[0], 1, 1)

        gen.h, gen.h_n = self.GRULayer.forward(torch.cat([embedding, wiki_cv],
                                                         dim=-1),
                                               length - 1,
                                               h_init=inp.init_h,
                                               need_h=True)

        gen.w = self.wLinearLayer(self.drop(gen.h))
        gen.w = torch.clamp(gen.w, max=5.0)
        gen.vocab_p = torch.exp(gen.w)
        wikiState = torch.transpose(
            torch.tanh(self.wCopyLinear(inp.wiki_hidden)), 0, 1)
        copyW = torch.exp(
            torch.clamp(torch.unsqueeze(
                torch.transpose(
                    torch.sum(
                        torch.unsqueeze(gen.h, 1) *
                        torch.unsqueeze(wikiState, 0), -1), 1, 2), 2),
                        max=5.0))

        inp.wiki_sen = inp.wiki_sen[:, :inp.wiki_hidden.shape[1]]
        copyHead = zeros(1, inp.wiki_sen.shape[0], inp.wiki_hidden.shape[1],
                         self.param.volatile.dm.vocab_size).scatter_(
                             3,
                             torch.unsqueeze(torch.unsqueeze(inp.wiki_sen, 0),
                                             3), 1)
        gen.copy_p = torch.matmul(copyW, copyHead).squeeze(2)
        gen.p = gen.vocab_p + gen.copy_p + 1e-10
        gen.p = gen.p / torch.unsqueeze(torch.sum(gen.p, 2), 2)
        gen.p = torch.clamp(gen.p, 1e-10, 1.0)

    def freerun(self, inp, gen, mode='max'):
        batch_size = inp.batch_size
        dm = self.param.volatile.dm

        first_emb = inp.embLayer(LongTensor([dm.go_id])).repeat(batch_size, 1)
        gen.w_pro = []
        gen.w_o = []
        gen.emb = []
        flag = zeros(batch_size).byte()
        EOSmet = []

        inp.wiki_sen = inp.wiki_sen[:, :inp.wiki_hidden.shape[1]]
        copyHead = zeros(1, inp.wiki_sen.shape[0], inp.wiki_hidden.shape[1],
                         self.param.volatile.dm.vocab_size).scatter_(
                             3,
                             torch.unsqueeze(torch.unsqueeze(inp.wiki_sen, 0),
                                             3), 1)
        wikiState = torch.transpose(
            torch.tanh(self.wCopyLinear(inp.wiki_hidden)), 0, 1)

        next_emb = first_emb
        gru_h = inp.init_h
        gen.p = []

        wiki_cv = inp.wiki_cv  # valid_num * (2 * eh_size)

        for _ in range(self.args.max_sent_length):
            now = torch.cat([next_emb, wiki_cv], dim=-1)

            gru_h = self.GRULayer.cell_forward(now, gru_h)
            w = self.wLinearLayer(gru_h)
            w = torch.clamp(w, max=5.0)
            vocab_p = torch.exp(w)
            copyW = torch.exp(
                torch.clamp(torch.unsqueeze(
                    (torch.sum(torch.unsqueeze(gru_h, 0) * wikiState,
                               -1).transpose_(0, 1)), 1),
                            max=5.0))  # batch * 1 * wiki_len
            copy_p = torch.matmul(copyW, copyHead).squeeze()

            p = vocab_p + copy_p + 1e-10
            p = p / torch.unsqueeze(torch.sum(p, 1), 1)
            p = torch.clamp(p, 1e-10, 1.0)
            gen.p.append(p)

            if mode == "max":
                w_o = torch.argmax(p[:, self.start_generate_id:],
                                   dim=1) + self.start_generate_id
                next_emb = inp.embLayer(w_o)
            elif mode == "gumbel":
                w_onehot, w_o = gumbel_max(p[:, self.start_generate_id:], 1, 1)
                w_o = w_o + self.start_generate_id
                next_emb = torch.sum(
                    torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[2:], 1)
            gen.w_o.append(w_o)
            gen.emb.append(next_emb)

            EOSmet.append(flag)
            flag = flag | (w_o == dm.eos_id).byte()
            if torch.sum(flag).detach().cpu().numpy() == batch_size:
                break

        EOSmet = 1 - torch.stack(EOSmet)
        gen.w_o = torch.stack(gen.w_o) * EOSmet.long()
        gen.emb = torch.stack(gen.emb) * EOSmet.float().unsqueeze(-1)
        gen.length = torch.sum(EOSmet, 0).detach().cpu().numpy()
        gen.h_n = gru_h

    def forward(self, incoming):
        # incoming.data.wiki_sen: batch * wiki_len * wiki_sen_len
        # incoming.wiki_hidden.h1: wiki_sen_len * (batch *wiki_len) * (eh_size * 2)
        # incoming.wiki_hidden.h_n1: wiki_len * batch * (eh_size * 2)
        # incoming.wiki_hidden.h2: wiki_len * batch * (eh_size * 2)
        # incoming.wiki_hidden.h_n2: batch * (eh_size * 2)

        i = incoming.state.num
        valid_sen = incoming.state.valid_sen
        reverse_valid_sen = incoming.state.reverse_valid_sen

        inp = Storage()
        inp.wiki_sen = incoming.conn.selected_wiki_sen
        inp.wiki_hidden = incoming.conn.selected_wiki_h
        raw_resp_length = torch.tensor(incoming.data.resp_length[:, i],
                                       dtype=torch.long)
        raw_embedding = incoming.resp.embedding

        resp_length = inp.resp_length = torch.index_select(
            raw_resp_length, 0, valid_sen.cpu()).numpy()
        inp.embedding = torch.index_select(
            raw_embedding, 0,
            valid_sen).transpose(0, 1)  # length * valid_num * embedding_dim
        resp = torch.index_select(incoming.data.resp[:, i], 0,
                                  valid_sen).transpose(0, 1)[1:]
        inp.init_h = incoming.conn.init_h
        inp.wiki_cv = incoming.conn.wiki_cv

        incoming.gen = gen = Storage()
        self.teacherForcing(inp, gen)
        # gen.h_n: valid_num * dh_dim

        w_slice = torch.index_select(gen.w, 1, reverse_valid_sen)
        if w_slice.shape[0] < self.args.max_sent_length:
            w_slice = torch.cat([
                w_slice,
                zeros(self.args.max_sent_length - w_slice.shape[0],
                      w_slice.shape[1], w_slice.shape[2])
            ], 0)
        if i == 0:
            incoming.state.w_all = w_slice.unsqueeze(0)
        else:
            incoming.state.w_all = torch.cat([
                incoming.state.w_all,
                w_slice.unsqueeze(0)
            ], 0)  #state.w_all: sen_num * sen_length * batch_size * vocab_size

        w_o_f = flattenSequence(torch.log(gen.p), resp_length - 1)
        data_f = flattenSequence(resp, resp_length - 1)
        incoming.statistic.sen_num += incoming.state.valid_num
        now = 0
        for l in resp_length:
            loss = self.lossCE(w_o_f[now:now + l - 1, :],
                               data_f[now:now + l - 1])
            if incoming.result.word_loss is None:
                incoming.result.word_loss = loss.clone()
            else:
                incoming.result.word_loss += loss.clone()
            incoming.statistic.sen_loss.append(loss.item())
            now += l - 1

        if i == incoming.state.last - 1:
            incoming.statistic.sen_loss = torch.tensor(
                incoming.statistic.sen_loss)
            incoming.result.perplexity = torch.mean(
                torch.exp(incoming.statistic.sen_loss))

    def detail_forward(self, incoming):

        index = i = incoming.state.num
        valid_sen = incoming.state.valid_sen
        reverse_valid_sen = incoming.state.reverse_valid_sen

        inp = Storage()
        inp.wiki_sen = incoming.conn.selected_wiki_sen
        inp.wiki_hidden = incoming.conn.selected_wiki_h
        inp.init_h = incoming.conn.init_h
        inp.wiki_cv = incoming.conn.wiki_cv

        batch_size = inp.batch_size = incoming.state.valid_num
        inp.embLayer = incoming.resp.embLayer

        incoming.gen = gen = Storage()
        self.freerun(inp, gen)

        dm = self.param.volatile.dm
        w_o = gen.w_o.detach().cpu().numpy()

        w_o_slice = torch.index_select(gen.w_o, 1, reverse_valid_sen)
        if w_o_slice.shape[0] < self.args.max_sent_length:
            w_o_slice = torch.cat([
                w_o_slice,
                zeros(self.args.max_sent_length - w_o_slice.shape[0],
                      w_o_slice.shape[1]).to(torch.long)
            ], 0)

        if index == 0:
            incoming.state.w_o_all = w_o_slice.unsqueeze(0)
        else:
            incoming.state.w_o_all = torch.cat(
                [incoming.state.w_o_all,
                 w_o_slice.unsqueeze(0)],
                0)  #state.w_all: sen_num * sen_length * batch_size
Beispiel #6
0
class ConnectLayer(nn.Module):
    def __init__(self, param):
        super().__init__()
        self.args = args = param.args
        self.param = param
        self.initLinearLayer = nn.Linear(args.eh_size * 4, args.dh_size)

        self.wiki_atten = nn.Softmax(dim=0)
        self.atten_lossCE = nn.CrossEntropyLoss(ignore_index=0)

        self.last_wiki = None
        self.hist_len = args.hist_len
        self.hist_weights = args.hist_weights

        self.compareGRU = MyGRU(2 * args.eh_size,
                                args.eh_size,
                                bidirectional=True)

        self.tilde_linear = nn.Linear(4 * args.eh_size, 2 * args.eh_size)
        self.attn_query = nn.Linear(2 * args.eh_size,
                                    2 * args.eh_size,
                                    bias=False)
        self.attn_key = nn.Linear(4 * args.eh_size,
                                  2 * args.eh_size,
                                  bias=False)
        self.attn_v = nn.Linear(2 * args.eh_size, 1, bias=False)

    def forward(self, incoming):
        incoming.conn = conn = Storage()
        index = incoming.state.num
        valid_sen = incoming.state.valid_sen
        valid_wiki_h_n1 = torch.index_select(
            incoming.wiki_hidden.h_n1, 1,
            valid_sen)  # [wiki_sen_num, valid_num, 2 * eh_size]
        valid_wiki_sen = torch.index_select(
            incoming.wiki_sen, 0,
            valid_sen)  # [valid_num, wiki_sen_num, wiki_sen_len]
        valid_wiki_h1 = torch.index_select(
            incoming.wiki_hidden.h1, 1,
            valid_sen)  # [wiki_sen_len, valid_num, wiki_sen_num, 2 * eh_size]
        atten_label = torch.index_select(incoming.data.atten[:, index], 0,
                                         valid_sen)  # valid_num
        valid_wiki_num = torch.index_select(
            LongTensor(incoming.data.wiki_num[:, index]), 0,
            valid_sen)  # valid_num

        if index == 0:
            tilde_wiki = zeros(1, 1, 2 * self.args.eh_size) * ones(
                valid_wiki_h_n1.shape[0], valid_wiki_h_n1.shape[1], 1)
        else:
            wiki_hidden = incoming.wiki_hidden
            wiki_num = incoming.data.wiki_num[:, index]  # [batch], numpy array
            wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward(
                wiki_hidden.h_n1, wiki_num, need_h=True)
            valid_wiki_h2 = torch.index_select(
                wiki_hidden.h2, 1,
                valid_sen)  # wiki_len * valid_num * (2 * eh_size)

            tilde_wiki_list = []
            for i in range(self.last_wiki.size(-1)):
                last_wiki = torch.index_select(self.last_wiki[:, :, i], 0,
                                               valid_sen).unsqueeze(
                                                   0)  # 1, valid_num, (2 * eh)
                tilde_wiki = torch.tanh(
                    self.tilde_linear(
                        torch.cat([
                            last_wiki - valid_wiki_h2,
                            last_wiki * valid_wiki_h2
                        ],
                                  dim=-1)))
                tilde_wiki_list.append(
                    tilde_wiki.unsqueeze(-1) * self.hist_weights[i])
            tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum(dim=-1)

        query = self.attn_query(incoming.hidden.h_n)  # [valid_num, hidden]
        key = self.attn_key(
            torch.cat([valid_wiki_h_n1[:tilde_wiki.shape[0]], tilde_wiki],
                      dim=-1))  # [wiki_sen_num, valid_num, hidden]
        atten_sum = self.attn_v(torch.tanh(query + key)).squeeze(
            -1)  # [wiki_sen_num, valid_num]
        beta = atten_sum.t()  # [valid_num, wiki_len]

        mask = torch.arange(beta.shape[1], device=beta.device).long().expand(
            beta.shape[0],
            beta.shape[1]).transpose(0, 1)  # [wiki_sen_num, valid_num]
        expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as(
            mask)  # [wiki_sen_num, valid_num]
        reverse_mask = (expand_wiki_num <=
                        mask).float()  # [wiki_sen_num, valid_num]

        if index == 0:
            incoming.result.atten_loss = self.atten_lossCE(beta, atten_label)
        else:
            incoming.result.atten_loss += self.atten_lossCE(beta, atten_label)

        golden_alpha = zeros(beta.shape).scatter_(1, atten_label.unsqueeze(1),
                                                  1)
        golden_alpha = torch.t(golden_alpha).unsqueeze(2)
        wiki_cv = torch.sum(valid_wiki_h_n1[:golden_alpha.shape[0]] *
                            golden_alpha,
                            dim=0)  # valid_num * (2 * eh_size)
        conn.wiki_cv = wiki_cv
        conn.init_h = self.initLinearLayer(
            torch.cat([incoming.hidden.h_n, wiki_cv], 1))

        reverse_valid_sen = incoming.state.reverse_valid_sen
        if index == 0:
            self.last_wiki = torch.index_select(wiki_cv, 0,
                                                reverse_valid_sen).unsqueeze(
                                                    -1)  # [batch, 2 * eh_size]
        else:
            self.last_wiki = torch.cat([
                torch.index_select(wiki_cv, 0,
                                   reverse_valid_sen).unsqueeze(-1),
                self.last_wiki[:, :, :self.hist_len - 1]
            ],
                                       dim=-1)

        atten_indices = atten_label.unsqueeze(1)  # valid_num * 1
        atten_indices = torch.cat([
            torch.arange(atten_indices.shape[0]).unsqueeze(1),
            atten_indices.cpu()
        ], 1)  # valid_num * 2
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 0,
            1)  # valid_num * wiki_sen_len * wiki_len * (2 * eh_size)
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 1,
            2)  # valid_num * wiki_len * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk(2,
                                                                 1)].squeeze(1)
        conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk(
            2, 1)].squeeze(1)

    def detail_forward(self, incoming):
        incoming.conn = conn = Storage()
        index = incoming.state.num
        valid_sen = incoming.state.valid_sen
        valid_wiki_h_n1 = torch.index_select(
            incoming.wiki_hidden.h_n1, 1,
            valid_sen)  # [wiki_sen_num, valid_num, 2 * eh_size]
        valid_wiki_sen = torch.index_select(
            incoming.wiki_sen, 0,
            valid_sen)  # [valid_num, wiki_sen_num, wiki_sen_len]
        valid_wiki_h1 = torch.index_select(
            incoming.wiki_hidden.h1, 1,
            valid_sen)  # [wiki_sen_len, valid_num, wiki_sen_num, 2 * eh_size]
        atten_label = torch.index_select(incoming.data.atten[:, index], 0,
                                         valid_sen)  # valid_num
        valid_wiki_num = torch.index_select(
            LongTensor(incoming.data.wiki_num[:, index]), 0,
            valid_sen)  # valid_num

        if index == 0:
            tilde_wiki = zeros(1, 1, 2 * self.args.eh_size) * ones(
                valid_wiki_h_n1.shape[0], valid_wiki_h_n1.shape[1], 1)
        else:
            wiki_hidden = incoming.wiki_hidden
            wiki_num = incoming.data.wiki_num[:, index]  # [batch], numpy array
            wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward(
                wiki_hidden.h_n1, wiki_num, need_h=True)
            valid_wiki_h2 = torch.index_select(
                wiki_hidden.h2, 1,
                valid_sen)  # wiki_len * valid_num * (2 * eh_size)

            tilde_wiki_list = []
            for i in range(self.last_wiki.size(-1)):
                last_wiki = torch.index_select(self.last_wiki[:, :, i], 0,
                                               valid_sen).unsqueeze(
                                                   0)  # 1, valid_num, (2 * eh)
                tilde_wiki = torch.tanh(
                    self.tilde_linear(
                        torch.cat([
                            last_wiki - valid_wiki_h2,
                            last_wiki * valid_wiki_h2
                        ],
                                  dim=-1)))
                tilde_wiki_list.append(
                    tilde_wiki.unsqueeze(-1) * self.hist_weights[i])
            tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum(dim=-1)

        query = self.attn_query(incoming.hidden.h_n)  # [valid_num, hidden]
        key = self.attn_key(
            torch.cat([valid_wiki_h_n1[:tilde_wiki.shape[0]], tilde_wiki],
                      dim=-1))  # [wiki_sen_num, valid_num, hidden]
        atten_sum = self.attn_v(torch.tanh(query + key)).squeeze(
            -1)  # [wiki_sen_num, valid_num]
        beta = atten_sum.t()  # [valid_num, wiki_len]

        mask = torch.arange(beta.shape[1], device=beta.device).long().expand(
            beta.shape[0],
            beta.shape[1]).transpose(0, 1)  # [wiki_sen_num, valid_num]
        expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as(
            mask)  # [wiki_sen_num, valid_num]
        reverse_mask = (expand_wiki_num <=
                        mask).float()  # [wiki_sen_num, valid_num]

        if index == 0:
            incoming.result.atten_loss = self.atten_lossCE(beta, atten_label)
        else:
            incoming.result.atten_loss += self.atten_lossCE(beta, atten_label)

        beta = torch.t(beta) - 1e10 * reverse_mask
        alpha = self.wiki_atten(beta)  # wiki_len * valid_num
        incoming.acc.prob.append(
            torch.index_select(
                alpha.t(), 0, incoming.state.reverse_valid_sen).cpu().tolist())
        atten_indices = torch.argmax(alpha, 0)
        alpha = zeros(beta.t().shape).scatter_(1, atten_indices.unsqueeze(1),
                                               1)
        alpha = torch.t(alpha)
        wiki_cv = torch.sum(valid_wiki_h_n1[:alpha.shape[0]] *
                            alpha.unsqueeze(2),
                            dim=0)  # valid_num * (2 * eh_size)
        conn.wiki_cv = wiki_cv
        conn.init_h = self.initLinearLayer(
            torch.cat([incoming.hidden.h_n, wiki_cv], 1))

        reverse_valid_sen = incoming.state.reverse_valid_sen
        if index == 0:
            self.last_wiki = torch.index_select(wiki_cv, 0,
                                                reverse_valid_sen).unsqueeze(
                                                    -1)  # [batch, 2 * eh_size]
        else:
            self.last_wiki = torch.cat([
                torch.index_select(wiki_cv, 0,
                                   reverse_valid_sen).unsqueeze(-1),
                self.last_wiki[:, :, :self.hist_len - 1]
            ],
                                       dim=-1)

        incoming.acc.label.append(
            torch.index_select(atten_label, 0,
                               reverse_valid_sen).cpu().tolist())
        incoming.acc.pred.append(
            torch.index_select(atten_indices, 0,
                               reverse_valid_sen).cpu().tolist())

        atten_indices = atten_indices.unsqueeze(1)
        atten_indices = torch.cat([
            torch.arange(atten_indices.shape[0]).unsqueeze(1),
            atten_indices.cpu()
        ], 1)  # valid_num * 2
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 0,
            1)  # valid_num * wiki_sen_len * wiki_len * (2 * eh_size)
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 1,
            2)  # valid_num * wiki_len * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk(
            2, 1)].squeeze(1)  # valid_num * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk(
            2, 1)].squeeze(1)  # valid_num * wiki_sen_len

    def forward_disentangle(self, incoming):
        incoming.conn = conn = Storage()
        index = incoming.state.num
        valid_sen = incoming.state.valid_sen
        valid_wiki_h_n1 = torch.index_select(
            incoming.wiki_hidden.h_n1, 1,
            valid_sen)  # [wiki_sen_num, valid_num, 2 * eh_size]
        valid_wiki_sen = torch.index_select(
            incoming.wiki_sen, 0,
            valid_sen)  # [valid_num, wiki_sen_num, wiki_sen_len]
        valid_wiki_h1 = torch.index_select(
            incoming.wiki_hidden.h1, 1,
            valid_sen)  # [wiki_sen_len, valid_num, wiki_sen_num, 2 * eh_size]
        atten_label = torch.index_select(incoming.data.atten[:, index], 0,
                                         valid_sen)  # valid_num
        valid_wiki_num = torch.index_select(
            LongTensor(incoming.data.wiki_num[:, index]), 0,
            valid_sen)  # valid_num

        reverse_valid_sen = incoming.state.reverse_valid_sen
        self.beta = torch.sum(valid_wiki_h_n1 * incoming.hidden.h_n,
                              dim=2)  # wiki_len * valid_num
        self.beta = torch.t(self.beta)  # [valid_num, wiki_len]

        mask = torch.arange(
            self.beta.shape[1], device=self.beta.device).long().expand(
                self.beta.shape[0],
                self.beta.shape[1]).transpose(0,
                                              1)  # [wiki_sen_num, valid_num]
        expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as(
            mask)  # [wiki_sen_num, valid_num]
        reverse_mask = (expand_wiki_num <=
                        mask).float()  # [wiki_sen_num, valid_num]

        if index > 0:
            wiki_hidden = incoming.wiki_hidden
            wiki_num = incoming.data.wiki_num[:, index]  # [batch], numpy array
            wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward(
                wiki_hidden.h_n1, wiki_num, need_h=True)
            valid_wiki_h2 = torch.index_select(
                wiki_hidden.h2, 1,
                valid_sen)  # wiki_len * valid_num * (2 * eh_size)

            tilde_wiki_list = []
            for i in range(self.last_wiki.size(-1)):
                last_wiki = torch.index_select(self.last_wiki[:, :, i], 0,
                                               valid_sen).unsqueeze(
                                                   0)  # 1, valid_num, (2 * eh)
                tilde_wiki = torch.tanh(
                    self.tilde_linear(
                        torch.cat([
                            last_wiki - valid_wiki_h2,
                            last_wiki * valid_wiki_h2
                        ],
                                  dim=-1)))
                tilde_wiki_list.append(
                    tilde_wiki.unsqueeze(-1) * self.hist_weights[i])
            tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum(dim=-1)
            # wiki_len * valid_num * (2 * eh_size)

            query = self.attn_query(tilde_wiki)  # [1, valid_num, hidden]
            key = self.attn_key(
                torch.cat([valid_wiki_h2, tilde_wiki],
                          dim=-1))  # [wiki_sen_num, valid_num, hidden]
            atten_sum = self.attn_v(torch.tanh(query + key)).squeeze(
                -1)  # [wiki_sen_num, valid_num]

            self.beta = self.beta[:, :atten_sum.shape[0]] + torch.t(atten_sum)

        if index == 0:
            incoming.result.atten_loss = self.atten_lossCE(
                self.beta,  #self.alpha.t().log(),
                atten_label)
        else:
            incoming.result.atten_loss += self.atten_lossCE(
                self.beta,  #self.alpha.t().log(),
                atten_label)

        golden_alpha = zeros(self.beta.shape).scatter_(
            1, atten_label.unsqueeze(1), 1)
        golden_alpha = torch.t(golden_alpha).unsqueeze(2)
        wiki_cv = torch.sum(valid_wiki_h_n1[:golden_alpha.shape[0]] *
                            golden_alpha,
                            dim=0)  # valid_num * (2 * eh_size)
        conn.wiki_cv = wiki_cv
        conn.init_h = self.initLinearLayer(
            torch.cat([incoming.hidden.h_n, wiki_cv], 1))

        if index == 0:
            self.last_wiki = torch.index_select(wiki_cv, 0,
                                                reverse_valid_sen).unsqueeze(
                                                    -1)  # [batch, 2 * eh_size]
        else:
            self.last_wiki = torch.cat([
                torch.index_select(wiki_cv, 0,
                                   reverse_valid_sen).unsqueeze(-1),
                self.last_wiki[:, :, :self.hist_len - 1]
            ],
                                       dim=-1)

        atten_indices = atten_label.unsqueeze(1)  # valid_num * 1
        atten_indices = torch.cat([
            torch.arange(atten_indices.shape[0]).unsqueeze(1),
            atten_indices.cpu()
        ], 1)  # valid_num * 2
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 0,
            1)  # valid_num * wiki_sen_len * wiki_len * (2 * eh_size)
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 1,
            2)  # valid_num * wiki_len * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk(2,
                                                                 1)].squeeze(1)
        conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk(
            2, 1)].squeeze(1)

    def detail_forward_disentangle(self, incoming):
        incoming.conn = conn = Storage()
        index = incoming.state.num
        valid_sen = incoming.state.valid_sen
        valid_wiki_h_n1 = torch.index_select(
            incoming.wiki_hidden.h_n1, 1,
            valid_sen)  # [wiki_sen_num, valid_num, 2 * eh_size]
        valid_wiki_sen = torch.index_select(incoming.wiki_sen, 0, valid_sen)
        valid_wiki_h1 = torch.index_select(incoming.wiki_hidden.h1, 1,
                                           valid_sen)
        atten_label = torch.index_select(incoming.data.atten[:, index], 0,
                                         valid_sen)  # valid_num
        valid_wiki_num = torch.index_select(
            LongTensor(incoming.data.wiki_num[:, index]), 0,
            valid_sen)  # valid_num

        reverse_valid_sen = incoming.state.reverse_valid_sen
        self.beta = torch.sum(valid_wiki_h_n1 * incoming.hidden.h_n, dim=2)
        self.beta = torch.t(self.beta)  # [valid_num, wiki_len]

        mask = torch.arange(
            self.beta.shape[1], device=self.beta.device).long().expand(
                self.beta.shape[0],
                self.beta.shape[1]).transpose(0,
                                              1)  # [wiki_sen_num, valid_num]
        expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as(
            mask)  # [wiki_sen_num, valid_num]
        reverse_mask = (expand_wiki_num <=
                        mask).float()  # [wiki_sen_num, valid_num]

        if index > 0:
            wiki_hidden = incoming.wiki_hidden
            wiki_num = incoming.data.wiki_num[:, index]  # [batch], numpy array
            wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward(
                wiki_hidden.h_n1, wiki_num, need_h=True)
            valid_wiki_h2 = torch.index_select(
                wiki_hidden.h2, 1,
                valid_sen)  # wiki_len * valid_num * (2 * eh_size)

            tilde_wiki_list = []
            for i in range(self.last_wiki.size(-1)):
                last_wiki = torch.index_select(self.last_wiki[:, :, i], 0,
                                               valid_sen).unsqueeze(
                                                   0)  # 1, valid_num, (2 * eh)
                tilde_wiki = torch.tanh(
                    self.tilde_linear(
                        torch.cat([
                            last_wiki - valid_wiki_h2,
                            last_wiki * valid_wiki_h2
                        ],
                                  dim=-1)))
                tilde_wiki_list.append(
                    tilde_wiki.unsqueeze(-1) * self.hist_weights[i])
            tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum(
                dim=-1)  # wiki_len * valid_num * (2 * eh_size)

            query = self.attn_query(tilde_wiki)  # [1, valid_num, hidden]
            key = self.attn_key(
                torch.cat([valid_wiki_h2, tilde_wiki],
                          dim=-1))  # [wiki_sen_num, valid_num, hidden]
            atten_sum = self.attn_v(torch.tanh(query + key)).squeeze(
                -1)  # [wiki_sen_num, valid_num]

            self.beta = self.beta[:, :atten_sum.shape[0]] + torch.t(
                atten_sum)  #

        if index == 0:
            incoming.result.atten_loss = self.atten_lossCE(
                self.beta,  #self.alpha.t().log(),
                atten_label)
        else:
            incoming.result.atten_loss += self.atten_lossCE(
                self.beta,  #self.alpha.t().log(),
                atten_label)

        self.beta = torch.t(
            self.beta) - 1e10 * reverse_mask[:self.beta.shape[1]]
        self.alpha = self.wiki_atten(self.beta)  # wiki_len * valid_num
        incoming.acc.prob.append(
            torch.index_select(
                self.alpha.t(), 0,
                incoming.state.reverse_valid_sen).cpu().tolist())
        atten_indices = torch.argmax(self.alpha, 0)  # valid_num
        alpha = zeros(self.beta.t().shape).scatter_(1,
                                                    atten_indices.unsqueeze(1),
                                                    1)
        alpha = torch.t(alpha)
        wiki_cv = torch.sum(valid_wiki_h_n1[:alpha.shape[0]] *
                            alpha.unsqueeze(2),
                            dim=0)  # valid_num * (2 * eh_size)
        conn.wiki_cv = wiki_cv
        conn.init_h = self.initLinearLayer(
            torch.cat([incoming.hidden.h_n, wiki_cv], 1))

        if index == 0:
            self.last_wiki = torch.index_select(wiki_cv, 0,
                                                reverse_valid_sen).unsqueeze(
                                                    -1)  # [batch, 2 * eh_size]
        else:
            self.last_wiki = torch.cat([
                torch.index_select(wiki_cv, 0,
                                   reverse_valid_sen).unsqueeze(-1),
                self.last_wiki[:, :, :self.hist_len - 1]
            ],
                                       dim=-1)

        incoming.acc.label.append(
            torch.index_select(atten_label, 0,
                               reverse_valid_sen).cpu().tolist())
        incoming.acc.pred.append(
            torch.index_select(atten_indices, 0,
                               reverse_valid_sen).cpu().tolist())

        atten_indices = atten_indices.unsqueeze(1)
        atten_indices = torch.cat([
            torch.arange(atten_indices.shape[0]).unsqueeze(1),
            atten_indices.cpu()
        ], 1)  # valid_num * 2
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 0,
            1)  # valid_num * wiki_sen_len * wiki_len * (2 * eh_size)
        valid_wiki_h1 = torch.transpose(
            valid_wiki_h1, 1,
            2)  # valid_num * wiki_len * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk(
            2, 1)].squeeze(1)  # valid_num * wiki_sen_len * (2 * eh_size)
        conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk(
            2, 1)].squeeze(1)  # valid_num * wiki_sen_len
Beispiel #7
0
class GenNetwork(nn.Module):
	def __init__(self, param):
		super().__init__()
		self.args = args = param.args
		self.param = param

		self.GRULayer = MyGRU(args.embedding_size, args.dh_size, initpara=False)
		self.wLinearLayer = nn.Linear(args.dh_size, param.volatile.dm.vocab_size)
		self.lossCE = nn.CrossEntropyLoss(ignore_index=param.volatile.dm.unk_id)
		self.start_generate_id = 2

	def teacherForcing(self, inp, gen):
		embedding = inp.embedding
		length = inp.resp_length
		gen.h, _ = self.GRULayer.forward(embedding, length-1, h_init=inp.init_h, need_h=True)
		gen.w = self.wLinearLayer(gen.h)

	def freerun(self, inp, gen, mode='max'):
		batch_size = inp.batch_size
		dm = self.param.volatile.dm

		first_emb = inp.embLayer(LongTensor([dm.go_id])).repeat(batch_size, 1)
		gen.w_pro = []
		gen.w_o = []
		gen.emb = []
		flag = zeros(batch_size).byte()
		EOSmet = []

		next_emb = first_emb
		gru_h = inp.init_h
		for _ in range(self.args.max_sen_length):
			now = next_emb
			gru_h = self.GRULayer.cell_forward(now, gru_h)
			w = self.wLinearLayer(gru_h)
			gen.w_pro.append(w)
			if mode == "max":
				w_o = torch.argmax(w[:, self.start_generate_id:], dim=1) + self.start_generate_id
				next_emb = inp.embLayer(w_o)
			elif mode == "gumbel":
				w_onehot, w_o = gumbel_max(w[:, self.start_generate_id:], 1, 1)
				w_o = w_o + self.start_generate_id
				next_emb = torch.sum(torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[2:], 1)
			gen.w_o.append(w_o)
			gen.emb.append(next_emb)

			EOSmet.append(flag)
			flag = flag | (w_o == dm.eos_id)
			if torch.sum(flag).detach().cpu().numpy() == batch_size:
				break

		EOSmet = 1-torch.stack(EOSmet)
		gen.w_o = torch.stack(gen.w_o) * EOSmet.long()
		gen.emb = torch.stack(gen.emb) * EOSmet.float().unsqueeze(-1)
		gen.length = torch.sum(EOSmet, 0).detach().cpu().numpy()

	def forward(self, incoming):
		inp = Storage()
		inp.resp_length = incoming.data.resp_length
		inp.embedding = incoming.resp.embedding
		inp.init_h = incoming.conn.init_h

		incoming.gen = gen = Storage()
		self.teacherForcing(inp, gen)

		w_o_f = flattenSequence(gen.w, incoming.data.resp_length-1)
		data_f = flattenSequence(incoming.data.resp[1:], incoming.data.resp_length-1)
		incoming.result.word_loss = self.lossCE(w_o_f, data_f)
		incoming.result.perplexity = torch.exp(incoming.result.word_loss)

	def detail_forward(self, incoming):
		inp = Storage()
		batch_size = inp.batch_size = incoming.data.batch_size
		inp.init_h = incoming.conn.init_h
		inp.embLayer = incoming.resp.embLayer

		incoming.gen = gen = Storage()
		self.freerun(inp, gen)

		dm = self.param.volatile.dm
		w_o = gen.w_o.detach().cpu().numpy()
		incoming.result.resp_str = resp_str = \
				[" ".join(dm.index_to_sen(w_o[:, i].tolist())) for i in range(batch_size)]
		incoming.result.golden_str = golden_str = \
				[" ".join(dm.index_to_sen(incoming.data.resp[:, i].detach().cpu().numpy().tolist()))\
				for i in range(batch_size)]
		incoming.result.post_str = post_str = \
				[" ".join(dm.index_to_sen(incoming.data.post[:, i].detach().cpu().numpy().tolist()))\
				for i in range(batch_size)]
		incoming.result.show_str = "\n".join(["post: " + a + "\n" + "resp: " + b + "\n" + \
				"golden: " + c + "\n" \
				for a, b, c in zip(post_str, resp_str, golden_str)])