def forward(self, input, hidden, return_h=False): batch_size = input.size(1) emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0) emb = self.lockdrop(emb, self.dropouti) raw_output = emb new_hidden = [] raw_outputs = [] outputs = [] for l, rnn in enumerate(self.rnns): raw_output, new_h = rnn(raw_output, hidden[l]) new_hidden.append(new_h) raw_outputs.append(raw_output) hidden = new_hidden output = self.lockdrop(raw_output, self.dropout) outputs.append(output) logit = self.decoder(output.view(-1, self.ninp)) log_prob = nn.functional.log_softmax(logit, dim=-1) model_output = log_prob model_output = model_output.view(-1, batch_size, self.ntoken) if return_h: return model_output, hidden, raw_outputs, outputs return model_output, hidden
def forward(self, input, hidden, genotype, return_h=False): # 传入的hidden注意下 batch_size = input.size(1) emb = embedded_dropout(self.encoder, input, dropout=self.dropout_e if self.training else 0) emb = self.lockdrop(emb, self.dropout_i) raw_output = emb raw_output, new_h = self.rnn( raw_output, hidden, genotype) # hidden states, last hidden states hidden = new_h output = self.lockdrop(raw_output, self.dropout) # raw output是rnn每个t时刻的hidden # 下面这层,实际上就是hidden->output的那个线性层 logit = self.decoder(output.view( -1, self.n_inp)) # 计算这个batch的logit值,view一下一起计算 log_prob = nn.functional.log_softmax(logit, dim=-1) model_output = log_prob model_output = model_output.view(-1, batch_size, self.n_token) if return_h: return model_output, hidden, raw_output, output return model_output, hidden
def forward(self, input, hidden_input_all, BN_start=0): seq_len, batch_size = input.size() rnnoutputs = {} hidden_x = {} if hidden_input_all is not None: hidden_x.update(hidden_input_all) else: for x in range(len(self.RNNs)): hidden_x['hidden%d' % x] = Variable( torch.zeros(1, batch_size, args.hidden_size).cuda()) hidden_x['hidden_lastindrnn'] = Variable( torch.zeros(1, batch_size, args.embed_size).cuda()) input = input.view(seq_len * batch_size) input = embedded_dropout( self.encoder, input, dropout=args.dropout_embedding if self.training else 0) input = input.view(seq_len, batch_size, args.embed_size) if args.dropout_words > 0: input = dropout_overtime(input, args.dropout_words, self.training) rnnoutputs['outlayer-1'] = input for x in range(len(self.RNNs)): rnnoutputs['dilayer%d' % x] = self.DIs[x](rnnoutputs['outlayer%d' % (x - 1)]) rnnoutputs['outlayer%d' % x], hidden_x['hidden%d' % x] = self.RNNs[x]( rnnoutputs['dilayer%d' % x], BN_start, hidden_x['hidden%d' % x]) if args.dropout > 0: dropout = args.dropout if x == len(self.RNNs) - 1: dropout = args.dropout_last rnnoutputs['outlayer%d' % x] = dropout_overtime( rnnoutputs['outlayer%d' % x], dropout, self.training) rnn_out = rnnoutputs['outlayer%d' % (len(self.RNNs) - 1)] rnn_out = self.last_fc(rnn_out) rnn_out, hidden_x['hidden_lastindrnn'] = self.last_indrnn( rnn_out, BN_start, hidden_x['hidden_lastindrnn']) if args.bn_location == 'bn_before': rnn_out = self.extra_bn(rnn_out, BN_start) rnn_out = dropout_overtime(rnn_out, args.dropout_extrafc, self.training) rnn_out = rnn_out.view(seq_len * batch_size, -1) output = self.decoder(rnn_out) return output, hidden_x
def forward(self, input, hidden_input_all, BN_start=0): seq_len, batch_size = input.size() rnnoutputs = {} hidden_x = {} hid_ind_layer = 0 if hidden_input_all is not None: hidden_x.update(hidden_input_all) input = input.view(seq_len * batch_size) input = embedded_dropout( self.encoder, input, dropout=args.dropout_embedding if self.training else 0) input = input.view(seq_len, batch_size, args.embed_size) if args.dropout_words > 0: input = dropout_overtime(input, args.dropout_words, self.training) # rnnoutputs['trans_outlayer-1'] = input for i in range(self.num_blocks): rnnoutputs['dense_outlayer%d' % i], hidden_x, hid_ind_layer = self.DenseBlocks[i]( rnnoutputs['trans_outlayer%d' % (i - 1)], BN_start, hidden_x, hid_ind_layer) if i != self.num_blocks - 1: rnnoutputs['trans_outlayer%d' % i], hidden_x, hid_ind_layer = self.TransBlocks[i]( rnnoutputs['dense_outlayer%d' % i], BN_start, hidden_x, hid_ind_layer) rnnoutputs['trans_outlayer%d' % (self.num_blocks - 1)], hidden_x, hid_ind_layer = self.TransBlocks[ self.num_blocks - 1](rnnoutputs['dense_outlayer%d' % (self.num_blocks - 1)], BN_start, hidden_x, hid_ind_layer) rnn_out = rnnoutputs['trans_outlayer%d' % (self.num_blocks - 1)] rnn_out = rnn_out.view(seq_len * batch_size, args.embed_size) rnn_out = self.decoder(rnn_out) return rnn_out, hidden_x
def forward(self, input,hidden_input_all, BN_start=0): seq_len, batch_size = input.size() rnnoutputs={} hidden_x={} if hidden_input_all is not None: hidden_x.update(hidden_input_all) else: for x in range(2*len(self.resblocks)): hidden_x['hidden%d'%x]=Variable(torch.zeros(1,batch_size,args.hidden_size).cuda()) hidden_x['hidden_resfinal']=Variable(torch.zeros(1,batch_size,args.hidden_size).cuda()) hidden_x['hidden_lastindrnn']=Variable(torch.zeros(1,batch_size,args.embed_size).cuda()) input = input.view(seq_len * batch_size) input = embedded_dropout(self.encoder,input, dropout=args.dropout_embedding if self.training else 0) input = input.view(seq_len, batch_size, args.embed_size) if args.dropout_words > 0: input = dropout_overtime(input, args.dropout_words, self.training)# rnnoutputs['outlayer-1']=self.fc0(input) for x in range(len(self.resblocks)): rnnoutputs['outlayer%d'%x],hidden_x['hidden%d'%(2*x)],hidden_x['hidden%d'%(2*x+1)]= self.resblocks[x](rnnoutputs['outlayer%d'%(x-1)], BN_start, hidden_x['hidden%d'%(2*x)], hidden_x['hidden%d'%(2*x+1)]) rnn_out=rnnoutputs['outlayer%d'%(len(self.resblocks)-1)] rnn_out,hidden_x['hidden_resfinal']=self.IndRNNwithBN_resfinal(rnn_out, BN_start, hidden_x['hidden_resfinal']) if args.dropout_last>0: rnn_out=dropout_overtime(rnn_out,args.dropout_last, self.training) rnn_out=self.last_fc(rnn_out) rnn_out1, hidden_x['hidden_lastindrnn']=self.last_indrnn(rnn_out, BN_start, hidden_x['hidden_lastindrnn']) if args.bn_location=='bn_before': rnn_out1=self.extra_bn(rnn_out1, BN_start) if args.dropout_extrafc>0: rnn_out1=dropout_overtime(rnn_out1,args.dropout_extrafc, self.training) rnn_out1 = rnn_out1.view(seq_len * batch_size, args.embed_size) rnn_out1=self.decoder(rnn_out1) return rnn_out1, hidden_x