def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): with torch.no_grad(): input_tensor = sentence2sequence(input_lang, sentence) input_length = len(input_tensor) encoder_input = LongTensor(input_tensor).view(1, -1) encoder_outputs, encoder_hidden = encoder(encoder_input, LongTensor([input_length])) decoder_hidden = encoder_hidden decoder_input = LongTensor([SOS_token]).repeat(encoder_input.shape[0], 1) decoded_words = [] for di in range(max_length): decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs, LongTensor([input_length])) # decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden) topv, topi = decoder_output.data.topk(1) if topi.item() == EOS_token: decoded_words.append('<EOS>') break else: decoded_words.append(output_lang.index2word[topi.item()]) decoder_input = topi.squeeze().detach().view(1, 1) return decoded_words
def decode_result(self, decoder_inputs, init_states, memories, target2index, index2target, max_length=50): start_decode = Variable(LongTensor([[target2index['<s>']] * 1 ])).transpose(0, 1) decodes = [] embedded = start_decode embedd_list = [] embedd_list.append(target2index['<s>']) # while decoded.data.tolist()[0] != target2index['</s>'] and max_length > len(decodes): for t in range(max_length): _, hidden = self.decode(embedded, init_states, memories) softmaxed = F.log_softmax(hidden) decodes.append(softmaxed) decoded = softmaxed.max(1)[1] embedd_list.append(decoded.data.tolist()[0]) embedded = Variable(LongTensor([embedd_list * 1])) if index2target[decoded.data.tolist()[0]] == '</s>' or ( t != 0 and index2target[decoded.data.tolist()[0]] == '<s>'): break # context, alpha = self.Attention(hidden, decoder_inputs) # attentions.append(alpha.squeeze(1)) print(embedded.size()) return torch.cat(decodes).max(1)[1]
def decode(self, h, mask): # Viterbi decoding # initialize backpointers and viterbi variables in log space bptr = LongTensor() score = Tensor(BATCH_SIZE, self.num_tags).fill_(-10000.) score[:, SOS_IDX] = 0. for t in range(h.size(1)): # recursion through the sequence mask_t = mask[:, t].unsqueeze(1) score_t = score.unsqueeze(1) + self.trans # [B, 1, C] -> [B, C, C] score_t, bptr_t = score_t.max(2) # best previous scores and tags score_t += h[:, t] # plus emission scores bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1) score = score_t * mask_t + score * (1 - mask_t) score += self.trans[EOS_IDX] best_score, best_tag = torch.max(score, 1) # back-tracking bptr = bptr.tolist() best_path = [[i] for i in best_tag.tolist()] for b in range(BATCH_SIZE): x = best_tag[b] # best tag y = int(mask[b].sum().item()) for bptr_t in reversed(bptr[b][:y]): x = bptr_t[x] best_path[b].append(x) best_path[b].pop() best_path[b].reverse() return best_path
def batch_generator(*arrays, batch_size=32, should_shuffle=False): input, target = arrays if should_shuffle: from sklearn.utils import shuffle input, target = shuffle(input, target) num_instances = len(input) batch_count = int(numpy.ceil(num_instances / batch_size)) progress = tqdm.tqdm(total=num_instances) input_length_in_words = numpy.array([len(seq) for seq in input], dtype=numpy.int32) target_length_in_words = numpy.array([len(seq) for seq in target], dtype=numpy.int32) for idx in range(batch_count): startIdx = idx * batch_size endIdx = (idx + 1) * batch_size if (idx + 1) * batch_size < num_instances else num_instances batch_input_lengths = input_length_in_words[startIdx:endIdx] input_maxlength = batch_input_lengths.max() input_lengths_argsort = \ numpy.argsort(batch_input_lengths)[::-1].copy() # without the copy torch complains about negative strides batch_target_lengths = target_length_in_words[startIdx:endIdx] target_maxlength = batch_target_lengths.max() batch_input = LongTensor([input_seq + (PAD_IDX,) * (input_maxlength - len(input_seq)) for input_seq in input[startIdx:endIdx]]) batch_target = LongTensor([target_seq + (PAD_IDX,) * (target_maxlength - len(target_seq)) for target_seq in target[startIdx:endIdx]]) progress.update(len(batch_input_lengths)) yield batch_input[input_lengths_argsort], LongTensor(batch_input_lengths)[input_lengths_argsort], \ batch_target[input_lengths_argsort], LongTensor(batch_target_lengths)[input_lengths_argsort] progress.close()
def batch_generator(*arrays, batch_size=32): word_id_lists, tag_id_lists, char_id_lists, seq_length_in_words, word_lengths = arrays word_id_lists, tag_id_lists, char_id_lists, seq_length_in_words, word_lengths = \ shuffle(word_id_lists, tag_id_lists, char_id_lists, seq_length_in_words, word_lengths) num_instances = len(word_id_lists) batch_count = int(numpy.ceil(num_instances / batch_size)) from tqdm import tqdm prog = tqdm(total=num_instances) for idx in range(batch_count): startIdx = idx * batch_size endIdx = (idx + 1) * batch_size if ( idx + 1) * batch_size < num_instances else num_instances batch_lengths = seq_length_in_words[startIdx:endIdx] batch_maxlen = batch_lengths.max() argsort = numpy.argsort(batch_lengths)[::-1].copy( ) # without the copy torch complains about negative strides char_batch = numpy.array(char_id_lists[startIdx:endIdx])[argsort] # make each sentence in batch contain same number of words char_batch = [ sentence + ((0, ), ) * (batch_maxlen - len(sentence)) for sentence in char_batch ] word_lengths = [ len(word) for sentence in char_batch for word in sentence ] max_word_length = max(word_lengths) # make each word in batch contain same number of chars chars = [ word + (0, ) * (max_word_length - len(word)) for sentence in char_batch for word in sentence ] chars = LongTensor(chars) words = LongTensor([ word_ids + (PAD_IDX, ) * (batch_maxlen - len(word_ids)) for word_ids in word_id_lists[startIdx:endIdx] ]) tags = LongTensor([ tag_ids + (PAD_IDX, ) * (batch_maxlen - len(tag_ids)) for tag_ids in tag_id_lists[startIdx:endIdx] ]) prog.update(len(batch_lengths)) yield words[argsort], chars, tags[argsort] prog.close()
def detail_forward(self, incoming): i = incoming.state.num incoming.post = Storage() incoming.post.embedding = self.embLayer(LongTensor(incoming.data.post)) incoming.resp = Storage() incoming.wiki = Storage() incoming.wiki.embedding = self.embLayer(incoming.data.wiki[:, i]) incoming.resp.embLayer = self.embLayer
def decode(self, h, mask): # Viterbi decoding # initialize backpointers and viterbi variables in log space backpointers = LongTensor() batch_size = h.shape[0] delta = Tensor(batch_size, self.num_tags).fill_(-10000.) delta[:, START_TAG_IDX] = 0. # TODO: is adding stop tag within loop needed at all??? # pro argument: yes, backpointers needed at every step - to be checked for t in range(h.size(1)): # iterate through the sequence # backpointers and viterbi variables at this timestep mask_t = mask[:, t].unsqueeze(1) # TODO: maybe unsqueeze transition explicitly for 0 dim for clarity next_tag_var = delta.unsqueeze(1) + self.transition # B x 1 x S + S x S delta_t, backpointers_t = next_tag_var.max(2) backpointers = torch.cat((backpointers, backpointers_t.unsqueeze(1)), 1) delta_next = delta_t + h[:, t] # plus emission scores delta = mask_t * delta_next + (1 - mask_t) * delta # TODO: check correctness # for those that end here add score for transitioning to stop tag if t + 1 < h.size(1): # mask_next = mask[:, t + 1].unsqueeze(1) # ending = mask_next.eq(0.).float().expand(batch_size, self.num_tags) # delta += ending * self.transition[STOP_TAG_IDX].unsqueeze(0) # or ending_here = (mask[:, t].eq(1.) * mask[:, t + 1].eq(0.)).view(1, -1).float() delta += ending_here.transpose(0, 1).mul(self.transition[STOP_TAG_IDX]) # add outer product of two vecs # TODO: check equality of these two again # TODO: should we add transition values for getting in stop state only for those that end here? # TODO: or to all? delta += mask[:, -1].view(1, -1).float().transpose(0, 1).mul(self.transition[STOP_TAG_IDX]) best_score, best_tag = torch.max(delta, 1) # back-tracking backpointers = backpointers.tolist() best_path = [[i] for i in best_tag.tolist()] for idx in range(batch_size): prev_best_tag = best_tag[idx] # best tag id for single instance length = int(scalar(mask[idx].sum())) # length of instance for backpointers_t in reversed(backpointers[idx][:length]): prev_best_tag = backpointers_t[prev_best_tag] best_path[idx].append(prev_best_tag) best_path[idx].pop() # remove start tag best_path[idx].reverse() return best_path
def nextStep(x, flag=None, regroup=None): nonlocal step, batch_size, top_k # regroup: batch * top_k regroup = regroup + LongTensor(list( range(batch_size))).unsqueeze(1) * top_k regroup = regroup.reshape(-1) x = x.reshape(batch_size * top_k, -1) x = step(x, regroup=regroup) x = x.reshape(batch_size, top_k, -1) return x
def positional_embeddings(seqlen, first_emb, reverse_emb=None, length=None): # first_emb: max_length * embedding encodings = first_emb.unsqueeze(1).expand(-1, seqlen, -1).\ gather(0, cuda(torch.arange(seqlen)).unsqueeze(-1).expand(-1, first_emb.shape[1]).unsqueeze(0))[0] if length is None: assert reverse_emb is None return encodings.unsqueeze(0) else: batch_size = len(length) reversed_id = np.zeros((batch_size, seqlen)) for i, l in enumerate(length): reversed_id[i, :l] = np.arange(l - 1, -1, -1) reversed_id = LongTensor(reversed_id) encodings_reversed = reverse_emb.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, seqlen, -1).\ gather(1, reversed_id.unsqueeze(1).unsqueeze(-1).expand(-1, -1, -1, reverse_emb.shape[-1]))[:, 0] return torch.cat([ encodings.unsqueeze(0).expand(batch_size, -1, -1), encodings_reversed ], dim=2)
def test(self, key): args = self.param.args dm = self.param.volatile.dm metric1 = dm.get_teacher_forcing_metric() batch_num, batches = self.get_batches(dm, key) logging.info("eval teacher-forcing") for incoming in tqdm.tqdm(batches, total=batch_num): incoming.args = Storage() incoming.args.sampling_proba = 1. with torch.no_grad(): self.net.forward(incoming) gen_log_prob = nn.functional.log_softmax( incoming.gen.w_pro, -1) data = incoming.data data.resp_allvocabs = LongTensor(incoming.data.resp_allvocabs) data.resp_length = incoming.data.resp_length data.gen_log_prob = gen_log_prob.transpose(1, 0) metric1.forward(data) res = metric1.close() metric2 = dm.get_inference_metric() batch_num, batches = self.get_batches(dm, key) logging.info("eval free-run") for incoming in tqdm.tqdm(batches, total=batch_num): incoming.args = Storage() with torch.no_grad(): self.net.detail_forward(incoming) data = incoming.data data.gen = incoming.gen.w_o.detach().cpu().numpy().transpose(1, 0) metric2.forward(data) res.update(metric2.close()) if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) filename = args.out_dir + "/%s_%s.txt" % (args.name, key) with codecs.open(filename, 'w', encoding='utf8') as f: logging.info("%s Test Result:", key) for key, value in res.items(): if isinstance(value, float) or isinstance(value, str): logging.info("\t{}:\t{}".format(key, value)) f.write("{}:\t{}\n".format(key, value)) for i in range(len(res['post'])): f.write("post:\t%s\n" % " ".join(res['post'][i])) f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n" % " ".join(res['gen'][i])) f.flush() logging.info("result output to %s.", filename) return { key: val for key, val in res.items() if isinstance(val, (str, int, float)) }
def sample_image(args, generator, n_row, batches_done): """Saves a grid of generated digits ranging from 0 to n_classes""" # Sample noise z = Variable( FloatTensor(np.random.normal(0, 1, (n_row**2, args.latent_dim)))) # Get labels for the n rows labels = np.array([num for _ in range(n_row) for num in range(n_row)]) labels = Variable(LongTensor(labels)) gen_imgs = generator(z, labels) save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
def train( encoder_input, input_lengths, target_tensor, target_lengths, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion): encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() loss = 0 encoder_outputs, encoder_hidden = encoder(encoder_input, input_lengths) decoder_hidden = encoder_hidden decoder_input = LongTensor([SOS_token]).repeat(encoder_input.shape[0], 1) # one for each instance in batch # use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False use_teacher_forcing = False if use_teacher_forcing: # TODO: adapt teacher forcing # Teacher forcing: Feed the target as the next input for idx in range(target_lengths.shape[1]): decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs) loss += criterion(decoder_output, target_tensor[idx]) decoder_input = target_tensor[idx] # Teacher forcing else: # Without teacher forcing: use its own predictions as the next input max_target_length = target_lengths.max().item() target_lengths_copy = target_lengths.clone() for idx in range(max_target_length): decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs, input_lengths) mask = target_lengths_copy > PAD_IDX target_lengths_copy -= 1 masked_output = decoder_output * mask.unsqueeze(1).float() topv, topi = masked_output.topk(1) decoder_input = topi.squeeze().detach() # detach from history as input loss += criterion(masked_output[mask], target_tensor[:, idx][mask]) # or alternative below # for instance_idx, target_word in enumerate(target_tensor[:, idx]): # if idx < target_lengths[instance_idx]: # loss += criterion(masked_output[instance_idx].view(1, -1), # target_word.view(1)) loss.backward() encoder_optimizer.step() decoder_optimizer.step() return loss.item() / target_lengths.sum().item()
def forward(self, incoming): ''' inp: data output: post ''' i = incoming.state.num incoming.post = Storage() incoming.post.embedding = self.drop( self.embLayer(LongTensor(incoming.data.post[:, i]))) incoming.resp = Storage() incoming.resp.embedding = self.drop( self.embLayer(incoming.data.resp[:, i])) incoming.wiki = Storage() incoming.wiki.embedding = self.drop( self.embLayer(incoming.data.wiki[:, i])) incoming.resp.embLayer = self.embLayer
def __predict_sentence(self, src_batch): """ predict sentence :param src_batch: get the source sentence :return: """ hyp_batch = '' inputs = prepare_sequence(['<s>'] + src_batch + ['</s>'], self.data_model.source2index).view(1, -1) start_decode = Variable(LongTensor([[self.data_model.target2index['<s>']] * inputs.size(1)])) show_preds = self.qrnn(inputs, [inputs.size(1)], start_decode) outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1) for pred in outputs.data.tolist(): for each_pred in pred: hyp_batch += self.data_model.index2target[each_pred] hyp_batch = hyp_batch.replace('<s>', '') hyp_batch = hyp_batch.replace('</s>', '') return hyp_batch
def score(self, h, y, mask): # calculate the score of a given sequence batch_size = h.shape[0] score = Tensor(batch_size).fill_(0.) # TODO: maybe instead of unsqueezing following two separately do it after sum in line for score calculation # TODO: check if unsqueezing needed at all h = h.unsqueeze(3) transition = self.transition.unsqueeze(2) y = torch.cat([LongTensor([START_TAG_IDX]).view(1, -1).expand(batch_size, 1), y], 1) # add start tag to begin # TODO: the loop can be vectorized, probably for t in range(h.size(1)): # iterate through the sequence mask_t = mask[:, t] emission = torch.cat([h[i, t, y[i, t + 1]] for i in range(batch_size)]) transition_t = torch.cat([transition[seq[t + 1], seq[t]] for seq in y]) score += (emission + transition_t) * mask_t # get transitions from last tags to stop tag: use gather to get last time step lengths = mask.sum(1).long() indices = lengths.unsqueeze(1) # we can safely use lengths as indices, because we prepended start tag to y last_tags = y.gather(1, indices).squeeze() score += self.transition[STOP_TAG_IDX, last_tags] return score
def detail_forward(self, incoming): incoming.hidden = hidden = Storage() # incoming.post.embedding : batch * sen_num * length * vec_dim # post_length : batch * sen_num raw_post = incoming.post.embedding raw_post_length = LongTensor(incoming.data.post_length) incoming.state.valid_sen = torch.sum(torch.nonzero(raw_post_length), 1) raw_reverse = torch.cumsum(torch.gt(raw_post_length, 0), 0) - 1 incoming.state.reverse_valid_sen = raw_reverse * torch.ge( raw_reverse, 0).to(torch.long) valid_sen = incoming.state.valid_sen incoming.state.valid_num = valid_sen.shape[0] post = torch.index_select(raw_post, 0, valid_sen).transpose( 0, 1) # [length, valid_num, vec_dim] post_length = torch.index_select( raw_post_length, 0, valid_sen).cpu().numpy() # [valid_num] hidden.h, hidden.h_n = self.postGRU.forward(post, post_length, need_h=True) hidden.length = post_length
def freerun(self, inp, gen, mode='max'): batch_size = inp.batch_size dm = self.param.volatile.dm first_emb = inp.embLayer(LongTensor([dm.go_id])).repeat(batch_size, 1) gen.w_pro = [] gen.w_o = [] gen.emb = [] flag = zeros(batch_size).byte() EOSmet = [] next_emb = first_emb gru_h = self.GRULayer.getInitialParameter(batch_size)[0] for _ in range(self.args.max_sen_length): now = next_emb gru_h = self.GRULayer.cell_forward(now, gru_h) w = self.wLinearLayer(gru_h) gen.w_pro.append(w) if mode == "max": w_o = torch.argmax(w[:, self.start_generate_id:], dim=1) + self.start_generate_id next_emb = inp.embLayer(w_o) elif mode == "gumbel": w_onehot, w_o = gumbel_max(w[:, self.start_generate_id:], 1, 1) w_o = w_o + self.start_generate_id next_emb = torch.sum( torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[2:], 1) gen.w_o.append(w_o) gen.emb.append(next_emb) EOSmet.append(flag) flag = flag | (w_o == dm.eos_id) if torch.sum(flag).detach().cpu().numpy() == batch_size: break EOSmet = 1 - torch.stack(EOSmet) gen.w_o = torch.stack(gen.w_o) * EOSmet.long() gen.emb = torch.stack(gen.emb) * EOSmet.float().unsqueeze(-1) gen.length = torch.sum(EOSmet, 0).detach().cpu().numpy()
def train(generator, discriminator, dataloader, args, cuda, adversarial_loss, auxiliary_loss): optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) for epoch in range(args.n_epochs): for i, (imgs, labels) in enumerate(dataloader): batch_size = imgs.shape[0] # Adversarial ground truths valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # Configure input real_imgs = Variable(imgs.type(FloatTensor)) labels = Variable(labels.type(LongTensor)) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # Sample noise as generator input z = Variable( FloatTensor( np.random.normal(0, 1, (batch_size, args.latent_dim)))) gen_labels = Variable( LongTensor(np.random.randint(0, args.n_classes, batch_size))) # Generate a batch of images gen_imgs = generator(z, gen_labels) # Loss measures generator's ability to fool the discriminator validity, pred_label = discriminator(gen_imgs) g_loss = 0.5 * adversarial_loss(validity, valid) + auxiliary_loss( pred_label, gen_labels) g_loss.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Loss for real images real_pred, real_aux = discriminator(real_imgs) d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2 # Loss for fake images fake_pred, fake_aux = discriminator(gen_imgs.detach()) d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2 # Measure discriminator's ability to classify real from generated samples d_loss = (d_real_loss + d_fake_loss) / 2 # Calculate discriminator accuracy pred = np.concatenate( [real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0) gt = np.concatenate( [labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0) d_acc = np.mean(np.argmax(pred, axis=1) == gt) d_loss.backward() optimizer_D.step() print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, args.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())) batches_done = epoch * len(dataloader) + i if batches_done % args.sample_interval == 0: sample_image(args, generator, n_row=10, batches_done=batches_done)
def freerun(self, inp, gen, mode='max'): batch_size = inp.batch_size dm = self.param.volatile.dm first_emb = inp.embLayer(LongTensor([dm.go_id])).repeat(batch_size, 1) gen.w_pro = [] gen.w_o = [] gen.emb = [] flag = zeros(batch_size).byte() EOSmet = [] inp.wiki_sen = inp.wiki_sen[:, :inp.wiki_hidden.shape[1]] copyHead = zeros(1, inp.wiki_sen.shape[0], inp.wiki_hidden.shape[1], self.param.volatile.dm.vocab_size).scatter_( 3, torch.unsqueeze(torch.unsqueeze(inp.wiki_sen, 0), 3), 1) wikiState = torch.transpose( torch.tanh(self.wCopyLinear(inp.wiki_hidden)), 0, 1) next_emb = first_emb gru_h = inp.init_h gen.p = [] wiki_cv = inp.wiki_cv # valid_num * (2 * eh_size) for _ in range(self.args.max_sent_length): now = torch.cat([next_emb, wiki_cv], dim=-1) gru_h = self.GRULayer.cell_forward(now, gru_h) w = self.wLinearLayer(gru_h) w = torch.clamp(w, max=5.0) vocab_p = torch.exp(w) copyW = torch.exp( torch.clamp(torch.unsqueeze( (torch.sum(torch.unsqueeze(gru_h, 0) * wikiState, -1).transpose_(0, 1)), 1), max=5.0)) # batch * 1 * wiki_len copy_p = torch.matmul(copyW, copyHead).squeeze() p = vocab_p + copy_p + 1e-10 p = p / torch.unsqueeze(torch.sum(p, 1), 1) p = torch.clamp(p, 1e-10, 1.0) gen.p.append(p) if mode == "max": w_o = torch.argmax(p[:, self.start_generate_id:], dim=1) + self.start_generate_id next_emb = inp.embLayer(w_o) elif mode == "gumbel": w_onehot, w_o = gumbel_max(p[:, self.start_generate_id:], 1, 1) w_o = w_o + self.start_generate_id next_emb = torch.sum( torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[2:], 1) gen.w_o.append(w_o) gen.emb.append(next_emb) EOSmet.append(flag) flag = flag | (w_o == dm.eos_id).byte() if torch.sum(flag).detach().cpu().numpy() == batch_size: break EOSmet = 1 - torch.stack(EOSmet) gen.w_o = torch.stack(gen.w_o) * EOSmet.long() gen.emb = torch.stack(gen.emb) * EOSmet.float().unsqueeze(-1) gen.length = torch.sum(EOSmet, 0).detach().cpu().numpy() gen.h_n = gru_h
def detail_forward_disentangle(self, incoming): incoming.conn = conn = Storage() index = incoming.state.num valid_sen = incoming.state.valid_sen valid_wiki_h_n1 = torch.index_select( incoming.wiki_hidden.h_n1, 1, valid_sen) # [wiki_sen_num, valid_num, 2 * eh_size] valid_wiki_sen = torch.index_select(incoming.wiki_sen, 0, valid_sen) valid_wiki_h1 = torch.index_select(incoming.wiki_hidden.h1, 1, valid_sen) atten_label = torch.index_select(incoming.data.atten[:, index], 0, valid_sen) # valid_num valid_wiki_num = torch.index_select( LongTensor(incoming.data.wiki_num[:, index]), 0, valid_sen) # valid_num reverse_valid_sen = incoming.state.reverse_valid_sen self.beta = torch.sum(valid_wiki_h_n1 * incoming.hidden.h_n, dim=2) self.beta = torch.t(self.beta) # [valid_num, wiki_len] mask = torch.arange( self.beta.shape[1], device=self.beta.device).long().expand( self.beta.shape[0], self.beta.shape[1]).transpose(0, 1) # [wiki_sen_num, valid_num] expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as( mask) # [wiki_sen_num, valid_num] reverse_mask = (expand_wiki_num <= mask).float() # [wiki_sen_num, valid_num] if index > 0: wiki_hidden = incoming.wiki_hidden wiki_num = incoming.data.wiki_num[:, index] # [batch], numpy array wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward( wiki_hidden.h_n1, wiki_num, need_h=True) valid_wiki_h2 = torch.index_select( wiki_hidden.h2, 1, valid_sen) # wiki_len * valid_num * (2 * eh_size) tilde_wiki_list = [] for i in range(self.last_wiki.size(-1)): last_wiki = torch.index_select(self.last_wiki[:, :, i], 0, valid_sen).unsqueeze( 0) # 1, valid_num, (2 * eh) tilde_wiki = torch.tanh( self.tilde_linear( torch.cat([ last_wiki - valid_wiki_h2, last_wiki * valid_wiki_h2 ], dim=-1))) tilde_wiki_list.append( tilde_wiki.unsqueeze(-1) * self.hist_weights[i]) tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum( dim=-1) # wiki_len * valid_num * (2 * eh_size) query = self.attn_query(tilde_wiki) # [1, valid_num, hidden] key = self.attn_key( torch.cat([valid_wiki_h2, tilde_wiki], dim=-1)) # [wiki_sen_num, valid_num, hidden] atten_sum = self.attn_v(torch.tanh(query + key)).squeeze( -1) # [wiki_sen_num, valid_num] self.beta = self.beta[:, :atten_sum.shape[0]] + torch.t( atten_sum) # if index == 0: incoming.result.atten_loss = self.atten_lossCE( self.beta, #self.alpha.t().log(), atten_label) else: incoming.result.atten_loss += self.atten_lossCE( self.beta, #self.alpha.t().log(), atten_label) self.beta = torch.t( self.beta) - 1e10 * reverse_mask[:self.beta.shape[1]] self.alpha = self.wiki_atten(self.beta) # wiki_len * valid_num incoming.acc.prob.append( torch.index_select( self.alpha.t(), 0, incoming.state.reverse_valid_sen).cpu().tolist()) atten_indices = torch.argmax(self.alpha, 0) # valid_num alpha = zeros(self.beta.t().shape).scatter_(1, atten_indices.unsqueeze(1), 1) alpha = torch.t(alpha) wiki_cv = torch.sum(valid_wiki_h_n1[:alpha.shape[0]] * alpha.unsqueeze(2), dim=0) # valid_num * (2 * eh_size) conn.wiki_cv = wiki_cv conn.init_h = self.initLinearLayer( torch.cat([incoming.hidden.h_n, wiki_cv], 1)) if index == 0: self.last_wiki = torch.index_select(wiki_cv, 0, reverse_valid_sen).unsqueeze( -1) # [batch, 2 * eh_size] else: self.last_wiki = torch.cat([ torch.index_select(wiki_cv, 0, reverse_valid_sen).unsqueeze(-1), self.last_wiki[:, :, :self.hist_len - 1] ], dim=-1) incoming.acc.label.append( torch.index_select(atten_label, 0, reverse_valid_sen).cpu().tolist()) incoming.acc.pred.append( torch.index_select(atten_indices, 0, reverse_valid_sen).cpu().tolist()) atten_indices = atten_indices.unsqueeze(1) atten_indices = torch.cat([ torch.arange(atten_indices.shape[0]).unsqueeze(1), atten_indices.cpu() ], 1) # valid_num * 2 valid_wiki_h1 = torch.transpose( valid_wiki_h1, 0, 1) # valid_num * wiki_sen_len * wiki_len * (2 * eh_size) valid_wiki_h1 = torch.transpose( valid_wiki_h1, 1, 2) # valid_num * wiki_len * wiki_sen_len * (2 * eh_size) conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk( 2, 1)].squeeze(1) # valid_num * wiki_sen_len * (2 * eh_size) conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk( 2, 1)].squeeze(1) # valid_num * wiki_sen_len
def forward(self, incoming): incoming.conn = conn = Storage() index = incoming.state.num valid_sen = incoming.state.valid_sen valid_wiki_h_n1 = torch.index_select( incoming.wiki_hidden.h_n1, 1, valid_sen) # [wiki_sen_num, valid_num, 2 * eh_size] valid_wiki_sen = torch.index_select( incoming.wiki_sen, 0, valid_sen) # [valid_num, wiki_sen_num, wiki_sen_len] valid_wiki_h1 = torch.index_select( incoming.wiki_hidden.h1, 1, valid_sen) # [wiki_sen_len, valid_num, wiki_sen_num, 2 * eh_size] atten_label = torch.index_select(incoming.data.atten[:, index], 0, valid_sen) # valid_num valid_wiki_num = torch.index_select( LongTensor(incoming.data.wiki_num[:, index]), 0, valid_sen) # valid_num if index == 0: tilde_wiki = zeros(1, 1, 2 * self.args.eh_size) * ones( valid_wiki_h_n1.shape[0], valid_wiki_h_n1.shape[1], 1) else: wiki_hidden = incoming.wiki_hidden wiki_num = incoming.data.wiki_num[:, index] # [batch], numpy array wiki_hidden.h2, wiki_hidden.h_n2 = self.compareGRU.forward( wiki_hidden.h_n1, wiki_num, need_h=True) valid_wiki_h2 = torch.index_select( wiki_hidden.h2, 1, valid_sen) # wiki_len * valid_num * (2 * eh_size) tilde_wiki_list = [] for i in range(self.last_wiki.size(-1)): last_wiki = torch.index_select(self.last_wiki[:, :, i], 0, valid_sen).unsqueeze( 0) # 1, valid_num, (2 * eh) tilde_wiki = torch.tanh( self.tilde_linear( torch.cat([ last_wiki - valid_wiki_h2, last_wiki * valid_wiki_h2 ], dim=-1))) tilde_wiki_list.append( tilde_wiki.unsqueeze(-1) * self.hist_weights[i]) tilde_wiki = torch.cat(tilde_wiki_list, dim=-1).sum(dim=-1) query = self.attn_query(incoming.hidden.h_n) # [valid_num, hidden] key = self.attn_key( torch.cat([valid_wiki_h_n1[:tilde_wiki.shape[0]], tilde_wiki], dim=-1)) # [wiki_sen_num, valid_num, hidden] atten_sum = self.attn_v(torch.tanh(query + key)).squeeze( -1) # [wiki_sen_num, valid_num] beta = atten_sum.t() # [valid_num, wiki_len] mask = torch.arange(beta.shape[1], device=beta.device).long().expand( beta.shape[0], beta.shape[1]).transpose(0, 1) # [wiki_sen_num, valid_num] expand_wiki_num = valid_wiki_num.unsqueeze(0).expand_as( mask) # [wiki_sen_num, valid_num] reverse_mask = (expand_wiki_num <= mask).float() # [wiki_sen_num, valid_num] if index == 0: incoming.result.atten_loss = self.atten_lossCE(beta, atten_label) else: incoming.result.atten_loss += self.atten_lossCE(beta, atten_label) golden_alpha = zeros(beta.shape).scatter_(1, atten_label.unsqueeze(1), 1) golden_alpha = torch.t(golden_alpha).unsqueeze(2) wiki_cv = torch.sum(valid_wiki_h_n1[:golden_alpha.shape[0]] * golden_alpha, dim=0) # valid_num * (2 * eh_size) conn.wiki_cv = wiki_cv conn.init_h = self.initLinearLayer( torch.cat([incoming.hidden.h_n, wiki_cv], 1)) reverse_valid_sen = incoming.state.reverse_valid_sen if index == 0: self.last_wiki = torch.index_select(wiki_cv, 0, reverse_valid_sen).unsqueeze( -1) # [batch, 2 * eh_size] else: self.last_wiki = torch.cat([ torch.index_select(wiki_cv, 0, reverse_valid_sen).unsqueeze(-1), self.last_wiki[:, :, :self.hist_len - 1] ], dim=-1) atten_indices = atten_label.unsqueeze(1) # valid_num * 1 atten_indices = torch.cat([ torch.arange(atten_indices.shape[0]).unsqueeze(1), atten_indices.cpu() ], 1) # valid_num * 2 valid_wiki_h1 = torch.transpose( valid_wiki_h1, 0, 1) # valid_num * wiki_sen_len * wiki_len * (2 * eh_size) valid_wiki_h1 = torch.transpose( valid_wiki_h1, 1, 2) # valid_num * wiki_len * wiki_sen_len * (2 * eh_size) conn.selected_wiki_h = valid_wiki_h1[atten_indices.chunk(2, 1)].squeeze(1) conn.selected_wiki_sen = valid_wiki_sen[atten_indices.chunk( 2, 1)].squeeze(1)
def main(): parser = argparse.ArgumentParser() parser.add_argument( "-t", "--train_data", metavar="train_data", type=str, default='../data/processed/source_replay_twitter_data.txt', dest="train_data", help="set the training data ") parser.add_argument("-e", "--embedding_size", metavar="embedding_size", type=int, default=50, dest="embedding_size", help="set the embedding size ") parser.add_argument("-H", "--hidden_size", metavar="hidden_size", type=int, default=512, dest="hidden_size", help="set the hidden size ") parser.add_argument("-f", "--fine_tune_model_name", metavar="fine_tune_model_name", type=str, default='../models/glove_model_40.pth', dest="fine_tune_model_name", help="set the fine tune model name ") parser.add_argument("-n", "--num_layers", metavar="num_layers", type=int, default=2, dest="num_layers", help="set the layer number") parser.add_argument("-k", "--kernel_size", metavar="kernel_size", type=int, default=2, dest="kernel_size", help="set the kernel_size") batch_size = 64 args = parser.parse_args() test_data_loader_attention = DataLoaderAttention(file_name=args.train_data) source2index, index2source, target2index, index2target, train_data = \ test_data_loader_attention.load_data() encoder_model_name = '../models/qrnn_encoder_model_285.pth' decoder_model_name = '../models/qrnn_decoder_model_285.pth' proj_linear_model_name = '../models/qrnn_proj_linear_model_285.pth' HIDDEN_SIZE = args.hidden_size NUM_LAYERS = args.num_layers KERNEL_SIZE = args.kernel_size EMBEDDING_SIZE = args.embedding_size SOURCE_VOCAB_SIZE = len(source2index) TARGET_VOCAB_SIZE = len(target2index) ZONE_OUT = 0.0 TRAINING = False DROPOUT = 0.0 qrnn = QRNNModel(QRNNLayer, NUM_LAYERS, KERNEL_SIZE, HIDDEN_SIZE, EMBEDDING_SIZE, SOURCE_VOCAB_SIZE, TARGET_VOCAB_SIZE, ZONE_OUT, TRAINING, DROPOUT) qrnn.encoder = torch.load(encoder_model_name) qrnn.decoder = torch.load(decoder_model_name) qrnn.proj_linear = torch.load(proj_linear_model_name) test = random.choice(train_data) inputs = test[0] truth = test[1] print(inputs) print(truth) start_decode = Variable(LongTensor([[target2index['<s>']] * truth.size(1) ])) show_preds = qrnn(inputs, [inputs.size(1)], start_decode) outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1) show_sentence(truth, inputs, outputs.data.tolist(), index2source, index2target)
def train_attention(self, train_data: list=[], source2index: list=[], target2index: list=[], index2source: list=[], index2target: list=[], encoder_model: object=None, decoder_model: object=None): encoder_model.init_weight() decoder_model.init_weight() encoder_model, decoder_model = self.__fine_tune_weight(encoder_model=encoder_model, decoder_model=decoder_model) if USE_CUDA: encoder_model = encoder_model.cuda() decoder_model = decoder_model.cuda() loss_function = nn.CrossEntropyLoss(ignore_index=0) encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr) decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=self.lr * self.decoder_learning_rate) for epoch in range(self.epoch): losses = [] for i, batch in enumerate(get_batch(self.batch_size, train_data)): inputs, targets, input_lengths, target_lengths = \ pad_to_batch(batch, source2index, target2index) input_mask = torch.cat([Variable(ByteTensor( tuple(map(lambda s: s == 0, t.data)))) for t in inputs]).view(inputs.size(0), -1) start_decode = Variable(LongTensor([[target2index['<s>']] * targets.size(0)])).transpose(0, 1) encoder_model.zero_grad() decoder_model.zero_grad() output, hidden_c = encoder_model(inputs, input_lengths) preds = decoder_model(start_decode, hidden_c, targets.size(1), output, input_mask, True) loss = loss_function(preds, targets.view(-1)) losses.append(loss.data.tolist()[0]) loss.backward() torch.nn.utils.clip_grad_norm(encoder_model.parameters(), 50.0) torch.nn.utils.clip_grad_norm(decoder_model.parameters(), 50.0) encoder_optimizer.step() decoder_optimizer.step() if i % 200 == 0: test = random.choice(train_data) inputs = test[0] output_c, hidden = encoder_model(inputs, [inputs.size(1)]) show_preds, _ = decoder_model.decode(hidden, output_c, target2index, index2target) show_preds = decoder_model(start_decode, hidden_c, targets.size(1), output, input_mask, True) outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1) show_sentence(targets, inputs, outputs.data.tolist(), index2source, index2target) print("[%02d/%d] [%03d/%d] mean_loss : %0.2f" %(epoch, self.epoch, i, len(train_data) // self.batch_size, np.mean(losses))) self.__save_model_info(inputs, epoch, losses) torch.save(encoder_model, './../models/encoder_model_{0}.pth'.format(epoch)) torch.save(decoder_model, './../models/decoder_model_{0}.pth'.format(epoch)) losses=[] if self.rescheduled is False and epoch == self.epoch // 2: self.lr = self.lr * 0.01 encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr) decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=self.lr * self.decoder_learning_rate) self.rescheduled = True self.writer.export_scalars_to_json("./all_scalars.json") self.writer.close()
def train_qrnn(self, train_data: list=[], source2index: list=[], target2index: list=[], index2source: list=[], index2target: list=[], qrnn_model: object=None): # qrnn_model.encoder, qrnn_model.decoder = self.__fine_tune_weight( # encoder_model=qrnn_model.encoder, # decoder_model=qrnn_model.decoder) if USE_CUDA: qrnn_model = qrnn_model.cuda() encoder_model = qrnn_model.encoder.cuda() decoder_model = qrnn_model.decoder.cuda() # proj_linear_model = qrnn_model.proj_linear.cuda() loss_function = nn.CrossEntropyLoss(ignore_index=0) # qrnn_optimizer = optim.Adam(qrnn_model.parameters(), lr=self.lr) encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr) decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=self.lr) # proj_linear_optimizer = optim.Adam(proj_linear_model.parameters(), lr=self.lr) for epoch in range(self.epoch): losses = [] for i, batch in enumerate(get_batch(self.batch_size, train_data)): inputs, targets, input_lengths, target_lengths = \ pad_to_batch(batch, source2index, target2index) qrnn_model.zero_grad() start_decode = Variable(LongTensor([[target2index['<s>']] * targets.size(1)])) preds = qrnn_model(inputs, input_lengths, start_decode) loss = loss_function(preds, targets.view(-1)) losses.append(loss.data.tolist()[0]) loss.backward() torch.nn.utils.clip_grad_norm(qrnn_model.parameters(), 50.0) # qrnn_optimizer.step() encoder_optimizer.step() decoder_optimizer.step() # proj_linear_optimizer.step() if i % 200 == 0: test = random.choice(train_data) show_inputs = test[0] show_targets = test[1] show_preds = qrnn_model(inputs, [inputs.size(1)], start_decode) outputs = torch.max(show_preds, dim=1)[1].view(len(inputs), -1) show_sentence(show_targets, show_inputs, outputs.data.tolist(), index2source, index2target) print("[%02d/%d] [%03d/%d] mean_loss : %0.2f" %(epoch, self.epoch, i, len(train_data) // self.batch_size, np.mean(losses))) self.__save_model_info(inputs, epoch, losses) torch.save(qrnn_model.encoder, './../models/test_qrnn_encoder_model_{0}.pth'.format(epoch)) torch.save(qrnn_model.decoder, './../models/test_qrnn_decoder_model_{0}.pth'.format(epoch)) torch.save(qrnn_model.proj_linear, './../models/test_qrnn_proj_linear_model_{0}.pth'.format(epoch)) losses=[] if self.rescheduled is False and epoch == self.epoch // 2: self.lr = self.lr * 0.01 # qrnn_optimizer = optim.Adam(qrnn_model.parameters(), lr=self.lr) encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=self.lr) decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=self.lr * self.decoder_learning_rate) # proj_linear_optimizer = optim.Adam(proj_linear_model.parameters(), lr=self.lr) self.rescheduled = True self.writer.export_scalars_to_json("./all_scalars.json") self.writer.close()
def forward(self, query, key, value, mask=None, tau=1): dot_products = (query.unsqueeze(2) * key.unsqueeze(1)).sum( -1) # batch x query_len x key_len if self.relative_clip: dot_relative = torch.einsum( "ijk,tk->ijt", query, self.key_relative.weight) # batch * query_len * relative_size batch_size, query_len, key_len = dot_products.shape diag_dim = max(query_len, key_len) if self.diag_id.shape[0] < diag_dim: self.diag_id = np.zeros((diag_dim, diag_dim)) for i in range(diag_dim): for j in range(diag_dim): if i <= j - self.relative_clip: self.diag_id[i, j] = 0 elif i >= j + self.relative_clip: self.diag_id[i, j] = self.relative_clip * 2 else: self.diag_id[i, j] = i - j + self.relative_clip diag_id = LongTensor(self.diag_id[:query_len, :key_len]) dot_relative = reshape( dot_relative, "bld", "bl_d", key_len).gather(-1, reshape(diag_id, "lm", "_lm_", batch_size, -1))[:, :, :, 0] # batch * query_len * key_len dot_products = dot_products + dot_relative if self.attend_mode == "only_attend_front": assert query.shape[1] == key.shape[1] tri = cuda(torch.ones(key.shape[1], key.shape[1]).triu(1), device=query) * 1e9 dot_products = dot_products - tri.unsqueeze(0) elif self.attend_mode == "only_attend_back": assert query.shape[1] == key.shape[1] tri = cuda(torch.ones(key.shape[1], key.shape[1]).tril(1), device=query) * 1e9 dot_products = dot_products - tri.unsqueeze(0) elif self.attend_mode == "not_attend_self": assert query.shape[1] == key.shape[1] eye = cuda(torch.eye(key.shape[1]), device=query) * 1e9 dot_products = dot_products - eye.unsqueeze(0) if self.window > 0: assert query.shape[1] == key.shape[1] window_mask = cuda(torch.ones(key.shape[1], key.shape[1]), device=query) window_mask = (window_mask.triu(self.window + 1) + window_mask.tril(self.window + 1)) * 1e9 dot_products = dot_products - window_mask.unsqueeze(0) if mask is not None: dot_products -= (1 - mask) * 1e9 logits = dot_products / self.scale if self.gumbel_attend and self.training: probs = gumbel_softmax(logits, tau, dim=-1) else: probs = torch.softmax(logits, dim=-1) probs = probs * ( (dot_products <= -5e8).sum(-1, keepdim=True) < dot_products.shape[-1]).float() # batch_size * query_len * key_len probs = self.dropout(probs) res = torch.matmul(probs, value) # batch_size * query_len * d_value if self.relative_clip: if self.recover_id.shape[0] < query_len: self.recover_id = np.zeros((query_len, self.relative_size)) for i in range(query_len): for j in range(self.relative_size): self.recover_id[i, j] = i + j - self.relative_clip recover_id = LongTensor(self.recover_id[:key_len]) recover_id[recover_id < 0] = key_len recover_id[recover_id >= key_len] = key_len probs = torch.cat([probs, zeros(batch_size, query_len, 1)], -1) relative_probs = probs.gather( -1, reshape(recover_id, "qr", "_qr", batch_size)) # batch_size * query_len * relative_size res = res + torch.einsum( "bqr,rd->bqd", relative_probs, self.value_relative.weight) # batch_size * query_len * d_value return res
def forward(self, inp, wLinearLayerCallback, h_init=None, mode='max', input_callback=None, no_unk=True, top_k=10): """ inp contains: batch_size, dm, embLayer, embedding, sampling_proba, max_sent_length, post, post_length, resp_length [init_h] input_callback(i, embedding): if you want to change word embedding at pos i, override this function nextStep(embedding, flag): pass embedding to RNN and get gru_h, flag indicates i th sentence is end when flag[i]==1 wLinearLayerCallback(gru_h): input gru_h and give a probability distribution on vocablist output: w_o emb length""" nextStep, h_now, context = self.init_forward_all(inp.batch_size, inp.post, inp.post_length, h_init=inp.get( "init_h", None)) gen = Storage() gen.w_pro = [] batch_size = inp.embedding.shape[1] seqlen = inp.embedding.shape[0] length = inp.resp_length - 1 start_id = inp.dm.go_id if no_unk else 0 attn_weights = [] first_emb = inp.embLayer(LongTensor([inp.dm.go_id ])).repeat(inp.batch_size, 1) next_emb = first_emb if input_callback: inp.embedding = input_callback(inp.embedding) for i in range(seqlen): proba = random() # Sampling if proba < inp.sampling_proba: now = next_emb if input_callback: now = input_callback(now) # Teacher Forcing else: now = inp.embedding[i] if self.gru_input_attn: h_now = self.cell_forward(torch.cat([now, context], last_dim=-1), h_now) \ * Tensor((length > np.ones(batch_size) * i).astype(float)).unsqueeze(-1) else: h_now = self.cell_forward(now, h_now) \ * Tensor((length > np.ones(batch_size) * i).astype(float)).unsqueeze(-1) query = self.attn_query(h_now) attn_weight = maskedSoftmax( (query.unsqueeze(0) * inp.post).sum(-1), inp.post_length) context = (attn_weight.unsqueeze(-1) * inp.post).sum(0) gru_h = torch.cat([h_now, context], dim=-1) attn_weights.append(attn_weight) w = wLinearLayerCallback(gru_h) gen.w_pro.append(w) # Decoding if mode == "max": w = torch.argmax(w[:, start_id:], dim=1) + start_id next_emb = inp.embLayer(w) elif mode == "gumbel" or mode == "sample": w_onehot = gumbel_max(w[:, start_id:]) w = torch.argmax(w_onehot, dim=1) + start_id next_emb = torch.sum( torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[start_id:], 1) elif mode == "samplek": _, index = w[:, start_id:].topk(top_k, dim=-1, largest=True, sorted=True) # batch_size, top_k mask = torch.zeros_like(w[:, start_id:]).scatter_(-1, index, 1.0) w_onehot = gumbel_max_with_mask(w[:, start_id:], mask) w = torch.argmax(w_onehot, dim=1) + start_id next_emb = torch.sum( torch.unsqueeze(w_onehot, -1) * inp.embLayer.weight[start_id:], 1) else: raise AttributeError( "The given mode {} is not recognized.".format(mode)) gen.w_pro = torch.stack(gen.w_pro, dim=0) return gen