def __init__(self, embeddings_path, rnn_num_layers, rnn_hidden_size, author_embedding_dim, num_author_embeddings, cnn_kernel_size, cnn_out_channels, pad_author_id, citations_dim, dual=False, global_info=False, weight_scores=True, device='cpu'): super().__init__() self.device = device self.dual = dual self.global_info = global_info self.weight_scores = weight_scores self.emb = EmbeddingLayer(embeddings_path) self.word_lstm = LSTMLayer(num_layers=rnn_num_layers, hidden_size=rnn_hidden_size, input_size=self.emb.embedding_dim, device=device) # self-attention self.context_attention = AttentionLayer( query_dim=2 * self.word_lstm.hidden_size, value_dim=2 * self.word_lstm.hidden_size, device=device) # attention over cited article's text self.ref_paper_attention = AttentionLayer( query_dim=2 * self.word_lstm.hidden_size, value_dim=2 * self.word_lstm.hidden_size, device=device) # attention over citing article's text if self.global_info: self.citing_paper_attention = AttentionLayer( query_dim=2 * self.word_lstm.hidden_size, value_dim=2 * self.word_lstm.hidden_size, device=device) self.semantic_score = torch.nn.CosineSimilarity(dim=-1) # dual version needs components for bibliographic score if self.dual: self.author_embedding = AuthorsCNN( num_author_embeddings=num_author_embeddings, author_embedding_dim=author_embedding_dim, padding_idx=pad_author_id, kernel_sizes=cnn_kernel_size, out_channels=cnn_out_channels) self.meta_score = MLPLayer( input_size=self.author_embedding.embedding_dim + citations_dim, output_size=1) if self.weight_scores: self.score_weights = torch.nn.Linear( in_features=2 * self.word_lstm.hidden_size, out_features=2)
def __init__(self, args, use_attn=False): super(MolConvNet, self).__init__() self.args = args self.use_attn = use_attn self.conv_layer = GraphConv(args) self.output_size = args.hidden_size if self.use_attn: self.attn_layer = AttentionLayer(args) self.output_size += args.hidden_size
def __init__(self, img_dim, question_dim, **kwargs): super(ImageEmbedding, self).__init__() self.image_attention_model = AttentionLayer(img_dim, question_dim, **kwargs) self.out_dim = self.image_attention_model.out_dim
def __init__( self, # params vocab_size, embedding_size=200, acous_hidden_size=256, acous_att_mode='bahdanau', hidden_size_dec=200, hidden_size_shared=200, num_unilstm_dec=4, # add_acous=True, acous_norm=False, spec_aug=False, batch_norm=False, enc_mode='pyramid', use_type='char', # add_times=False, # embedding_dropout=0, dropout=0.0, residual=True, batch_first=True, max_seq_len=32, load_embedding=None, word2id=None, id2word=None, hard_att=False, use_gpu=False): super(LAS, self).__init__() # config device if use_gpu and torch.cuda.is_available(): global device device = torch.device('cuda') else: device = torch.device('cpu') # define model self.acous_dim = 40 self.acous_hidden_size = acous_hidden_size self.acous_att_mode = acous_att_mode self.hidden_size_dec = hidden_size_dec self.hidden_size_shared = hidden_size_shared self.num_unilstm_dec = num_unilstm_dec # define var self.hard_att = hard_att self.residual = residual self.use_type = use_type self.max_seq_len = max_seq_len # tuning self.add_acous = add_acous self.acous_norm = acous_norm self.spec_aug = spec_aug self.batch_norm = batch_norm self.enc_mode = enc_mode # add time stamps self.add_times = add_times # use shared embedding + vocab self.vocab_size = vocab_size self.embedding_size = embedding_size self.load_embedding = load_embedding self.word2id = word2id self.id2word = id2word # define operations self.embedding_dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(dropout) # ------- load embeddings -------- if self.use_type != 'bpe': if self.load_embedding: embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding( self.word2id, embedding_matrix, self.load_embedding) embedding_matrix = torch.FloatTensor(embedding_matrix) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) else: self.embedder = nn.Embedding(self.vocab_size, self.embedding_size, sparse=False, padding_idx=PAD) elif self.use_type == 'bpe': # BPE embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding_bpe(embedding_matrix) embedding_matrix = torch.FloatTensor(embedding_matrix).to( device=device) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) # ------- las model -------- if self.add_acous and not self.add_times: # ------ define acous enc ------- if self.enc_mode == 'pyramid': self.acous_enc_l1 = torch.nn.LSTM(self.acous_dim, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l2 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l3 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l4 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) if self.batch_norm: self.bn1 = nn.BatchNorm1d(self.acous_hidden_size * 2) self.bn2 = nn.BatchNorm1d(self.acous_hidden_size * 2) self.bn3 = nn.BatchNorm1d(self.acous_hidden_size * 2) self.bn4 = nn.BatchNorm1d(self.acous_hidden_size * 2) elif self.enc_mode == 'cnn': pass # ------ define acous att -------- dropout_acous_att = dropout self.acous_hidden_size_att = 0 # ignored with bilinear self.acous_key_size = self.acous_hidden_size * 2 # acous feats self.acous_value_size = self.acous_hidden_size * 2 # acous feats self.acous_query_size = self.hidden_size_dec # use dec(words) as query self.acous_att = AttentionLayer( self.acous_query_size, self.acous_key_size, value_size=self.acous_value_size, mode=self.acous_att_mode, dropout=dropout_acous_att, query_transform=False, output_transform=False, hidden_size=self.acous_hidden_size_att, use_gpu=use_gpu, hard_att=False) # ------ define acous out -------- self.acous_ffn = nn.Linear(self.acous_hidden_size * 2 + self.hidden_size_dec, self.hidden_size_shared, bias=False) self.acous_out = nn.Linear(self.hidden_size_shared, self.vocab_size, bias=True) # ------ define acous dec ------- # embedding_size_dec + self.hidden_size_shared [200+200] -> hidden_size_dec [200] if not self.residual: self.dec = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=self.num_unilstm_dec, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: self.dec = nn.Module() self.dec.add_module( 'l0', torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) for i in range(1, self.num_unilstm_dec): self.dec.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_dec, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) elif self.add_acous and self.add_times: # ------ define acous enc ------- if self.enc_mode == 'ts-pyramid': self.acous_enc_l1 = torch.nn.LSTM(self.acous_dim, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l2 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l3 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l4 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) else: # default acous_enc_blstm_depth = 1 self.acous_enc = torch.nn.LSTM( self.acous_dim, self.acous_hidden_size, num_layers=acous_enc_blstm_depth, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) # ------ define acous local att -------- dropout_acous_att = dropout self.acous_hidden_size_att = 0 # ignored with bilinear self.acous_key_size = self.acous_hidden_size * 2 # acous feats self.acous_value_size = self.acous_hidden_size * 2 # acous feats self.acous_query_size = self.hidden_size_dec # use dec(words) as query self.acous_att = AttentionLayer( self.acous_query_size, self.acous_key_size, value_size=self.acous_value_size, mode=self.acous_att_mode, dropout=dropout_acous_att, query_transform=False, output_transform=False, hidden_size=self.acous_hidden_size_att, use_gpu=use_gpu, hard_att=False) # ------ define dd classifier ------- self.dd_blstm_size = 300 self.dd_blstm_depth = 2 self.dd_blstm = torch.nn.LSTM(self.embedding_size, self.dd_blstm_size, num_layers=self.dd_blstm_depth, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) if self.add_acous: dd_in_dim = self.dd_blstm_size * 2 + self.acous_hidden_size * 2 else: dd_in_dim = self.dd_blstm_size * 2 # might need to change this self.dd_classify = nn.Sequential( nn.Linear(dd_in_dim, 50, bias=True), nn.LeakyReLU(0.2, inplace=True), nn.Linear(50, 50), nn.LeakyReLU(0.2, inplace=True), nn.Linear(50, 1), nn.Sigmoid(), )
def __init__( self, # add params vocab_size_enc, vocab_size_dec, embedding_size_enc=200, embedding_size_dec=200, embedding_dropout=0, hidden_size_enc=200, num_bilstm_enc=2, num_unilstm_enc=0, hidden_size_dec=200, num_unilstm_dec=2, hidden_size_att=10, hidden_size_shared=200, dropout=0.0, residual=False, batch_first=True, max_seq_len=32, batch_size=64, load_embedding_src=None, load_embedding_tgt=None, src_word2id=None, tgt_word2id=None, src_id2word=None, att_mode='bahdanau', hard_att=False, use_gpu=False, additional_key_size=0, ptr_net=False, use_bpe=False): super(Seq2Seq, self).__init__() # config device if use_gpu and torch.cuda.is_available(): global device device = torch.device('cuda') else: device = torch.device('cpu') # define var self.hidden_size_enc = hidden_size_enc self.num_bilstm_enc = num_bilstm_enc self.num_unilstm_enc = num_unilstm_enc self.hidden_size_dec = hidden_size_dec self.num_unilstm_dec = num_unilstm_dec self.hidden_size_att = hidden_size_att self.hidden_size_shared = hidden_size_shared self.batch_size = batch_size self.max_seq_len = max_seq_len self.use_gpu = use_gpu self.hard_att = hard_att self.additional_key_size = additional_key_size self.residual = residual self.ptr_net = ptr_net self.use_bpe = use_bpe # use shared embedding + vocab self.vocab_size = vocab_size_enc self.embedding_size = embedding_size_enc self.load_embedding = load_embedding_src self.word2id = src_word2id self.id2word = src_id2word # define operations self.embedding_dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(dropout) self.beam_width = 0 # load embeddings if not self.use_bpe: if self.load_embedding: embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding( self.word2id, embedding_matrix, self.load_embedding) embedding_matrix = torch.FloatTensor(embedding_matrix) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) else: self.embedder = nn.Embedding(self.vocab_size, self.embedding_size, sparse=False, padding_idx=PAD) else: # BPE embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding_bpe(embedding_matrix) embedding_matrix = torch.FloatTensor(embedding_matrix).to( device=device) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) self.embedder_enc = self.embedder self.embedder_dec = self.embedder # define enc # embedding_size_enc -> hidden_size_enc * 2 self.enc = torch.nn.LSTM(self.embedding_size, self.hidden_size_enc, num_layers=self.num_bilstm_enc, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) if self.num_unilstm_enc != 0: if not self.residual: self.enc_uni = torch.nn.LSTM(self.hidden_size_enc * 2, self.hidden_size_enc * 2, num_layers=self.num_unilstm_enc, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: self.enc_uni = nn.Module() for i in range(self.num_unilstm_enc): self.enc_uni.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_enc * 2, self.hidden_size_enc * 2, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) # define dec # embedding_size_dec + self.hidden_size_shared [200+200] -> hidden_size_dec [200] if not self.residual: self.dec = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=self.num_unilstm_dec, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: lstm_uni_dec_first = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) self.dec = nn.Module() self.dec.add_module('l0', lstm_uni_dec_first) for i in range(1, self.num_unilstm_dec): self.dec.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_dec, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) # define att # query: hidden_size_dec [200] # keys: hidden_size_enc * 2 + (optional) self.additional_key_size [400] # values: hidden_size_enc * 2 [400] # context: weighted sum of values [400] self.key_size = self.hidden_size_enc * 2 + self.additional_key_size self.value_size = self.hidden_size_enc * 2 self.query_size = self.hidden_size_dec self.att = AttentionLayer(self.query_size, self.key_size, value_size=self.value_size, mode=att_mode, dropout=dropout, query_transform=False, output_transform=False, hidden_size=self.hidden_size_att, use_gpu=self.use_gpu, hard_att=self.hard_att) # define output # (hidden_size_enc * 2 + hidden_size_dec) -> self.hidden_size_shared -> vocab_size_dec self.ffn = nn.Linear(self.hidden_size_enc * 2 + self.hidden_size_dec, self.hidden_size_shared, bias=False) self.out = nn.Linear(self.hidden_size_shared, self.vocab_size, bias=True) # define pointer weight if self.ptr_net == 'comb': self.ptr_i = nn.Linear(self.embedding_size, 1, bias=False) #decoder input self.ptr_s = nn.Linear(self.hidden_size_dec, 1, bias=False) #decoder state self.ptr_c = nn.Linear(self.hidden_size_enc * 2, 1, bias=True) #context
class Seq2Seq(nn.Module): """ enc-dec model """ def __init__( self, # add params vocab_size_enc, vocab_size_dec, embedding_size_enc=200, embedding_size_dec=200, embedding_dropout=0, hidden_size_enc=200, num_bilstm_enc=2, num_unilstm_enc=0, hidden_size_dec=200, num_unilstm_dec=2, hidden_size_att=10, hidden_size_shared=200, dropout=0.0, residual=False, batch_first=True, max_seq_len=32, batch_size=64, load_embedding_src=None, load_embedding_tgt=None, src_word2id=None, tgt_word2id=None, src_id2word=None, att_mode='bahdanau', hard_att=False, use_gpu=False, additional_key_size=0, ptr_net=False, use_bpe=False): super(Seq2Seq, self).__init__() # config device if use_gpu and torch.cuda.is_available(): global device device = torch.device('cuda') else: device = torch.device('cpu') # define var self.hidden_size_enc = hidden_size_enc self.num_bilstm_enc = num_bilstm_enc self.num_unilstm_enc = num_unilstm_enc self.hidden_size_dec = hidden_size_dec self.num_unilstm_dec = num_unilstm_dec self.hidden_size_att = hidden_size_att self.hidden_size_shared = hidden_size_shared self.batch_size = batch_size self.max_seq_len = max_seq_len self.use_gpu = use_gpu self.hard_att = hard_att self.additional_key_size = additional_key_size self.residual = residual self.ptr_net = ptr_net self.use_bpe = use_bpe # use shared embedding + vocab self.vocab_size = vocab_size_enc self.embedding_size = embedding_size_enc self.load_embedding = load_embedding_src self.word2id = src_word2id self.id2word = src_id2word # define operations self.embedding_dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(dropout) self.beam_width = 0 # load embeddings if not self.use_bpe: if self.load_embedding: embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding( self.word2id, embedding_matrix, self.load_embedding) embedding_matrix = torch.FloatTensor(embedding_matrix) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) else: self.embedder = nn.Embedding(self.vocab_size, self.embedding_size, sparse=False, padding_idx=PAD) else: # BPE embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding_bpe(embedding_matrix) embedding_matrix = torch.FloatTensor(embedding_matrix).to( device=device) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) self.embedder_enc = self.embedder self.embedder_dec = self.embedder # define enc # embedding_size_enc -> hidden_size_enc * 2 self.enc = torch.nn.LSTM(self.embedding_size, self.hidden_size_enc, num_layers=self.num_bilstm_enc, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) if self.num_unilstm_enc != 0: if not self.residual: self.enc_uni = torch.nn.LSTM(self.hidden_size_enc * 2, self.hidden_size_enc * 2, num_layers=self.num_unilstm_enc, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: self.enc_uni = nn.Module() for i in range(self.num_unilstm_enc): self.enc_uni.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_enc * 2, self.hidden_size_enc * 2, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) # define dec # embedding_size_dec + self.hidden_size_shared [200+200] -> hidden_size_dec [200] if not self.residual: self.dec = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=self.num_unilstm_dec, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: lstm_uni_dec_first = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) self.dec = nn.Module() self.dec.add_module('l0', lstm_uni_dec_first) for i in range(1, self.num_unilstm_dec): self.dec.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_dec, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) # define att # query: hidden_size_dec [200] # keys: hidden_size_enc * 2 + (optional) self.additional_key_size [400] # values: hidden_size_enc * 2 [400] # context: weighted sum of values [400] self.key_size = self.hidden_size_enc * 2 + self.additional_key_size self.value_size = self.hidden_size_enc * 2 self.query_size = self.hidden_size_dec self.att = AttentionLayer(self.query_size, self.key_size, value_size=self.value_size, mode=att_mode, dropout=dropout, query_transform=False, output_transform=False, hidden_size=self.hidden_size_att, use_gpu=self.use_gpu, hard_att=self.hard_att) # define output # (hidden_size_enc * 2 + hidden_size_dec) -> self.hidden_size_shared -> vocab_size_dec self.ffn = nn.Linear(self.hidden_size_enc * 2 + self.hidden_size_dec, self.hidden_size_shared, bias=False) self.out = nn.Linear(self.hidden_size_shared, self.vocab_size, bias=True) # define pointer weight if self.ptr_net == 'comb': self.ptr_i = nn.Linear(self.embedding_size, 1, bias=False) #decoder input self.ptr_s = nn.Linear(self.hidden_size_dec, 1, bias=False) #decoder state self.ptr_c = nn.Linear(self.hidden_size_enc * 2, 1, bias=True) #context def reset_use_gpu(self, use_gpu): self.use_gpu = use_gpu def reset_max_seq_len(self, max_seq_len): self.max_seq_len = max_seq_len def reset_batch_size(self, batch_size): self.batch_size = batch_size def set_beam_width(self, beam_width): self.beam_width = beam_width def set_idmap(self, word2id, id2word): self.word2id = word2id self.id2word = id2word def check_var(self, var_name, var_val_set=None): """ to make old models capatible with added classvar in later versions """ if not hasattr(self, var_name): if var_name == 'additional_key_size': var_val = var_val_set if type(var_val_set) != type(None) else 0 elif var_name == 'ptr_net': var_val = var_val_set if type(var_val_set) != type( None) else 'null' else: var_val = var_val_set if type(var_val_set) != type( None) else None # set class attribute to default value setattr(self, var_name, var_val) def forward(self, src, tgt=None, hidden=None, is_training=False, teacher_forcing_ratio=1.0, att_key_feats=None, beam_width=1): """ Args: src: list of src word_ids [batch_size, max_seq_len, word_ids] tgt: list of tgt word_ids hidden: initial hidden state is_training: whether in eval or train mode teacher_forcing_ratio: default at 1 - always teacher forcing Returns: decoder_outputs: list of step_output - log predicted_softmax [batch_size, 1, vocab_size_dec] * (T-1) ret_dict """ if self.use_gpu and torch.cuda.is_available(): global device device = torch.device('cuda') else: device = torch.device('cpu') # ****************************************************** # 0. init var ret_dict = dict() ret_dict[KEY_ATTN_SCORE] = [] decoder_outputs = [] sequence_symbols = [] batch_size = self.batch_size lengths = np.array([self.max_seq_len] * batch_size) self.beam_width = beam_width if not hasattr(self, 'num_unilstm_enc'): self.num_unilstm_enc = 0 if not hasattr(self, 'residual'): self.residual = False self.check_var('use_bpe', var_val_set=False) # src mask mask_src = src.data.eq(PAD) # print(mask_src[0]) # ****************************************************** # 1. convert id to embedding emb_src = self.embedding_dropout(self.embedder_enc(src)) if type(tgt) == type(None): tgt = torch.Tensor([BOS]).repeat(src.size()).type( torch.LongTensor).to(device=device) emb_tgt = self.embedding_dropout(self.embedder_dec(tgt)) # ****************************************************** # 2. run enc enc_outputs, enc_hidden = self.enc(emb_src, hidden) enc_outputs = self.dropout(enc_outputs)\ .view(self.batch_size, self.max_seq_len, enc_outputs.size(-1)) if self.num_unilstm_enc != 0: if not self.residual: enc_hidden_uni_init = None enc_outputs, enc_hidden_uni = self.enc_uni( enc_outputs, enc_hidden_uni_init) enc_outputs = self.dropout(enc_outputs)\ .view(self.batch_size, self.max_seq_len, enc_outputs.size(-1)) else: enc_hidden_uni_init = None enc_hidden_uni_lis = [] for i in range(self.num_unilstm_enc): enc_inputs = enc_outputs enc_func = getattr(self.enc_uni, 'l' + str(i)) enc_outputs, enc_hidden_uni = enc_func( enc_inputs, enc_hidden_uni_init) enc_hidden_uni_lis.append(enc_hidden_uni) if i < self.num_unilstm_enc - 1: # no residual for last layer enc_outputs = enc_outputs + enc_inputs enc_outputs = self.dropout(enc_outputs)\ .view(self.batch_size, self.max_seq_len, enc_outputs.size(-1)) # ****************************************************** # 2.5 att inputs: keys n values if type(att_key_feats) == type(None): att_keys = enc_outputs else: # att_key_feats: b x max_seq_len x additional_key_size assert self.additional_key_size == att_key_feats.size(-1), \ 'Mismatch in attention key dimension!' att_keys = torch.cat((enc_outputs, att_key_feats), dim=2) att_vals = enc_outputs # ****************************************************** # 3. init hidden states dec_hidden = None # ====================================================== # decoder def decode(step, step_output, step_attn): """ Greedy decoding Note: it should generate EOS, PAD as used in training tgt Args: step: step idx step_output: log predicted_softmax [batch_size, 1, vocab_size_dec] step_attn: attention scores - (batch_size x tgt_len(query_len) x src_len(key_len) Returns: symbols: most probable symbol_id [batch_size, 1] """ ret_dict[KEY_ATTN_SCORE].append(step_attn) decoder_outputs.append(step_output) symbols = decoder_outputs[-1].topk(1)[1] sequence_symbols.append(symbols) eos_batches = torch.max(symbols.data.eq(EOS), symbols.data.eq(PAD)) # eos_batches = symbols.data.eq(PAD) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > step) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return symbols def decode_dd(step, step_output, step_attn): # same as decode - used in eval scripts # only decode over the input vocab ret_dict[KEY_ATTN_SCORE].append(step_attn) decoder_outputs.append(step_output) symbols = decoder_outputs[-1].topk(1)[1] src_detach = src.clone().detach().contiguous() src_detach[src == EOS] = 0 # eos prob eos_prob = step_output[:, EOS].view(-1, 1) PROB_BIAS = 5 output_trim = torch.gather(step_output, 1, src_detach) val, ind = output_trim.topk(1) choice_eos = (eos_prob > (val + PROB_BIAS)).type('torch.LongTensor')\ .to(device=device) choice_ind = (eos_prob <= (val + PROB_BIAS)).type('torch.LongTensor')\ .to(device=device) symbols_trim = torch.gather(src_detach, 1, ind) symbols_choice = symbols_trim * choice_ind + EOS * choice_eos sequence_symbols.append(symbols_choice) eos_batches = torch.max(symbols_choice.eq(EOS), symbols_choice.data.eq(PAD)) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > step) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return symbols_choice # ====================================================== # ****************************************************** # 4. run dec + att + shared + output """ teacher_forcing_ratio = 1.0 -> always teacher forcing E.g.: emb_tgt = <s> w1 w2 w3 </s> <pad> <pad> <pad> [max_seq_len] tgt_chunk in = <s> w1 w2 w3 </s> <pad> <pad> [max_seq_len - 1] predicted = w1 w2 w3 </s> <pad> <pad> <pad> [max_seq_len - 1] (shift-by-1) """ use_teacher_forcing = True if random.random( ) < teacher_forcing_ratio else False if not is_training: use_teacher_forcing = False # beam search decoding if not is_training and self.beam_width > 1: decoder_outputs, decoder_hidden, metadata = \ self.beam_search_decoding(att_keys, att_vals, dec_hidden, mask_src, beam_width=self.beam_width) return decoder_outputs, decoder_hidden, metadata # no beam search decoding self.check_var('ptr_net') assert self.ptr_net == 'null' # use without ptr net tgt_chunk = emb_tgt[:, 0].unsqueeze(1) # BOS cell_value = torch.FloatTensor([0])\ .repeat(self.batch_size, 1, self.hidden_size_shared).to(device=device) prev_c = torch.FloatTensor([0])\ .repeat(self.batch_size, 1, self.max_seq_len).to(device=device) for idx in range(self.max_seq_len - 1): predicted_logsoftmax, dec_hidden, step_attn, c_out, cell_value, p_gen = \ self.forward_step(att_keys, att_vals, tgt_chunk, cell_value, dec_hidden, mask_src, prev_c) predicted_logsoftmax = predicted_logsoftmax.squeeze( 1) # [b, vocab_size] predicted_softmax = torch.exp(predicted_logsoftmax) step_output = predicted_logsoftmax symbols = decode(idx, step_output, step_attn) # symbols = decode_dd(idx, step_output, step_attn) prev_c = c_out if use_teacher_forcing: tgt_chunk = emb_tgt[:, idx + 1].unsqueeze(1) else: tgt_chunk = self.embedder_dec(symbols) ret_dict[KEY_SEQUENCE] = sequence_symbols ret_dict[KEY_LENGTH] = lengths.tolist() return decoder_outputs, dec_hidden, ret_dict def forward_step(self, att_keys, att_vals, tgt_chunk, prev_cell_value, dec_hidden=None, mask_src=None, prev_c=None): """ manual unrolling - can only operate per time step Args: att_keys: [batch_size, seq_len, hidden_size_enc * 2 + optional key size (key_size)] att_vals: [batch_size, seq_len, hidden_size_enc * 2 (val_size)] tgt_chunk: tgt word embeddings non teacher forcing - [batch_size, 1, embedding_size_dec] (lose 1 dim when indexed) prev_cell_value: previous cell value before prediction [batch_size, 1, self.state_size] dec_hidden: initial hidden state for dec layer mask_src: mask of PAD for src sequences prev_c: used in hybrid attention mechanism Returns: predicted_softmax: log probilities [batch_size, vocab_size_dec] dec_hidden: a list of hidden states of each dec layer attn: attention weights cell_value: transformed attention output [batch_size, 1, self.hidden_size_shared] """ # record sizes batch_size = tgt_chunk.size(0) tgt_chunk_etd = torch.cat([tgt_chunk, prev_cell_value], -1) tgt_chunk_etd = tgt_chunk_etd.view( -1, 1, self.embedding_size + self.hidden_size_shared) # run dec # default dec_hidden: [h_0, c_0]; with h_0 # [num_layers * num_directions(==1), batch, hidden_size] if not self.residual: dec_outputs, dec_hidden = self.dec(tgt_chunk, dec_hidden) dec_outputs = self.dropout(dec_outputs) else: # store states layer by layer # num_layers * ([1, batch, hidden_size], [1, batch, hidden_size]) dec_hidden_lis = [] # layer0 dec_func_first = getattr(self.dec, 'l0') if type(dec_hidden) == type(None): dec_outputs, dec_hidden_out = dec_func_first( tgt_chunk_etd, None) else: index = torch.tensor([0]).to( device=device) # choose the 0th layer dec_hidden_in = tuple( [h.index_select(dim=0, index=index) for h in dec_hidden]) dec_outputs, dec_hidden_out = dec_func_first( tgt_chunk_etd, dec_hidden_in) dec_hidden_lis.append(dec_hidden_out) # no residual for 0th layer dec_outputs = self.dropout(dec_outputs) # layer1+ for i in range(1, self.num_unilstm_dec): dec_inputs = dec_outputs dec_func = getattr(self.dec, 'l' + str(i)) if type(dec_hidden) == type(None): dec_outputs, dec_hidden_out = dec_func(dec_inputs, None) else: index = torch.tensor([i]).to(device=device) dec_hidden_in = tuple([ h.index_select(dim=0, index=index) for h in dec_hidden ]) dec_outputs, dec_hidden_out = dec_func( dec_inputs, dec_hidden_in) dec_hidden_lis.append(dec_hidden_out) if i < self.num_unilstm_dec - 1: dec_outputs = dec_outputs + dec_inputs dec_outputs = self.dropout(dec_outputs) # convert to tuple h_0 = torch.cat([h[0] for h in dec_hidden_lis], 0) c_0 = torch.cat([h[1] for h in dec_hidden_lis], 0) dec_hidden = tuple([h_0, c_0]) # run att self.att.set_mask(mask_src) att_outputs, attn, c_out = self.att(dec_outputs, att_keys, att_vals, prev_c=prev_c) att_outputs = self.dropout(att_outputs) # run ff + softmax ff_inputs = torch.cat((att_outputs, dec_outputs), dim=-1) ff_inputs_size = self.hidden_size_enc * 2 + self.hidden_size_dec cell_value = self.ffn(ff_inputs.view(-1, 1, ff_inputs_size)) # 600 -> 200 outputs = self.out(cell_value.contiguous().view( -1, self.hidden_size_shared)) predicted_logsoftmax = F.log_softmax(outputs, dim=1).view(batch_size, 1, -1) # ptr net assert self.ptr_net == 'null' # dummy p_gen p_gen = torch.FloatTensor([1]).repeat(self.batch_size, 1, 1).to(device=device) return predicted_logsoftmax, dec_hidden, attn, c_out, cell_value, p_gen def beam_search_decoding(self, att_keys, att_vals, dec_hidden=None, mask_src=None, prev_c=None, beam_width=10): """ beam search decoding - only used for evaluation Modified from - https://github.com/IBM/pytorch-seq2seq/blob/master/seq2seq/models/TopKDecoder.py Shortcuts: beam_width: k batch_size: b vocab_size: v max_seq_len: l Args: att_keys: [b x l x hidden_size_enc * 2 + optional key size (key_size)] att_vals: [b x l x hidden_size_enc * 2 (val_size)] dec_hidden: initial hidden state for dec layer [b x h_dec] mask_src: mask of PAD for src sequences beam_width: beam width kept during searching Returns: decoder_outputs: output probabilities [(batch, 1, vocab_size)] * T decoder_hidden (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden state of the decoder. ret_dict: dictionary containing additional information as follows { *length* : list of integers representing lengths of output sequences, *topk_length*: list of integers representing lengths of beam search sequences, *sequence* : list of sequences, where each sequence is a list of predicted token IDs, *topk_sequence* : list of beam search sequences, each beam is a list of token IDs, *outputs* : [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n) }. """ # define var self.beam_width = beam_width self.pos_index = Variable( torch.LongTensor(range(self.batch_size)) * self.beam_width).view( -1, 1).to(device=device) # initialize the input vector; att_c_value input_var = Variable( torch.transpose( torch.LongTensor([[BOS] * self.batch_size * self.beam_width]), 0, 1)).to(device=device) input_var_emb = self.embedder_dec(input_var).to(device=device) prev_c = torch.FloatTensor([0]).repeat( self.batch_size, 1, self.max_seq_len).to(device=device) cell_value = torch.FloatTensor([0]).repeat( self.batch_size, 1, self.hidden_size_shared).to(device=device) # inflate attention keys and values (derived from encoder outputs) # correct ordering inflated_att_keys = att_keys.repeat_interleave(self.beam_width, dim=0) inflated_att_vals = att_vals.repeat_interleave(self.beam_width, dim=0) inflated_mask_src = mask_src.repeat_interleave(self.beam_width, dim=0) inflated_prev_c = prev_c.repeat_interleave(self.beam_width, dim=0) inflated_cell_value = cell_value.repeat_interleave(self.beam_width, dim=0) # inflate hidden states and others # note that inflat_hidden_state might be faulty - currently using None so it's fine dec_hidden = inflat_hidden_state(dec_hidden, self.beam_width) # Initialize the scores; for the first step, # ignore the inflated copies to avoid duplicate entries in the top k sequence_scores = torch.Tensor(self.batch_size * self.beam_width, 1).to(device=device) sequence_scores.fill_(-float('Inf')) sequence_scores.index_fill_( 0, torch.LongTensor([ i * self.beam_width for i in range(0, self.batch_size) ]).to(device=device), 0.0) sequence_scores = Variable(sequence_scores) # Store decisions for backtracking stored_outputs = list() # raw softmax scores [bk x v] * T stored_scores = list() # topk scores [bk] * T stored_predecessors = list() # preceding beam idx (from 0-bk) [bk] * T stored_emitted_symbols = list() # word ids [bk] * T stored_hidden = list() # for _ in range(self.max_seq_len): predicted_softmax, dec_hidden, step_attn, inflated_c_out, inflated_cell_value, _ = \ self.forward_step(inflated_att_keys, inflated_att_vals, input_var_emb, inflated_cell_value, dec_hidden, inflated_mask_src, inflated_prev_c) inflated_prev_c = inflated_c_out # retain output probs stored_outputs.append(predicted_softmax) # [bk x v] # To get the full sequence scores for the new candidates, # add the local scores for t_i to the predecessor scores for t_(i-1) sequence_scores = _inflate(sequence_scores, self.vocab_size, 1) sequence_scores += predicted_softmax.squeeze(1) # [bk x v] scores, candidates = sequence_scores.view(self.batch_size, -1)\ .topk(self.beam_width, dim=1) # [b x kv] -> [b x k] # Reshape input = (bk, 1) and sequence_scores = (bk, 1) input_var = (candidates % self.vocab_size)\ .view(self.batch_size * self.beam_width, 1).to(device=device) input_var_emb = self.embedder_dec(input_var) sequence_scores = scores.view(self.batch_size * self.beam_width, 1) #[bk x 1] # Update fields for next timestep predecessors = (candidates / self.vocab_size + self.pos_index.expand_as(candidates))\ .view(self.batch_size * self.beam_width, 1) # dec_hidden: [h_0, c_0]; with h_0 [num_layers * num_directions, batch, hidden_size] if isinstance(dec_hidden, tuple): dec_hidden = tuple([ h.index_select(1, predecessors.squeeze()) for h in dec_hidden ]) else: dec_hidden = dec_hidden.index_select(1, predecessors.squeeze()) stored_scores.append(sequence_scores.clone()) # Cache results for backtracking stored_predecessors.append(predecessors) stored_emitted_symbols.append(input_var) stored_hidden.append(dec_hidden) # Do backtracking to return the optimal values output, h_t, h_n, s, l, p = self._backtrack( stored_outputs, stored_hidden, stored_predecessors, stored_emitted_symbols, stored_scores, self.batch_size, self.hidden_size_dec) # Build return objects decoder_outputs = [step[:, 0, :].squeeze(1) for step in output] if isinstance(h_n, tuple): decoder_hidden = tuple([h[:, :, 0, :] for h in h_n]) else: decoder_hidden = h_n[:, :, 0, :] metadata = {} metadata['output'] = output metadata['h_t'] = h_t metadata['score'] = s metadata['topk_length'] = l metadata['topk_sequence'] = p # [b x k x 1] * T metadata['length'] = [seq_len[0] for seq_len in l] metadata['sequence'] = [seq[:, 0] for seq in p] return decoder_outputs, decoder_hidden, metadata def _backtrack(self, nw_output, nw_hidden, predecessors, symbols, scores, b, hidden_size): """ Backtracks over batch to generate optimal k-sequences. https://github.com/IBM/pytorch-seq2seq/blob/master/seq2seq/models/TopKDecoder.py Args: nw_output [(batch*k, vocab_size)] * sequence_length: A Tensor of outputs from network nw_hidden [(num_layers, batch*k, hidden_size)] * sequence_length: A Tensor of hidden states from network predecessors [(batch*k)] * sequence_length: A Tensor of predecessors symbols [(batch*k)] * sequence_length: A Tensor of predicted tokens scores [(batch*k)] * sequence_length: A Tensor containing sequence scores for every token t = [0, ... , seq_len - 1] b: Size of the batch hidden_size: Size of the hidden state Returns: output [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n) from the last layer of the RNN, for every n = [0, ... , seq_len - 1] h_t [(batch, k, hidden_size)] * sequence_length: A list containing the output features (h_n) from the last layer of the RNN, for every n = [0, ... , seq_len - 1] h_n(batch, k, hidden_size): A Tensor containing the last hidden state for all top-k sequences. score [batch, k]: A list containing the final scores for all top-k sequences length [batch, k]: A list specifying the length of each sequence in the top-k candidates p (batch, k, sequence_len): A Tensor containing predicted sequence [b x k x 1] * T """ # initialize return variables given different types output = list() h_t = list() p = list() # Placeholder for last hidden state of top-k sequences. # If a (top-k) sequence ends early in decoding, `h_n` contains # its hidden state when it sees EOS. Otherwise, `h_n` contains # the last hidden state of decoding. lstm = isinstance(nw_hidden[0], tuple) if lstm: state_size = nw_hidden[0][0].size() h_n = tuple([ torch.zeros(state_size).to(device=device), torch.zeros(state_size).to(device=device) ]) else: h_n = torch.zeros(nw_hidden[0].size()).to(device=device) # Placeholder for lengths of top-k sequences # Similar to `h_n` l = [[self.max_seq_len] * self.beam_width for _ in range(b)] # the last step output of the beams are not sorted # thus they are sorted here sorted_score, sorted_idx = scores[-1].view(b, self.beam_width).topk( self.beam_width) sorted_score = sorted_score.to(device=device) sorted_idx = sorted_idx.to(device=device) # initialize the sequence scores with the sorted last step beam scores s = sorted_score.clone().to(device=device) batch_eos_found = [0] * b # the number of EOS found # in the backward loop below for each batch t = self.max_seq_len - 1 # initialize the back pointer with the sorted order of the last step beams. # add self.pos_index for indexing variable with b*k as the first dimension. t_predecessors = (sorted_idx + self.pos_index.expand_as(sorted_idx))\ .view(b * self.beam_width).to(device=device) while t >= 0: # Re-order the variables with the back pointer current_output = nw_output[t].index_select(0, t_predecessors) if lstm: current_hidden = tuple( [h.index_select(1, t_predecessors) for h in nw_hidden[t]]) else: current_hidden = nw_hidden[t].index_select(1, t_predecessors) current_symbol = symbols[t].index_select(0, t_predecessors) # Re-order the back pointer of the previous step with the back pointer of # the current step t_predecessors = predecessors[t].index_select( 0, t_predecessors).squeeze().to(device=device) """ This tricky block handles dropped sequences that see EOS earlier. The basic idea is summarized below: Terms: Ended sequences = sequences that see EOS early and dropped Survived sequences = sequences in the last step of the beams Although the ended sequences are dropped during decoding, their generated symbols and complete backtracking information are still in the backtracking variables. For each batch, everytime we see an EOS in the backtracking process, 1. If there is survived sequences in the return variables, replace the one with the lowest survived sequence score with the new ended sequences 2. Otherwise, replace the ended sequence with the lowest sequence score with the new ended sequence """ eos_indices = symbols[t].data.squeeze(1).eq(EOS).nonzero().to( device=device) if eos_indices.dim() > 0: for i in range(eos_indices.size(0) - 1, -1, -1): # Indices of the EOS symbol for both variables # with b*k as the first dimension, and b, k for # the first two dimensions idx = eos_indices[i] b_idx = int(idx[0] / self.beam_width) # The indices of the replacing position # according to the replacement strategy noted above res_k_idx = self.beam_width - (batch_eos_found[b_idx] % self.beam_width) - 1 batch_eos_found[b_idx] += 1 res_idx = b_idx * self.beam_width + res_k_idx # Replace the old information in return variables # with the new ended sequence information t_predecessors[res_idx] = predecessors[t][idx[0]].to( device=device) current_output[res_idx, :] = nw_output[t][idx[0], :].to( device=device) if lstm: current_hidden[0][:, res_idx, :] = nw_hidden[t][ 0][:, idx[0], :].to(device=device) current_hidden[1][:, res_idx, :] = nw_hidden[t][ 1][:, idx[0], :].to(device=device) h_n[0][:, res_idx, :] = nw_hidden[t][ 0][:, idx[0], :].data.to(device=device) h_n[1][:, res_idx, :] = nw_hidden[t][ 1][:, idx[0], :].data.to(device=device) else: current_hidden[:, res_idx, :] = nw_hidden[ t][:, idx[0], :].to(device=device) h_n[:, res_idx, :] = nw_hidden[t][:, idx[0], :].data.to( device=device) current_symbol[res_idx, :] = symbols[t][idx[0]].to( device=device) s[b_idx, res_k_idx] = scores[t][idx[0]].data[0].to(device=device) l[b_idx][res_k_idx] = t + 1 # record the back tracked results output.append(current_output) h_t.append(current_hidden) p.append(current_symbol) t -= 1 # Sort and re-order again as the added ended sequences may change # the order (very unlikely) s, re_sorted_idx = s.topk(self.beam_width) for b_idx in range(b): l[b_idx] = [ l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :] ] re_sorted_idx = (re_sorted_idx + self.pos_index.expand_as(re_sorted_idx))\ .view(b * self.beam_width).to(device=device) # Reverse the sequences and re-order at the same time # It is reversed because the backtracking happens in reverse time order output = [step.index_select(0, re_sorted_idx)\ .view(b, self.beam_width, -1) for step in reversed(output)] p = [ step.index_select(0, re_sorted_idx).view(b, self.beam_width, -1) for step in reversed(p) ] if lstm: h_t = [ tuple([ h.index_select(1, re_sorted_idx.to(device=device)).view( -1, b, self.beam_width, hidden_size) for h in step ]) for step in reversed(h_t) ] h_n = tuple([ h.index_select(1, re_sorted_idx.data.to(device=device)).view( -1, b, self.beam_width, hidden_size) for h in h_n ]) else: h_t = [ step.index_select(1, re_sorted_idx.to(device=device)).view( -1, b, self.beam_width, hidden_size) for step in reversed(h_t) ] h_n = h_n.index_select(1, re_sorted_idx.data.to(device=device))\ .view(-1, b, self.beam_width, hidden_size) s = s.data return output, h_t, h_n, s, l, p
def __init__(self, vocab_size_dec, embedding_size_dec=200, embedding_dropout=0, hidden_size_enc=200, hidden_size_dec=200, num_unilstm_dec=2, att_mode='bahdanau', hidden_size_att=10, hidden_size_shared=200, dropout=0.0, residual=False, batch_first=True, max_seq_len=32, load_embedding_tgt=None, tgt_word2id=None, tgt_id2word=None ): super(DecRNN, self).__init__() # define embeddings self.vocab_size_dec = vocab_size_dec self.embedding_size_dec = embedding_size_dec self.load_embedding = load_embedding_tgt self.word2id = tgt_word2id self.id2word = tgt_id2word # define model params self.hidden_size_enc = hidden_size_enc self.hidden_size_dec = hidden_size_dec self.num_unilstm_dec = num_unilstm_dec self.hidden_size_att = hidden_size_att self.hidden_size_shared = hidden_size_shared # [200] self.max_seq_len = max_seq_len self.residual = residual # define operations self.embedding_dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(dropout) # load embeddings if self.load_embedding: # import pdb; pdb.set_trace() embedding_matrix = np.random.rand(self.vocab_size_dec, self.embedding_size_dec) embedding_matrix = torch.FloatTensor(load_pretrained_embedding( self.word2id, embedding_matrix, self.load_embedding)) self.embedder_dec = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) else: self.embedder_dec = nn.Embedding(self.vocab_size_dec, self.embedding_size_dec, sparse=False, padding_idx=PAD) # define dec # embedding_size_dec + self.hidden_size_shared [200+200] -> hidden_size_dec [200] if not self.residual: self.dec = torch.nn.LSTM( self.embedding_size_dec + self.hidden_size_shared, self.hidden_size_dec, num_layers=self.num_unilstm_dec, batch_first=batch_first, bias=True, dropout=dropout,bidirectional=False ) else: lstm_uni_dec_first = torch.nn.LSTM( self.embedding_size_dec + self.hidden_size_shared, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False ) self.dec = nn.Module() self.dec.add_module('l0', lstm_uni_dec_first) for i in range(1, self.num_unilstm_dec): self.dec.add_module( 'l'+str(i), torch.nn.LSTM(self.hidden_size_dec, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False ) ) # define att # query: hidden_size_dec [200] # keys: hidden_size_enc * 2 [400] # values: hidden_size_enc * 2 [400] # context: weighted sum of values [400] self.key_size = self.hidden_size_enc * 2 self.value_size = self.hidden_size_enc * 2 self.query_size = self.hidden_size_dec self.att = AttentionLayer( self.query_size, self.key_size, value_size=self.value_size, mode=att_mode, dropout=dropout, query_transform=False, output_transform=False, hidden_size=self.hidden_size_att, hard_att=False) # define output # (hidden_size_enc * 2 + hidden_size_dec) # -> self.hidden_size_shared -> vocab_size_dec self.ffn = nn.Linear(self.hidden_size_enc * 2 + self.hidden_size_dec, self.hidden_size_shared, bias=False) self.out = nn.Linear(self.hidden_size_shared, self.vocab_size_dec, bias=True)
def __init__( self, # params vocab_size, embedding_size=200, acous_hidden_size=256, acous_att_mode='bahdanau', hidden_size_dec=200, hidden_size_shared=200, num_unilstm_dec=4, use_type='char', # embedding_dropout=0, dropout=0.0, residual=True, batch_first=True, max_seq_len=32, load_embedding=None, word2id=None, id2word=None, hard_att=False, use_gpu=False): super(Dec, self).__init__() device = check_device(use_gpu) # define model self.acous_hidden_size = acous_hidden_size self.acous_att_mode = acous_att_mode self.hidden_size_dec = hidden_size_dec self.hidden_size_shared = hidden_size_shared self.num_unilstm_dec = num_unilstm_dec # define var self.hard_att = hard_att self.residual = residual self.max_seq_len = max_seq_len self.use_type = use_type # use shared embedding + vocab self.vocab_size = vocab_size self.embedding_size = embedding_size self.load_embedding = load_embedding self.word2id = word2id self.id2word = id2word # define operations self.embedding_dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(dropout) # ------- load embeddings -------- if self.load_embedding: embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding( self.word2id, embedding_matrix, self.load_embedding) embedding_matrix = torch.FloatTensor(embedding_matrix) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) else: self.embedder = nn.Embedding(self.vocab_size, self.embedding_size, sparse=False, padding_idx=PAD) # ------ define acous att -------- dropout_acous_att = dropout self.acous_hidden_size_att = 0 # ignored with bilinear self.acous_key_size = self.acous_hidden_size * 2 # acous feats self.acous_value_size = self.acous_hidden_size * 2 # acous feats self.acous_query_size = self.hidden_size_dec # use dec(words) as query self.acous_att = AttentionLayer(self.acous_query_size, self.acous_key_size, value_size=self.acous_value_size, mode=self.acous_att_mode, dropout=dropout_acous_att, query_transform=False, output_transform=False, hidden_size=self.acous_hidden_size_att, use_gpu=use_gpu, hard_att=False) # ------ define acous out -------- self.acous_ffn = nn.Linear(self.acous_hidden_size * 2 + self.hidden_size_dec, self.hidden_size_shared, bias=False) self.acous_out = nn.Linear(self.hidden_size_shared, self.vocab_size, bias=True) # ------ define acous dec ------- # embedding_size_dec + self.hidden_size_shared [200+200]-> hidden_size_dec [200] if not self.residual: self.dec = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=self.num_unilstm_dec, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: self.dec = nn.Module() self.dec.add_module( 'l0', torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) for i in range(1, self.num_unilstm_dec): self.dec.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_dec, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False))
def __init__( self, # params vocab_size, embedding_size=200, acous_hidden_size=256, acous_att_mode='bahdanau', hidden_size_dec=200, hidden_size_shared=200, num_unilstm_dec=4, # embedding_dropout=0, dropout=0.0, residual=True, batch_first=True, max_seq_len=32, embedder=None, word2id=None, id2word=None, hard_att=False, ): super(Dec, self).__init__() # define model self.acous_hidden_size = acous_hidden_size self.acous_att_mode = acous_att_mode self.hidden_size_dec = hidden_size_dec self.hidden_size_shared = hidden_size_shared self.num_unilstm_dec = num_unilstm_dec # define var self.hard_att = hard_att self.residual = residual self.max_seq_len = max_seq_len # use shared embedding + vocab self.vocab_size = vocab_size self.embedding_size = embedding_size self.word2id = word2id self.id2word = id2word # define operations self.embedding_dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(dropout) if type(embedder) != type(None): self.embedder = embedder else: self.embedder = nn.Embedding(self.vocab_size, self.embedding_size, sparse=False, padding_idx=PAD) # ------ define acous att -------- dropout_acous_att = dropout self.acous_hidden_size_att = 0 # ignored with bilinear self.acous_key_size = self.acous_hidden_size * 2 # acous feats self.acous_value_size = self.acous_hidden_size * 2 # acous feats self.acous_query_size = self.hidden_size_dec # use dec(words) as query self.acous_att = AttentionLayer(self.acous_query_size, self.acous_key_size, value_size=self.acous_value_size, mode=self.acous_att_mode, dropout=dropout_acous_att, query_transform=False, output_transform=False, hidden_size=self.acous_hidden_size_att, hard_att=False) # ------ define acous out -------- self.acous_ffn = nn.Linear(self.acous_hidden_size * 2 + self.hidden_size_dec, self.hidden_size_shared, bias=False) self.acous_out = nn.Linear(self.hidden_size_shared, self.vocab_size, bias=True) # ------ define acous dec ------- # embedding_size_dec + self.hidden_size_shared [200+200]-> hidden_size_dec [200] if not self.residual: self.dec = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=self.num_unilstm_dec, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: self.dec = nn.Module() self.dec.add_module( 'l0', torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) for i in range(1, self.num_unilstm_dec): self.dec.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_dec, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False))