def forward(self, inputs, lengths, hidden=None): lens, indices = torch.sort(inputs.data.new(lengths).long(), 0, True) inputs = inputs[indices] if self.batch_first else inputs[:, indices] outputs, (h, c) = self.rnn(pack(inputs, lens.tolist(), batch_first=self.batch_first), hidden) outputs = unpack(outputs, batch_first=self.batch_first)[0] _, _indices = torch.sort(indices, 0) outputs = outputs[_indices] if self.batch_first else outputs[:, _indices] h, c = h[:, _indices, :], h[:, _indices, :] return outputs, (h, c)
def forward(self, enc_input, hidden=None): if isinstance(enc_input, tuple): # Lengths data is wrapped inside a Variable. lengths = enc_input[1].data.view(-1).tolist() emb = pack(self.embedding(enc_input[0]), lengths) else: emb = self.embedding(enc_input) outputs, hidden_t = self.rnn(emb, hidden) if isinstance(enc_input, tuple): outputs = unpack(outputs)[0] return outputs, hidden_t
def forward(self, src, lengths=None): """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" batch_size, _, nfft, t = src.size() src = src.transpose(0, 1).transpose(0, 3).contiguous() \ .view(t, batch_size, nfft) orig_lengths = lengths lengths = lengths.view(-1).tolist() for l in range(self.enc_layers): rnn = getattr(self, 'rnn_%d' % l) pool = getattr(self, 'pool_%d' % l) batchnorm = getattr(self, 'batchnorm_%d' % l) stride = self.enc_pooling[l] packed_emb = pack(src, lengths) memory_bank, tmp = rnn(packed_emb) memory_bank = unpack(memory_bank)[0] t, _, _ = memory_bank.size() memory_bank = memory_bank.transpose(0, 2) memory_bank = pool(memory_bank) lengths = [int(math.floor((length - stride) / stride + 1)) for length in lengths] memory_bank = memory_bank.transpose(0, 2) src = memory_bank t, _, num_feat = src.size() src = batchnorm(src.contiguous().view(-1, num_feat)) src = src.view(t, -1, num_feat) if self.dropout and l + 1 != self.enc_layers: src = self.dropout(src) memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2)) memory_bank = self.W(memory_bank).view(-1, batch_size, self.dec_rnn_size) state = memory_bank.new_full((self.dec_layers * self.num_directions, batch_size, self.dec_rnn_size_real), 0) if self.rnn_type == 'LSTM': # The encoder hidden is (layers*directions) x batch x dim. encoder_final = (state, state) else: encoder_final = state return encoder_final, memory_bank, orig_lengths.new_tensor(lengths)
def forward(self, src, lengths=None, encoder_state=None): "See :obj:`EncoderBase.forward()`" self._check_args(src, lengths, encoder_state) emb = self.embeddings(src) s_len, batch, emb_dim = emb.size() packed_emb = emb if lengths is not None and not self.no_pack_padded_seq: # Lengths data is wrapped inside a Variable. lengths = lengths.view(-1).tolist() packed_emb = pack(emb, lengths) memory_bank, encoder_final = self.rnn(packed_emb, encoder_state) if lengths is not None and not self.no_pack_padded_seq: memory_bank = unpack(memory_bank)[0] if self.use_bridge: encoder_final = self._bridge(encoder_final) return encoder_final, memory_bank
def forward(self, docs_input, titles_input, keywords=None, topics=None): docs_len, titles_len = get_seq_lenth(docs_input), get_seq_lenth( titles_input) docs_mask, titles_mask = creat_mask(docs_len), creat_mask(titles_len) s_docs, s_docs_len, reverse_docs_idx = sort_batch(docs_input, docs_len) s_titles, s_titles_len, reverse_titles_idx = sort_batch( titles_input, titles_len) # embedding docs_embedding = pack(self.embedding(s_docs), list(s_docs_len.data), batch_first=True) titles_embedding = pack(self.embedding(s_titles), list(s_titles_len.data), batch_first=True) # GRU encoder docs_outputs, _ = self.gru(docs_embedding, None) titles_outputs, _ = self.gru(titles_embedding, None) # unpack docs_outputs, _ = unpack(docs_outputs, batch_first=True) titles_outputs, _ = unpack(titles_outputs, batch_first=True) # unsort docs_outputs = docs_outputs[reverse_docs_idx] titles_outputs = titles_outputs[reverse_titles_idx] # calculate attention matrix dos = docs_outputs doc_mask = docs_mask.unsqueeze(2) tos = torch.transpose(titles_outputs, 1, 2) title_mask = titles_mask.unsqueeze(2) M = torch.bmm(dos, tos) M_mask = torch.bmm(doc_mask, title_mask.transpose(1, 2)) alpha = softmax_mask(M, M_mask, axis=1) beta = softmax_mask(M, M_mask, axis=2) sum_beta = torch.sum(beta, dim=1, keepdim=True) docs_len = docs_len.unsqueeze(1).unsqueeze(2).expand_as(sum_beta) average_beta = sum_beta / docs_len.float() # doc-level attention s = torch.bmm(alpha, average_beta.transpose(1, 2)) # predict keywords kws_probs = None if keywords is not None: kws_probs = [] for i, kws in enumerate(keywords): document = docs_input[i].squeeze() cur_prob = 1. for j, kw in enumerate(kws): if kw.data[0] == Constants.PAD: continue kw = kws[j].squeeze() pointer = document == kw.expand_as(document) cur_prob *= torch.sum( torch.masked_select(s[i].squeeze(), pointer)) kws_probs.append(cur_prob + 1e-10) kws_probs = torch.cat(kws_probs, 0).squeeze() # predict prob of topic fc_feature = torch.sum(docs_outputs * s, dim=1) topic_probs = self.fc(fc_feature) return topic_probs, kws_probs, s
def forward(self, x1, x2, cate, seq_lens, _, attn_mode=False): batch_size = x1.size(0) max_seq_length = x1.size(1) embed_size = x1.size(2) outputs = torch.zeros( [max_seq_length, batch_size, embed_size], device=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')) # cate = pack(cate, seq_lens, batch_first=True).data # x2: [batch_size, seq_len, num_of_temporal, embed_dim] x2 = pack(x2, seq_lens, batch_first=True).data mha_input = torch.transpose(x2, 0, 1) _, x2_score = self._mha(mha_input, mha_input, mha_input) x2_score = torch.softmax(torch.mean(x2_score, 2, keepdim=False), dim=1) x2_score = torch.unsqueeze(x2_score, dim=1) x2 = torch.squeeze(torch.bmm(x2_score, x2), dim=1) #x2 = torch.mean(x2, 0, keepdim=False) #x2 = torch.mean(x2.data, 1, keepdim=False) x2 = self._mlp_mha(x2) # sequence embedding x1 = pack(x1, seq_lens, batch_first=True) x1, _ = self.rnn(x1) sequence_lenths = x1.batch_sizes.cpu().numpy() cursor = 0 prev_x1s = [] if attn_mode: attn_score_save = torch.zeros( [max_seq_length, batch_size, max_seq_length], device=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')) for step in range(sequence_lenths.shape[0]): sequence_lenth = sequence_lenths[step] x1_step = x1.data[cursor:cursor + sequence_lenth] x2_step = x2[cursor:cursor + sequence_lenth] prev_x1s.append(x1_step) prev_x1s = [prev_x1[:sequence_lenth] for prev_x1 in prev_x1s] prev_hs = torch.stack(prev_x1s, dim=1) attn_score = [] for prev in range(prev_hs.size(1)): attn_inter = torch.sum(prev_hs[:, prev, :] * x2_step, dim=1, keepdim=True) attn_score.append(attn_inter + self._b_attn) attn_score = torch.softmax(torch.stack(attn_score, dim=1), dim=1) if attn_mode: attn_score_save[step, :sequence_lenth, :attn_score. shape[1]] = torch.squeeze(attn_score, dim=2) x1_step = torch.squeeze(torch.bmm( torch.transpose(attn_score, 1, 2), prev_hs), dim=1) x_step = torch.cat((x1_step, x2_step), dim=1) x_step = self.mlp(x_step) outputs[step][:sequence_lenth] = x_step cursor += sequence_lenth if attn_mode: prev_cates = [] attn_good_data = [] attn_score_save = torch.transpose(attn_score_save, 0, 1) batch_total = 0 cursor = 0 for step in range(sequence_lenths.shape[0]): sequence_lenth = sequence_lenths[step] cate_step = cate[cursor:cursor + sequence_lenth] cursor += sequence_lenth prev_cates.append(cate_step) prev_cates = [prev[:sequence_lenth] for prev in prev_cates] batch_total += 1 min_len = 14 min_vari_cate_prev = 6 min_vari_cate_pred = 4 target_len = 10 # min_len = 5 # min_vari_cate_prev = 0 # min_vari_cate_pred = 0 # target_len = 2 if step < min_len: continue prev_cates_t = torch.stack(prev_cates, dim=0) prev_cates_t = torch.argmax(torch.transpose( prev_cates_t, 0, 1), dim=2) for batch in range(prev_cates_t.shape[0]): #attn_score_c = attn_score_save[batch, step-1, :step].cpu().numpy().tolist() prev_cates_c = prev_cates_t[ batch, :].cpu().numpy().tolist() prev_cate_set = set([]) pred_cate_set = set([]) for i, c in enumerate(prev_cates_c): if i < target_len: prev_cate_set.update([c]) else: pred_cate_set.update([c]) if len(prev_cate_set) < min_vari_cate_prev or len( pred_cate_set) < min_vari_cate_pred: continue ok = 0 for c in pred_cate_set: if c not in prev_cate_set: continue ok += 1 if ok < 4: continue attn_good_data.append( ['===================================']) attn_good_data.append(prev_cates_c[:target_len]) attn_good_data.append(prev_cates_c[target_len:]) for s in range(target_len, step + 1): attn_score_c = attn_score_save[ batch, s - 1, :s].cpu().numpy().tolist() attn_good_data.append(attn_score_c[:target_len]) outputs = torch.transpose(outputs, 0, 1) #outputs = self._dropout(outputs) if attn_mode: with open('attn_data_cate.txt', 'a') as f_att: attn_data_str = '' for data_line in attn_good_data: data_line = [str(d) for d in data_line] attn_data_str += ','.join(data_line) + '\n' f_att.write(attn_data_str) return outputs
def forward(self, x, src_len=None, ivec_feat=None): """ x : (batch x seq x ndim) mask : (batch x seq) """ batchsize, seqlen, ndim = x.size() if src_len is None: src_len = [seqlen] * batchsize ### FNN ### # apply augmentation first layer # # TODO : dynamic layer # assert self.ivec_dim >= ivec_feat.size(1) ivec_feat = ivec_feat[:, 0:self.ivec_dim] if self.ivec_cfg['type'] == 'concat': # calculate h from ivec # res_ivec = self.aug_layer(ivec_feat) res_ivec = res_ivec.unsqueeze(1).expand(batchsize, seqlen, res_ivec.size(1)) res_ivec = res_ivec.contiguous().view(seqlen * batchsize, -1) res_main = x.contiguous().view(seqlen * batchsize, ndim) res_main = self.fnn[0](res_main) res = res_main + res_ivec else: _res_list = [] for ii in range(batchsize): _aug_param = self.fn_gen_params(ivec_feat[ii:ii + 1]) _main_aug_param = self.fn_aug_params(self.fnn[0].weight, _aug_param) _main_bias = self.fnn[0].bias res_ii = F.linear(x[ii], _main_aug_param, _main_bias) _res_list.append(res_ii) res = torch.stack(_res_list) pass res = getattr(F, self.fnn_act)(res) res = F.dropout(res, self.do_fnn[0], training=self.training) prev_size = res.size(1) for ii in range(1, len(self.fnn_sizes)): res = getattr(F, self.fnn_act)(self.fnn[ii](res)) res = F.dropout(res, self.do_fnn[ii], training=self.training) ### RNN ### # convert shape for RNN # res = res.view(batchsize, seqlen, -1) for ii in range(len(self.rnn_sizes)): # AVOID pack, slow !!! # if self.use_pack: res = pack(res, src_len, batch_first=True) res = self.rnn[ii](res)[0] # get h only # res, _ = unpack(res, batch_first=True) else: res = self.rnn[ii](res)[0] # get h only # if self.downsampling[ii] == True: res = res[:, 1::2] src_len = [x // 2 for x in src_len] pass res = F.dropout(res, self.do_rnn[ii], training=self.training) ### PRE SOFTMAX ### batchsize, seqlen_final, ndim_final = res.size() res = res.view(seqlen_final * batchsize, ndim_final) res = self.pre_softmax(res) res = res.view(batchsize, seqlen_final, -1) res = res.transpose(1, 0) return res, Variable(torch.IntTensor(src_len))
def forward(self, confnets, scores, src_lengths=None, par_arc_lengths=None): """ Based on the paper NEURAL CONFNET CLASSIFICATION (http://150.162.46.34:8080/icassp2018/ICASSP18_USB/pdfs/0006039.pdf) """ #self._check_args(confnets, src_lengths) #print('confnet size', confnets.size()) emb = self.embeddings(confnets.permute( 1, 0, 2, 3)) #(slen, batch, max_par_arc, emb_dim) emb_trans = emb.squeeze( 3) #.permute(1,0,2,3) #(batch, slen, max_par_arc, emb_dim) output_list = torch.tensor([]).cuda() #### FEATS NOT SUPPORTED ### confnets_ = confnets.squeeze(-1).permute( 1, 0, 2) # (max_sent_len, batch, max_par_arc_len, emb_dim) scores_ = scores.squeeze(-1).permute( 1, 0, 2) # (max_sent_len, batch, max_par_arc_len) par_arc_lengths_ = par_arc_lengths.permute(1, 0) # (max_sent_lens, batch) for em, score, lengths in zip(emb_trans, scores_, par_arc_lengths_): # word embedding # s_len, batch, emb_dim = emb.size() # output = self.dropout(output) sc = score.unsqueeze(-1).expand( em.size()) #(batch, max_par_arc, emb_sz) # confnet score weighted word embedding q = em.float() * sc.float() #(batch, max_par_arc, emb_sz) batch_size, max_par_arcs, emb_sz = q.size() v = torch.tanh(self.thetav(q)) #(batch, max_par_arc, emb_sz) v_bar = self.v_bar(v).squeeze(-1) #### masking: Mask the padding #### mask = torch.arange(max_par_arcs)[None, :].to( "cuda") < lengths[:, None].to("cuda").type(torch.float) mask = mask.type(torch.float) masked_v_bar = torch.where( mask == False, torch.tensor([float("-inf") - 1e-10], device=q.device), v_bar) attention = torch.softmax(masked_v_bar, dim=1) final_attention = attention.masked_fill(torch.isnan(attention), 0) # apply attention weights output = q * final_attention.unsqueeze(-1).expand(q.size()) # most attented words most_attentive_arc = torch.argmax(final_attention, dim=1) # highest attention weights most_attentive_arc_weights, _ = torch.max(final_attention, dim=1) # a = output a = torch.sum(output, dim=1) output_list = torch.cat((output_list, a.unsqueeze(0)), dim=0) # a = self.dropout(a) output_confnet = output_list.permute( 1, 0, 2) # (batch, max_sent_len, hid_dim) #return a, output_confnet, src_lengths #most_attentive_arc, attention, most_attentive_arc_weights # output, h_output packed_emb_ = output_confnet.permute(1, 0, 2) if src_lengths is not None and not self.no_pack_padded_seq: # Lengths data is wrapped inside a Tensor. lengths_list = src_lengths.view(-1).tolist() packed_emb = pack(packed_emb_, lengths_list) memory_bank, encoder_final = self.rnn(packed_emb) if src_lengths is not None and not self.no_pack_padded_seq: memory_bank = unpack(memory_bank)[0] if self.use_bridge: encoder_final = self._bridge(encoder_final) return encoder_final, memory_bank, src_lengths
def beam_search(self, src: Dict[str, torch.Tensor], key: Dict[str, torch.Tensor], lpos: Dict[str, torch.Tensor], rpos: Dict[str, torch.Tensor], max_decoding_step: int) -> Dict[str, torch.Tensor]: beam_size = self.beam_size src = src['tokens'] if self.use_feature: keys, lpos, rpos = key['keys'], lpos['lpos'], rpos['rpos'] lengths = self._get_lengths(src) batch_size = src.size(0) lengths, indices = lengths.sort(dim=0, descending=True) rev_indices = indices.sort()[1] src = src.index_select(dim=0, index=indices) if self.use_feature: keys = keys.index_select(dim=0, index=indices) lpos = lpos.index_select(dim=0, index=indices) rpos = rpos.index_select(dim=0, index=indices) src_embs = torch.cat([ self.src_embedding(src), self.key_embedding(keys), self.lpos_embedding(lpos), self.rpos_embedding(rpos) ], dim=-1) else: src_embs = self.src_embedding(src) src_embs = pack(src_embs, lengths, batch_first=True) encode_outputs = self.encoder(src_embs) contexts, encState = encode_outputs['hidden_outputs'], encode_outputs[ 'final_state'] contexts = contexts.repeat(beam_size, 1, 1) decState = encState[0].repeat(1, beam_size, 1), encState[1].repeat(1, beam_size, 1) beam = [ modules.beam.Beam(beam_size, bos=self._bos, eos=self._eos, n_best=1, minimum_length=self.minimum_length) for _ in range(batch_size) ] for i in range(max_decoding_step): if all((b.done() for b in beam)): break inp = torch.stack([b.getCurrentState() for b in beam]).t().contiguous().view(-1) outputs = self.decoder.decode_step(self.tgt_embedding(inp), decState, contexts) output, decState, attn = outputs['hidden_output'], outputs[ 'state'], outputs['attention_weights'] logits = self.generator(output) output = torch.nn.functional.log_softmax(logits, dim=-1).view( beam_size, batch_size, -1) attn = attn.view(beam_size, batch_size, -1) for j, b in enumerate(beam): b.advance(output.data[:, j], attn.data[:, j]) b.beam_update(decState, j) allHyps, allScores, allAttn = [], [], [] for j in rev_indices: b = beam[j] n_best = 1 scores, ks = b.sortFinished(minimum=n_best) hyps, attn = [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = b.getHyp(times, k) hyps.append(hyp) attn.append(att.max(1)[1]) allHyps.append(hyps[0]) allScores.append(scores[0]) allAttn.append(attn[0]) outputs = {'output_ids': allHyps, 'alignments': allAttn} return outputs
def greedy_search(self, src: Dict[str, torch.Tensor], key: Dict[str, torch.Tensor], lpos: Dict[str, torch.Tensor], rpos: Dict[str, torch.Tensor], max_decoding_step: int) -> Dict[str, torch.Tensor]: src = src['tokens'] if self.use_feature: keys, lpos, rpos = key['keys'], lpos['lpos'], rpos['rpos'] lengths = self._get_lengths(src) lengths, indices = lengths.sort(dim=0, descending=True) rev_indices = indices.sort()[1] src = src.index_select(dim=0, index=indices) bos = torch.ones(src.size(0)).long().fill_(self._bos).cuda() if self.use_feature: keys = keys.index_select(dim=0, index=indices) lpos = lpos.index_select(dim=0, index=indices) rpos = rpos.index_select(dim=0, index=indices) src_embs = torch.cat([ self.src_embedding(src), self.key_embedding(keys), self.lpos_embedding(lpos), self.rpos_embedding(rpos) ], dim=-1) else: src_embs = self.src_embedding(src) src_embs = pack(src_embs, lengths, batch_first=True) encode_outputs = self.encoder(src_embs) inputs, state, contexts = [ bos ], encode_outputs['final_state'], encode_outputs['hidden_outputs'] output_ids, attention_weights = [], [] for i in range(max_decoding_step): outputs = self.decoder.decode_step(self.tgt_embedding(inputs[i]), state, contexts) hidden_output, state, attn_weight = outputs[ 'hidden_output'], outputs['state'], outputs[ 'attention_weights'] logits = self.generator(hidden_output) next_id = logits.max(1)[1] inputs += [next_id] output_ids += [next_id] attention_weights += [attn_weight] output_ids = torch.stack(output_ids, dim=1) attention_weights = torch.stack(attention_weights, dim=1) alignments = attention_weights.max(2)[1] output_ids = output_ids.index_select(dim=0, index=rev_indices) alignments = alignments.index_select(dim=0, index=rev_indices) outputs = { 'output_ids': output_ids.tolist(), 'alignments': alignments.tolist() } return outputs
def forward(self, label_seqs, location_seqs, lengths): # sort label sequences and location sequences in batch dimension according to length batch_idx = sorted(range(len(lengths)), key=lambda k: lengths[k], reverse=True) reverse_batch_idx = torch.LongTensor( [batch_idx.index(i) for i in range(len(batch_idx))]) lens_sorted = sorted(lengths, reverse=True) label_seqs_sorted = torch.index_select(label_seqs, 0, torch.LongTensor(batch_idx)) location_seqs_sorted = torch.index_select(location_seqs, 0, torch.LongTensor(batch_idx)) # assert torch.equal(torch.index_select(label_seqs_sorted, 0, reverse_batch_idx), label_seqs) # assert torch.equal(torch.index_select(location_seqs_sorted, 0, reverse_batch_idx), location_seqs) if torch.cuda.is_available(): reverse_batch_idx = reverse_batch_idx.cuda() label_seqs_sorted = label_seqs_sorted.cuda() location_seqs_sorted = location_seqs_sorted.cuda() # create Variables label_seqs_sorted_var = Variable(label_seqs_sorted, requires_grad=False) location_seqs_sorted_var = Variable(location_seqs_sorted, requires_grad=False) # encode label sequences label_encoding = self.label_encoder(label_seqs_sorted_var) # encode location sequences location_seqs_sorted_var = location_seqs_sorted_var.view(-1, 4) location_encoding = self.location_encoder(location_seqs_sorted_var) location_encoding = location_encoding.view(label_encoding.size(0), -1, location_encoding.size(1)) # layout encoding - batch_size x max_seq_len x embed_size layout_encoding = label_encoding + location_encoding packed = pack(layout_encoding, lens_sorted, batch_first=True) hiddens, _ = self.lstm(packed) # unpack hiddens and get last hidden vector hiddens_unpack = unpack( hiddens, batch_first=True)[0] # batch_size x max_seq_len x embed_size last_hidden_idx = torch.zeros(hiddens_unpack.size(0), 1, hiddens_unpack.size(2)).long() for i in range(hiddens_unpack.size(0)): last_hidden_idx[i, 0, :] = lens_sorted[i] - 1 if torch.cuda.is_available(): last_hidden_idx = last_hidden_idx.cuda() last_hidden = torch.gather( hiddens_unpack, 1, Variable(last_hidden_idx, requires_grad=False)) # batch_size x 1 x embed_size last_hidden = torch.squeeze(last_hidden, 1) # batch_size x embed_size # convert back to original batch order last_hidden = torch.index_select( last_hidden, 0, Variable(reverse_batch_idx, requires_grad=False)) return last_hidden
def forward(self, x1, x2, __, seq_lens, ___, ____): """ forwarding function of the model called by the pytorch :x: input tensor :x2: input tensor for global temporal preferences :seq_lens: list cataining the length of each seqeunce :return: output tensor """ batch_size = x1.size(0) max_seq_length = x1.size(1) embed_size = x1.size(2) outputs = torch.zeros( [max_seq_length, batch_size, embed_size], device=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu')) # x2: [batch_size, seq_len, num_of_temporal, embed_dim] x2 = pack(x2, seq_lens, batch_first=True).data mha_input = torch.transpose(x2, 0, 1) _, x2_score = self._mha(mha_input, mha_input, mha_input) x2_score = torch.softmax(torch.mean(x2_score, 2, keepdim=False), dim=1) x2_score = torch.unsqueeze(x2_score, dim=1) x2 = torch.squeeze(torch.bmm(x2_score, x2), dim=1) #x2 = torch.mean(x2, 0, keepdim=False) #x2 = torch.mean(x2.data, 1, keepdim=False) x2 = self._mlp_mha(x2) # sequence embedding x1 = pack(x1, seq_lens, batch_first=True) x1, _ = self.rnn(x1) sequence_lenths = x1.batch_sizes.cpu().numpy() cursor = 0 prev_x1s = [] for step in range(sequence_lenths.shape[0]): sequence_lenth = sequence_lenths[step] x1_step = x1.data[cursor:cursor + sequence_lenth] x2_step = x2[cursor:cursor + sequence_lenth] prev_x1s.append(x1_step) prev_x1s = [prev_x1[:sequence_lenth] for prev_x1 in prev_x1s] prev_hs = torch.stack(prev_x1s, dim=1) # attn_score = [] # for prev in range(prev_hs.size(1)): # attn_input = torch.cat((prev_hs[:,prev,:], x2_step), dim=1) # attn_score.append(torch.matmul(attn_input, self._W_attn) + self._b_attn) # attn_score = torch.softmax(torch.stack(attn_score, dim=1), dim=1) # x1_step = torch.squeeze(torch.bmm(torch.transpose(attn_score, 1, 2), prev_hs), dim=1) x1_step = torch.mean(prev_hs, dim=1, keepdim=False) x_step = torch.cat((x1_step, x2_step), dim=1) x_step = self.mlp(x_step) outputs[step][:sequence_lenth] = x_step cursor += sequence_lenth outputs = torch.transpose(outputs, 0, 1) #outputs = self._dropout(outputs) return outputs
def forward(self, seq, length): emb = self.emb(seq) packed = pack(emb, length, batch_first=True, enforce_sorted=False) out, (h, c) = self.lstm(packed) return out, h, c
def forward(self, emb, lengths=None, init_states=None): "See :obj:`EncoderBase.forward()`" self._check_args(emb, lengths) packed_emb = emb if lengths is not None: # Lengths data is wrapped inside a Tensor. lengths, indices = torch.sort(lengths, 0, True) # Sort by length (keep idx) packed_emb = pack(packed_emb[indices], lengths.tolist(), batch_first=True) _, _indices = torch.sort(indices, 0) # Un-sort by length istates = [] if init_states: if isinstance(init_states, tuple): hidden_states, cell_states = init_states hidden_states = hidden_states.split(self.nlayers, dim=0) cell_states = cell_states.split(self.nlayers, dim=0) else: hidden_states = init_states hidden_states = hidden_states.split(self.nlayers, dim=0) for i in range(self.nlayers): if isinstance(init_states, tuple): istates.append((hidden_states[i], cell_states[i])) else: istates.append(hidden_states[i]) memory_bank, encoder_final = [], {'h_n': [], 'c_n': []} for i in range(self.nlayers): if i != 0: packed_emb = self.dropout(packed_emb) if lengths is not None: packed_emb = pack(packed_emb, lengths.tolist(), batch_first=True) if init_states: packed_emb, states = self.rnns[i](packed_emb, istates[i]) else: packed_emb, states = self.rnns[i](packed_emb) if isinstance(states, tuple): h_n, c_n = states encoder_final['c_n'].append(c_n) else: h_n = states encoder_final['h_n'].append(h_n) packed_emb = unpack( packed_emb, batch_first=True)[0] if lengths is not None else packed_emb if not self.use_last or i == self.nlayers - 1: memory_bank += [packed_emb[_indices] ] if lengths is not None else [packed_emb] assert len(encoder_final['h_n']) != 0 if self.use_last: memory_bank = memory_bank[-1] if len(encoder_final['c_n']) == 0: encoder_final = encoder_final['h_n'][-1] else: encoder_final = encoder_final['h_n'][-1], encoder_final['c_n'][ -1] else: memory_bank = torch.cat(memory_bank, dim=2) if len(encoder_final['c_n']) == 0: encoder_final = torch.cat(encoder_final['h_n'], dim=0) else: encoder_final = torch.cat(encoder_final['h_n'], dim=0), \ torch.cat(encoder_final['c_n'], dim=0) if self.use_bridge: encoder_final = self._bridge(encoder_final) # TODO: Temporary hack is adopted to compatible with DataParallel # reference: https://github.com/pytorch/pytorch/issues/1591 if memory_bank.size(1) < emb.size(1): dummy_tensor = torch.zeros( memory_bank.size(0), emb.size(1) - memory_bank.size(1), memory_bank.size(2)).type_as(memory_bank) memory_bank = torch.cat([memory_bank, dummy_tensor], 1) return encoder_final, memory_bank
def forward(self, src, lengths=None, is_knowledge=False): """ run transformer encoder :param src: source input :param lengths: sorted lengths :return: output and state (if with rnn) """ if self.config.conditioned and not is_knowledge: # HACK: recover the original sentence without the condition conditions_1 = src[[length - 1 for length in lengths], range(src.shape[1])] conditions_2 = src[[length - 2 for length in lengths], range(src.shape[1])] src[[length - 1 for length in lengths], range(src.shape[1])] = self.padding_idx src[[length - 2 for length in lengths], range(src.shape[1])] = self.padding_idx lengths = [length - 2 for length in lengths] assert all([length > 0 for length in lengths]) # print(conditions.shape) # batch_size # print(src.shape) # max_len X batch_size conditions_1 = conditions_1.unsqueeze(0) # 1 X batch_size conditions_2 = conditions_2.unsqueeze(0) # 1 X batch_size embed = self.embedding(src) if self.config.embed_only: return embed # RNN for positional information if self.config.positional: emb = self.position_embedding(embed) # [len, batch, size] else: emb, state = self.rnn(pack(embed, lengths)) emb = unpack(emb)[0] # [len, batch, 2*size] emb = emb[:, :, :self.config.hidden_size] + \ emb[:, :, self.config.hidden_size:] # [len, batch, size] emb = emb + embed # [len, batch, size] state = (state[0][0], state[1][0]) # LSTM states if self.config.conditioned and not is_knowledge: assert self.config.positional conditions_1_embed = self.embedding(conditions_1) conditions_1_embed = conditions_1_embed.expand_as(embed) conditions_2_embed = self.embedding(conditions_2) conditions_2_embed = conditions_2_embed.expand_as(embed) # Concat # emb = torch.cat([emb, conditions_embed], dim=-1) # emb = self.embed_transform(emb) # emb = torch.cat([emb, conditions_1_embed + conditions_2_embed], dim=-1) # emb = self.embed_transform(emb) # Add # emb = emb + conditions_embed emb = emb + conditions_1_embed + conditions_2_embed # Remove condition # emb = emb out = emb.transpose(0, 1).contiguous() # [batch, len, size] src_words = src.transpose(0, 1) # [batch, len] src_batch, src_len = src_words.size() padding_idx = self.padding_idx mask = src_words.data.eq(padding_idx).unsqueeze(1) \ .expand(src_batch, src_len, src_len) # [batch, len, len] for i in range(self.num_layers): out = self.transformer[i](out, mask) out = self.layer_norm(out) # [batch, len, size] assert self.config.positional if self.config.positional: # out = self.condition_context_attn(out, conditions_embed) # out = self.bi_attn_control_exp(out) return out.transpose(0, 1) else: return out.transpose(0, 1), state # [len, batch, size]
def forward(self, sents, lengths, fts=[], rel_idxs=[], lidx_start=[], lidx_end=[], ridx_start=[], ridx_end=[], pred_ind=True, flip=False, causal=False, token_type_ids=None, task='relation'): batch_size = sents.size(0) # dropout out = self.dropout(sents) # pack and lstm layer out, _ = self.lstm(pack(out, lengths, batch_first=True)) # unpack out, _ = unpack(out, batch_first=True) ### entity prediction - predict each input token if task == 'entity': out_ent = self.linear1_ent(self.dropout(out)) out_ent = self.act(out_ent) out_ent = self.linear2_ent(out_ent) prob_ent = self.softmax_ent(out_ent) return out_ent, prob_ent ### relaiton prediction - flatten hidden vars into a long vector if task == 'relation': ltar_f = torch.cat([ out[b, lidx_start[b][r], :self.hid_size].unsqueeze(0) for b, r in rel_idxs ], dim=0) ltar_b = torch.cat([ out[b, lidx_end[b][r], self.hid_size:].unsqueeze(0) for b, r in rel_idxs ], dim=0) rtar_f = torch.cat([ out[b, ridx_start[b][r], :self.hid_size].unsqueeze(0) for b, r in rel_idxs ], dim=0) rtar_b = torch.cat([ out[b, ridx_end[b][r], self.hid_size:].unsqueeze(0) for b, r in rel_idxs ], dim=0) out = self.dropout( torch.cat((ltar_f, ltar_b, rtar_f, rtar_b), dim=1)) out = torch.cat((out, fts), dim=1) # linear prediction out = self.linear1(out) out = self.act(out) out = self.dropout(out) out = self.linear2(out) prob = self.softmax(out) return out, prob
def test(self): self.ds.set_split("test", self.args.num_samples) thresh = 1. / 50. prec = 0. reca = 0. acc = 0. num_batches = len(self.dl) num_labels = len(self.ds.labels_dict) infer_outputs = [] counter_array = np.zeros((num_labels, 6)) # tgts, preds, tp, fp, tn, fn if any(x in self.model_name for x in ["resnet", "squeezenet"]): m = self.model_list[0] # set model(s) into eval mode m.eval() with tqdm(total=num_batches, leave=False, position=1, postfix={"accuracy": acc, "precision": prec}) as t: for mb, tgts in self.dl: mb = mb.to(self.device) tgts = tgts.to(torch.device("cpu")) # run inference out = m(mb) # move output to cpu for analysis / numpy out = out.to(torch.device("cpu")) infer_outputs.append((out.numpy().tolist(), tgts.numpy().tolist())) if self.loss_criterion == "crossentropy": out = F.softmax(out, dim = 1) else: out = F.sigmoid(out) #out = F.softmax(out, dim = 1) # out is either size (N, C) or (N, ) for tgt, o in zip(tgts, out): o_mask = torch.zeros_like(o) o_mask[torch.topk(o, tgt.sum().int().item())[1]] = 1. o_mask = o_mask.numpy() o_mask = o_mask.astype(np.bool) tgt = tgt.numpy() tgt_mask = tgt == 1. counter_array[tgt_mask, 0] += 1 #print(o_mask); break; counter_array[o_mask, 1] += 1 tp = np.logical_and(tgt_mask==True, o_mask==True) # this will be deflated for cross entorpy fp = np.logical_and(tgt_mask==False, o_mask==True) tn = np.logical_and(tgt_mask==False, o_mask==False) fn = np.logical_and(tgt_mask==True, o_mask==False) counter_array[tp, 2] += 1 counter_array[fp, 3] += 1 counter_array[tn, 4] += 1 counter_array[fn, 5] += 1 k = int(np.sum(tgt_mask)) tmp1 = torch.topk(o, k)[1] # get indicies tmp2 = np.where(tgt == 1.)[0] #acc = counter_array[:, 0].sum() / counter_array[:, 0].sum() ttp = counter_array[:, 2].sum() tfp = counter_array[:, 3].sum() ttn = counter_array[:, 4].sum() tfn = counter_array[:, 5].sum() prec = ttp / (ttp + tfp) reca = ttp / (ttp + tfn) acc = (ttp + ttn) / (ttp + tfp + ttn + tfn) t.set_postfix({"accuracy": "{0:.4f}".format(acc * 100.), "precision": "{0:.4f}".format(prec * 100.)}) t.update() #correct += (out_valid.detach().max(1)[1] == tgts_valid.detach()).sum() elif "attn" in self.model_name: encoder = self.model_list[0] decoder = self.model_list[1] # set model(s) into eval mode encoder.eval() decoder.eval() with tqdm(total=num_batches, leave=True, position=1, postfix={"accuracy": acc, "precision": prec}) as t: for i, ((mb, lengths), tgts) in enumerate(self.dl): # set model into train mode and clear gradients # move inputs to cuda if required mb = mb.to(self.device) tgts = tgts.to(torch.device("cpu")) # init hidden before packing encoder_hidden = encoder.initHidden(mb) # set inputs and targets mb = pack(mb, lengths, batch_first=True) #print(mb.size(), tgts.size()) encoder_output, encoder_hidden = encoder(mb, encoder_hidden) #print(encoder_output.detach().new(dec_size).size()) #enc_out_var, enc_out_len = unpack(encoder_output, batch_first=True) #dec_i = enc_out_var.new_zeros((self.batch_size, 1, encoder.hidden_size)) dec_h = encoder_hidden # Use last (forward) hidden state from encoder #print(decoder.n_layers, encoder_hidden.size(), dec_i.size(), dec_h.size()) # run through decoder in one shot mb, _ = unpack(mb, batch_first=True) out, dec_h, dec_attn = decoder(mb, dec_h, encoder_output) # calculate loss out = out.to(torch.device("cpu")) out.squeeze_() infer_outputs.append((out.numpy().tolist(), tgts.numpy().tolist())) if self.loss_criterion == "crossentropy": out = F.softmax(out, dim = 1) else: out = F.sigmoid(out) # out is either size (N, C) or (N, ) for tgt, o in zip(tgts, out): o_mask = torch.zeros_like(o) o_mask[torch.topk(o, tgt.sum().int().item())[1]] = 1. o_mask = o_mask.numpy() o_mask = o_mask.astype(np.bool) tgt = tgt.numpy() tgt_mask = tgt == 1. counter_array[tgt_mask, 0] += 1 #print(o_mask); break; counter_array[o_mask, 1] += 1 tp = np.logical_and(tgt_mask==True, o_mask==True) # this will be deflated for cross entorpy fp = np.logical_and(tgt_mask==False, o_mask==True) tn = np.logical_and(tgt_mask==False, o_mask==False) fn = np.logical_and(tgt_mask==True, o_mask==False) counter_array[tp, 2] += 1 counter_array[fp, 3] += 1 counter_array[tn, 4] += 1 counter_array[fn, 5] += 1 ttp = counter_array[:, 2].sum() tfp = counter_array[:, 3].sum() ttn = counter_array[:, 4].sum() tfn = counter_array[:, 5].sum() prec = ttp / (ttp + tfp) reca = ttp / (ttp + tfn) acc = (ttp + ttn) / (ttp + tfp + ttn + tfn) t.set_postfix({"accuracy": "{0:.4f}".format(acc * 100.), "precision": "{0:.4f}".format(prec * 100.)}) t.update() elif "bytenet" in self.model_name: encoder = self.model_list[0] decoder = self.model_list[1] # set model(s) into eval mode encoder.eval() decoder.eval() with tqdm(total=num_batches, leave=True, position=1, postfix={"accuracy": acc, "precision": prec}) as t: for i, (mb, tgts) in enumerate(self.dl): # set inputs and targets mb, tgts = mb.to(self.device), tgts.to(torch.device("cpu")) mb = encoder(mb) out = decoder(mb) out = out.to(torch.device("cpu")) infer_outputs.append((out.numpy().tolist(), tgts.numpy().tolist())) if self.loss_criterion == "crossentropy": out = F.softmax(out, dim = 1) else: out = F.sigmoid(out) # out is either size (N, C) or (N, ) for tgt, o in zip(tgts, out): o_mask = torch.zeros_like(o) o_mask[torch.topk(o, tgt.sum().int().item())[1]] = 1. o_mask = o_mask.numpy() o_mask = o_mask.astype(np.bool) tgt = tgt.numpy() tgt_mask = tgt == 1. counter_array[tgt_mask, 0] += 1 #print(o_mask); break; counter_array[o_mask, 1] += 1 tp = np.logical_and(tgt_mask==True, o_mask==True) # this will be deflated for cross entorpy fp = np.logical_and(tgt_mask==False, o_mask==True) tn = np.logical_and(tgt_mask==False, o_mask==False) fn = np.logical_and(tgt_mask==True, o_mask==False) counter_array[tp, 2] += 1 counter_array[fp, 3] += 1 counter_array[tn, 4] += 1 counter_array[fn, 5] += 1 ttp = counter_array[:, 2].sum() tfp = counter_array[:, 3].sum() ttn = counter_array[:, 4].sum() tfn = counter_array[:, 5].sum() prec = ttp / (ttp + tfp) reca = ttp / (ttp + tfn) acc = (ttp + ttn) / (ttp + tfp + ttn + tfn) t.set_postfix({"accuracy": "{0:.4f}".format(acc * 100.), "precision": "{0:.4f}".format(prec * 100.)}) t.update() else: raise NotImplemented self.infer_stats = counter_array self.infer_outputs = infer_outputs
def validate(self, epoch): self.ds.set_split("valid", self.args.num_samples) running_validation_loss = [] accuracies = [] acc = 0 threshold = 1 - (1. / 3.) num_batches = len(self.dl) if any(x in self.model_name for x in ["resnet", "squeezenet"]): m = self.model_list[0] # set model(s) into eval mode m.eval() with tqdm(total=num_batches, leave=True, position=2, postfix={"acc": acc, "loss": "{0:.6f}".format(0.)}) as t: for mb_valid, tgts_valid in self.dl: mb_valid = mb_valid.to(self.device) tgts_valid = tgts_valid.to(torch.device("cpu")) out_valid = m(mb_valid) out_valid = out_valid.to(torch.device("cpu")) if "margin" in self.loss_criterion: out_valid = F.sigmoid(out_valid) if self.loss_criterion == "margin": tgts_valid = tgts_valid.long() loss_valid = self.criterion(out_valid, tgts_valid) running_validation_loss += [loss_valid.item()] if "margin" not in self.loss_criterion: out_valid = F.sigmoid(out_valid) if self.loss_criterion == "crossentropy": out_pred = out_valid.max(1)[1] acc = (out_pred == tgts_valid).sum().item() / tgts_valid.size(0) else: acc = 0. num_out = out_valid.size(0) for ov, tgt in zip(out_valid, tgts_valid): tgt = torch.LongTensor([i for i, x in enumerate(tgt) if x == 1]) num_tgt = tgt.size(0) ov = torch.topk(ov, num_tgt)[1] correct = len(np.intersect1d(tgt.numpy(), ov.numpy())) acc += (correct / num_tgt) / num_out accuracies.append(acc) t.set_postfix({"acc": acc, "loss": "{0:.6f}".format(running_validation_loss[-1])}) t.update() #correct += (out_valid.detach().max(1)[1] == tgts_valid.detach()).sum() elif "attn" in self.model_name: encoder = self.model_list[0] decoder = self.model_list[1] # set model(s) into eval mode encoder.eval() decoder.eval() with tqdm(total=num_batches, leave=True, position=2, postfix={"acc": acc, "loss": "{0:.6f}".format(0.)}) as t: for i, ((mb_valid, lengths), tgts_valid) in enumerate(self.dl): # set model into train mode and clear gradients # move inputs to cuda if required mb_valid = mb_valid.to(self.device) tgts_valid = tgts_valid.to(torch.device("cpu")) # init hidden before packing encoder_hidden = encoder.initHidden(mb_valid) # set inputs and targets mb_valid = pack(mb_valid, lengths, batch_first=True) #print(mb.size(), tgts.size()) encoder_output, encoder_hidden = encoder(mb_valid, encoder_hidden) #print(encoder_output.detach().new(dec_size).size()) #enc_out_var, enc_out_len = unpack(encoder_output, batch_first=True) #dec_i = enc_out_var.new_zeros((self.batch_size, 1, encoder.hidden_size)) dec_h = encoder_hidden # Use last (forward) hidden state from encoder #print(decoder.n_layers, encoder_hidden.size(), dec_i.size(), dec_h.size()) # run through decoder in one shot mb_valid, _ = unpack(mb_valid, batch_first=True) out_valid, dec_h, dec_attn = decoder(mb_valid, dec_h, encoder_output) # calculate loss out_valid = out_valid.to(torch.device("cpu")) out_valid.squeeze_() if "margin" in self.loss_criterion: out_valid = F.sigmoid(out_valid) if self.loss_criterion == "margin": tgts_valid = tgts_valid.long() loss_valid = self.criterion(out_valid, tgts_valid) running_validation_loss += [loss_valid.item()] if "margin" not in self.loss_criterion: out_valid = F.sigmoid(out_valid) if self.loss_criterion == "crossentropy": out_pred = out_valid.max(1)[1] acc = (out_pred == tgts_valid).sum().item() / tgts_valid.size(0) else: acc = 0. num_out = out_valid.size(0) for ov, tgt in zip(out_valid, tgts_valid): tgt = torch.LongTensor([i for i, x in enumerate(tgt) if x == 1]) num_tgt = tgt.size(0) ov = torch.topk(ov, num_tgt)[1] correct = len(np.intersect1d(tgt.numpy(), ov.numpy())) acc += (correct / num_tgt) / num_out accuracies.append(acc) t.set_postfix({"acc": acc, "loss": "{0:.6f}".format(running_validation_loss[-1])}) t.update() #correct += (dec_o.detach().max(1)[1] == tgts.detach()).sum() elif "bytenet" in self.model_name: encoder = self.model_list[0] decoder = self.model_list[1] # set model(s) into eval mode encoder.eval() decoder.eval() with tqdm(total=num_batches, leave=True, position=2, postfix={"acc": acc, "loss": "{0:.6f}".format(0.)}) as t: for i, (mb_valid, tgts_valid) in enumerate(self.dl): # set inputs and targets mb_valid, tgts_valid = mb_valid.to(self.device), tgts_valid.to(torch.device("cpu")) mb_valid = encoder(mb_valid) out_valid = decoder(mb_valid) if "margin" in self.loss_criterion: out_valid = F.sigmoid(out_valid) if self.loss_criterion == "margin": tgts_valid = tgts_valid.long() out_valid = out_valid.to(torch.device("cpu")) loss_valid = self.criterion(out_valid, tgts_valid) running_validation_loss += [loss_valid.item()] if "margin" not in self.loss_criterion: out_valid = F.sigmoid(out_valid) if self.loss_criterion == "crossentropy": out_pred = out_valid.max(1)[1] acc = (out_pred == tgts_valid).sum().item() / tgts_valid.size(0) else: acc = 0. num_out = out_valid.size(0) for ov, tgt in zip(out_valid, tgts_valid): tgt = torch.LongTensor([i for i, x in enumerate(tgt) if x == 1]) num_tgt = tgt.size(0) ov = torch.topk(ov, num_tgt)[1] correct = len(np.intersect1d(tgt.numpy(), ov.numpy())) acc += (correct / num_tgt) / num_out accuracies.append(acc) t.set_postfix({"acc": acc, "loss": "{0:.6f}".format(running_validation_loss[-1])}) t.update() #correct += (dec_o.detach().max(1)[1] == tgts.detach()).sum() self.valid_losses.append((running_validation_loss, accuracies))
def forward(self, word_inputs, feat_inputs, word_seq_length, char_inputs, char_seq_length, char_recover, dict_inputs, mask, batch_bert): """ word_inputs: (batch_size,seq_len) word_seq_length:() """ batch_size = word_inputs.size(0) seq_len = word_inputs.size(1) word_emb = self.word_embedding(word_inputs) if self.args.use_elmo: elmo_emb = self.elmo_embedding(word_inputs) # if self.args.use_bert: # word_emb = torch.cat((word_emb,torch.squeeze(batch_bert,2)),2) #elmo_emb = self.drop(elmo_emb) # word_rep = word_emb if self.args.use_char: size = char_inputs.size(0) char_emb = self.char_embedding(char_inputs) char_emb = pack(char_emb, char_seq_length.cpu().numpy(), batch_first=True) char_lstm_out, char_hidden = self.char_feature(char_emb) char_lstm_out = pad(char_lstm_out, batch_first=True) char_hidden = char_hidden[0].transpose(1, 0).contiguous().view( size, -1) char_hidden = char_hidden[char_recover] char_hidden = char_hidden.view(batch_size, seq_len, -1) if self.args.attention: word_rep = F.tanh( self.attn1(word_emb) + self.attn2(char_hidden)) z = F.sigmoid(self.attn3(word_rep)) x = 1 - z word_rep = F.mul(z, word_emb) + F.mul(x, char_hidden) else: word_rep = torch.cat((word_emb, char_hidden), 2) word_rep = self.word_drop(word_rep) #word represent dropout #if self.args.use_elmo: # word_rep = torch.cat((word_rep, elmo_emb), 2) if self.args.feature: for idx in range(self.feature_num): word_rep = torch.cat( (word_rep, self.feature_embeddings[idx](feat_inputs[idx])), 2) # batch_bert = torch.split(batch_bert,1,dim=2) # normed_weights = F.softmax(self.scalar_parameters, dim=0) # y = self.gamma * sum(weight * tensor.squeeze(2) for weight, tensor in zip(normed_weights,batch_bert)) # x = F.softmax(torch.mean(batch_bert,dim=2)) x = F.softmax(torch.mean(batch_bert, dim=2)) if self.args.use_bert: word_rep = torch.cat((word_rep, x), 2) word_rep = pack(word_rep, word_seq_length.cpu().numpy(), batch_first=True) out, hidden = self.word_feature(word_rep) out, _ = pad(out, batch_first=True) if self.args.use_elmo: out = torch.cat((out, elmo_emb), 2) if self.args.out_dict: dict_rep = pack(dict_inputs, word_seq_length.cpu().numpy(), batch_first=True) dict_out, hidden = self.dict_feature(dict_rep) dict_out, _ = pad(dict_out, batch_first=True) #dict_out = self.dict_fc(dict_inputs) out = torch.cat((out, dict_out), 2) if self.args.lstm_attention: out_list, weight_list = [], [] for idx in range(seq_len): # slice_out = out[:,0:idx+1,:] if idx + 2 > seq_len: slice_out = out else: slice_out = out[:, 0:idx + 2, :] # slice_out = out slice_out, weights = self.attention(slice_out) # slice_out, weights = SelfAttention(self.args.hidden_dim*2).forward(slice_out) out_list.append(slice_out.unsqueeze(1)) weight_list.append(weights) out = torch.cat(out_list, dim=1) out = self.drop(out) out = self.hidden2tag(out) return out
def forward(self, x, lens, k, kx): # model takes as input the text, aspect, and location # runs BLSTM over text using embedding(location, aspect) as # the initial hidden state, as opposed to a different lstm for every pair??? # output sentiment # DBG words = x emb = self.drop(self.lut(x)) p_emb = pack(emb, lens, True) l, a = k N = l.shape[0] T = x.shape[1] # factor this out, for sure. POSSIBLE BUGS y_idx = l * len(self.A) + a s = (self.lut_la(y_idx) .view(N, 2, 2 * self.nlayers, self.rnn_sz) .permute(1, 2, 0, 3) .contiguous()) state = (s[0], s[1]) x, (h, c) = self.rnn(p_emb, state) # h: L * D x N x H x = unpack(x, True)[0] # Get the last hidden states for both directions, POSSIBLE BUGS phi_s = self.proj_s(x) #""" idxs = torch.arange(0, max(lens)).to(lens.device) # mask: N x R x 1 mask = (idxs.repeat(len(lens), 1) >= lens.unsqueeze(-1)) phi_s[:,:,-1].masked_fill_(1-mask, float("-inf")) phi_s[:,:,:3].masked_fill_(mask.unsqueeze(-1), float("-inf")) #""" """ h = (h .view(self.nlayers, 2, -1, self.rnn_sz)[-1] .permute(1, 0, 2) .contiguous() .view(-1, 2 * self.rnn_sz)) phi_y = self.proj_y(h) """ phi_y = torch.zeros(N, len(self.S)).to(self.psi_ys.device) psi_ys = torch.cat( [torch.diag(self.psi_ys), torch.zeros(len(self.S), 1).to(self.psi_ys)], dim=-1, ).expand(T, len(self.S), len(self.S)+1) #psi_ys = torch.diag(self.psi_ys).repeat(T, 1, 1) # Z is really weird here Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True) #Z, hy = ubersum("nts,tys,ny->n,ny", phi_s, psi_ys, phi_y, batch_dims="t", modulo_total=True) def stuff(i): loc = self.L.itos[l[i]] asp = self.A.itos[a[i]] return self.tostr(words[i]), loc, asp, xp[i], yp[i] if self.training: self._N += 1 if self._N > 100 and self.training: Zx, hx = ubersum("nts,tys->nt,nts", phi_s, psi_ys, batch_dims="t", modulo_total=True) xp = (hx - Zx.unsqueeze(-1)).exp() yp = (hy - Z.unsqueeze(-1)).exp() #Zx, hx = ubersum("nts,ys->nt,nts", phi_s, self.psi_ys, batch_dims="t") import pdb; pdb.set_trace() pass # text, loc, asp, xpi, ypi = stuff(10) #import pdb; pdb.set_trace() return hy# - Z.unsqueeze(-1)
def forward(self, inputs, hidden=None): emb = pack(self.word_lut(inputs[0]), inputs[1]) outputs, hidden_t = self.rnn(emb, hidden) outputs = unpack(outputs)[0] return hidden_t, outputs
def fit(self, epoch, early_stop=None): epoch_losses = [] self.ds.set_split("train") self.adjust_opt_params(epoch) self.scheduler.step() #self.optimizer = self.get_optimizer(epoch) num_batches = len(self.dl) if any(x in self.model_name for x in ["resnet", "squeezenet"]): if self.use_precompute: pass # TODO implement network precomputation #self.precompute(self.L["fc_layer"]["precompute"]) m = self.model_list[0] with tqdm(total=num_batches, leave=False, position=1) as t: for i, (mb, tgts) in enumerate(self.dl): if i == early_stop: break m.train() mb, tgts = mb.to(self.device), tgts.to(self.device) m.zero_grad() out = m(mb) if "margin" in self.loss_criterion: out = F.sigmoid(out) if self.loss_criterion == "margin": tgts = tgts.long() #print(tgts) loss = self.criterion(out, tgts) loss.backward() self.optimizer.step() epoch_losses.append(loss.item()) if self.tqdmiter: self.tqdmiter.set_postfix({"loss": "{0:.6f}".format(epoch_losses[-1])}) self.tqdmiter.refresh() else: print(epoch_losses[-1]) if i % self.log_interval == 0 and self.do_validate and i != 0: with torch.no_grad(): self.validate(epoch) self.ds.set_split("train") t.update() elif "attn" in self.model_name: encoder = self.model_list[0] decoder = self.model_list[1] with tqdm(total=num_batches, leave=False, position=1) as t: for i, ((mb, lengths), tgts) in enumerate(self.dl): # set model into train mode and clear gradients encoder.train() decoder.train() encoder.zero_grad() decoder.zero_grad() # set inputs and targets mb, tgts = mb.to(self.device), tgts.to(self.device) # create the initial hidden input before packing sequence encoder_hidden = encoder.initHidden(mb) # pack sequence mb = pack(mb, lengths, batch_first=True) #print(mb.size(), tgts.size()) # encode sequence encoder_output, encoder_hidden = encoder(mb, encoder_hidden) # Prepare input and output variables for decoder #dec_size = [[[0] * encoder.hidden_size]*1]*self.batch_size #print(encoder_output.detach().new(dec_size).size()) #enc_out_var, enc_out_len = unpack(encoder_output, batch_first=True) #dec_i = enc_out_var.new_zeros((self.batch_size, 1, encoder.hidden_size)) dec_h = encoder_hidden # Use last (forward) hidden state from encoder #print(decoder.n_layers, encoder_hidden.size(), dec_i.size(), dec_h.size()) # run through decoder in one shot mb, _ = unpack(mb, batch_first=True) dec_o, dec_h, dec_attn = decoder(mb, dec_h, encoder_output) dec_o.squeeze_() #print(dec_o) #print(dec_o.size(), dec_h.size(), dec_attn.size(), tgts.size()) #print(dec_o.view(-1, decoder.output_size).size(), tgts.view(-1).size()) # calculate loss and backprop if "margin" in self.loss_criterion: dec_o = F.sigmoid(dec_o) if self.loss_criterion == "margin": tgts = tgts.long() loss = self.criterion(dec_o, tgts) #nn.utils.clip_grad_norm(encoder.parameters(), 0.05) #nn.utils.clip_grad_norm(decoder.parameters(), 0.05) loss.backward() self.optimizer.step() epoch_losses.append(loss.item()) if self.tqdmiter: self.tqdmiter.set_postfix({"loss": "{0:.6f}".format(epoch_losses[-1])}) self.tqdmiter.refresh() else: print(epoch_losses[-1]) if i % self.log_interval == 0 and self.do_validate and i != 0: with torch.no_grad(): self.validate(epoch) self.ds.set_split("train") t.update() elif "bytenet" in self.model_name: encoder = self.model_list[0] decoder = self.model_list[1] with tqdm(total=num_batches, leave=False, position=1) as t: for i, (mb, tgts) in enumerate(self.dl): # set model into train mode and clear gradients encoder.train() decoder.train() encoder.zero_grad() decoder.zero_grad() # set inputs and targets mb, tgts = mb.to(self.device), tgts.to(self.device) mb = encoder(mb) out = decoder(mb) if "margin" in self.loss_criterion: out = F.sigmoid(out) if self.loss_criterion == "margin": tgts = tgts.long() loss = self.criterion(out, tgts) loss.backward() self.optimizer.step() epoch_losses.append(loss.item()) if self.tqdmiter: self.tqdmiter.set_postfix({"loss": "{0:.6f}".format(epoch_losses[-1])}) self.tqdmiter.refresh() else: print(epoch_losses[-1]) if i % self.log_interval == 0 and self.do_validate and i != 0: with torch.no_grad(): self.validate(epoch) self.ds.set_split("train") t.update() self.train_losses.append(epoch_losses) if epoch % 10 == 0 and epoch != 0 and self.use_cache: self.ds.init_cache()