def forward(self, incoming): inp = Storage() inp.resp_length = incoming.data.resp_length inp.embedding = incoming.resp.embedding inp.init_h = incoming.conn.init_h incoming.gen = gen = Storage() self.teacherForcing(inp, gen) w_o_f = flattenSequence(gen.w, incoming.data.resp_length-1) data_f = flattenSequence(incoming.data.resp[1:], incoming.data.resp_length-1) incoming.result.word_loss = self.lossCE(w_o_f, data_f) incoming.result.perplexity = torch.exp(incoming.result.word_loss)
def forward(self, incoming): inp = Storage() inp.length = incoming.data.sent_length inp.embedding = incoming.sent.embedding incoming.gen = gen = Storage() self.teacherForcing(inp, gen) w_o_f = flattenSequence(gen.w, incoming.data.sent_length - 1) data_f = flattenSequence( incoming.data.sent.transpose(0, 1)[1:], incoming.data.sent_length - 1) incoming.result.word_loss = self.lossCE(w_o_f, data_f) incoming.result.perplexity = torch.exp(incoming.result.word_loss)
def teacherForcingForward(incoming): #Copied from GenNetwork.forward inp = Storage() inp.resp_length = incoming.data.resp_length inp.embedding = incoming.resp.embedding inp.post = incoming.hidden.h inp.post_length = incoming.data.post_length inp.init_h = incoming.conn.init_h incoming.gen = gen = Storage() self.teacherForcing(inp, gen) w_o_f = flattenSequence(gen.w, incoming.data.resp_length-1) data_f = flattenSequence(incoming.data.resp[1:], incoming.data.resp_length-1) incoming.result.word_loss = self.lossCE(w_o_f, data_f) incoming.result.perplexity = torch.exp(incoming.result.word_loss)
def forward(self, incoming): inp = Storage() inp.resp_length = incoming.data.resp_length inp.embedding = incoming.resp.embedding inp.post = incoming.hidden.h inp.post_length = incoming.data.post_length inp.init_h = incoming.conn.init_h incoming.gen = gen = Storage() self.teacherForcing(inp, gen) if self.training: incoming.result.word_loss = raml_loss( gen.w, incoming.data.resp[1:], incoming.data.resp_length - 1, incoming.data.rewards_ts, self.lossCE) else: w_o_f = flattenSequence(gen.w, incoming.data.resp_length - 1) data_f = flattenSequence(incoming.data.resp[1:], incoming.data.resp_length - 1) incoming.result.word_loss = self.lossCE(w_o_f, data_f) incoming.result.perplexity = torch.exp(incoming.result.word_loss)
def forward(self, incoming): inp = Storage() inp.embedding = incoming.resp.embedding inp.post = incoming.hidden.h inp.post_length = incoming.data.post_length inp.resp_length = incoming.data.resp_length incoming.gen = gen = Storage() inp.init_h = incoming.conn.init_h inp.embLayer = incoming.resp.embLayer inp.max_sent_length = self.args.max_sent_length inp.sampling_proba = incoming.args.sampling_proba inp.dm = self.param.volatile.dm inp.batch_size = incoming.data.batch_size self.scheduledTeacherForcing(inp, gen) w_o_f = flattenSequence(gen.w_pro, incoming.data.resp_length - 1) data_f = flattenSequence(incoming.data.resp[1:], incoming.data.resp_length - 1) incoming.result.word_loss = self.lossCE(w_o_f, data_f) incoming.result.perplexity = torch.exp(incoming.result.word_loss)
def forward(self, incoming): # incoming.data.wiki_sen: batch * wiki_len * wiki_sen_len # incoming.wiki_hidden.h1: wiki_sen_len * (batch *wiki_len) * (eh_size * 2) # incoming.wiki_hidden.h_n1: wiki_len * batch * (eh_size * 2) # incoming.wiki_hidden.h2: wiki_len * batch * (eh_size * 2) # incoming.wiki_hidden.h_n2: batch * (eh_size * 2) i = incoming.state.num valid_sen = incoming.state.valid_sen reverse_valid_sen = incoming.state.reverse_valid_sen inp = Storage() inp.wiki_sen = incoming.conn.selected_wiki_sen inp.wiki_hidden = incoming.conn.selected_wiki_h raw_resp_length = torch.tensor(incoming.data.resp_length[:, i], dtype=torch.long) raw_embedding = incoming.resp.embedding resp_length = inp.resp_length = torch.index_select( raw_resp_length, 0, valid_sen.cpu()).numpy() inp.embedding = torch.index_select( raw_embedding, 0, valid_sen).transpose(0, 1) # length * valid_num * embedding_dim resp = torch.index_select(incoming.data.resp[:, i], 0, valid_sen).transpose(0, 1)[1:] inp.init_h = incoming.conn.init_h inp.wiki_cv = incoming.conn.wiki_cv incoming.gen = gen = Storage() self.teacherForcing(inp, gen) # gen.h_n: valid_num * dh_dim w_slice = torch.index_select(gen.w, 1, reverse_valid_sen) if w_slice.shape[0] < self.args.max_sent_length: w_slice = torch.cat([ w_slice, zeros(self.args.max_sent_length - w_slice.shape[0], w_slice.shape[1], w_slice.shape[2]) ], 0) if i == 0: incoming.state.w_all = w_slice.unsqueeze(0) else: incoming.state.w_all = torch.cat([ incoming.state.w_all, w_slice.unsqueeze(0) ], 0) #state.w_all: sen_num * sen_length * batch_size * vocab_size w_o_f = flattenSequence(torch.log(gen.p), resp_length - 1) data_f = flattenSequence(resp, resp_length - 1) incoming.statistic.sen_num += incoming.state.valid_num now = 0 for l in resp_length: loss = self.lossCE(w_o_f[now:now + l - 1, :], data_f[now:now + l - 1]) if incoming.result.word_loss is None: incoming.result.word_loss = loss.clone() else: incoming.result.word_loss += loss.clone() incoming.statistic.sen_loss.append(loss.item()) now += l - 1 if i == incoming.state.last - 1: incoming.statistic.sen_loss = torch.tensor( incoming.statistic.sen_loss) incoming.result.perplexity = torch.mean( torch.exp(incoming.statistic.sen_loss))