def _compute_attention_sum(q, m, length): # q : batch_size x lstm_size # m : batch_size x max(length) x embedded_dim assert torch.max(length) == m.size()[1] max_len = m.size()[1] if simple: if q.size()[-1] != m.size()[-1]: q = self.attention(q) # batch_size x embedded_dim weight_logit = torch.bmm(m, q.unsqueeze(-1)).squeeze(2) # batch_size x n_features else: linear_m = self.attention[1] linear_q = self.attention[0] linear_out = self.attention[2] packed = pack(m, list(length), batch_first=True) proj_m = PackedSequence(linear_m(packed.data), packed.batch_sizes) proj_m, _ = pad(proj_m, batch_first=True) # batch_size x n_features x proj_dim proj_q = linear_q(q).unsqueeze(1) # batch_size x 1 x proj_dim packed = pack(F.relu(proj_m + proj_q), list(length), batch_first=True) weight_logit = PackedSequence(linear_out(packed.data), packed.batch_sizes) weight_logit, _ = pad(weight_logit, batch_first=True) # batch_size x n_features x 1 weight_logit = weight_logit.squeeze(2) # max_len = weight_logit.size()[1] indices = torch.arange(0, max_len, out=torch.LongTensor(max_len).unsqueeze(0)).cuda() # TODO here.. cuda.. mask = indices < length.unsqueeze(1)#.long() weight_logit[1-mask] = -np.inf weight = F.softmax(weight_logit, dim=1) # nonzero x max_len weighted = torch.bmm(weight.unsqueeze(1), m) # batch_size x 1 x max_len # batch_size x max_len x embedded_dim # = batch_size x 1 x embedded_dim return weighted.squeeze(1), weight #nonzero x embedded_dim
def collate(self, batch): ei = [torch.tensor(sent) for sent in batch] eo = [torch.tensor([x if x is not None else -100 for x in sent.target]) for sent in batch] el = torch.tensor([len(sent) for sent in batch]) ei = pad(ei, padding_value = self.pad) eo = pad(eo, padding_value = -100) return Batch(ei, eo, el)
def __call__(self, indices): return { 'src': pad([torch.tensor([0] + self.data[index]) for index in indices]), 'trg': pad([torch.tensor(self.data[index] + [0]) for index in indices], padding_value=-100), 'len': self.lengths[indices], }
def collate(self, items): old = [torch.tensor(i["old"]) for i in items] new = [torch.tensor(i["new"]) for i in items] len_old = [i["len_old"] for i in items] len_new = [i["len_new"] for i in items] return { "old": pad(old, batch_first=True), "new": pad(new, batch_first=True), "len_old": torch.tensor(len_old), "len_new": torch.tensor(len_new) }
def collate(self, batch): ei = pad([torch.tensor(sent) for sent in batch], padding_value=self.pad) eo = pad([torch.tensor(sent) for sent in batch], padding_value=self.pad) el = [len(sent) for sent in batch] rand_tensor = torch.rand(ei.shape) rand_token = torch.randint(2, len(self.vocab), ei.shape) normal_token = ei > 2 position_to_mask = (rand_tensor < self.mask_th) & normal_token position_to_replace = (rand_tensor < self.replace_th) & normal_token ei.masked_fill_(position_to_mask, self.msk) ei.masked_scatter_(position_to_replace, rand_token) eo.masked_fill_(~position_to_mask, self.pad) return Batch(ei, eo, el)
def forward(self, input, encoder_outputs, hidden=None): batch_size = input.size()[0] embedded = self.embedding(input) embedded = self.embedding_dropout(embedded) dec_outs, hidden = self.rnn( pack(embedded, [1] * batch_size, batch_first=True), hidden) # batch_size * 1 * hidden_size dec_outs, lengths = pad(dec_outs, batch_first=True) # sum the bidirectional outputs if self.bidirectional: dec_outs = dec_outs[:, :, :self.hidden_size] + \ dec_outs[:, :, self.hidden_size:] # calculate the attention context vector attn_weights = self.attn(dec_outs, encoder_outputs) context = attn_weights.bmm(encoder_outputs) # Finally concatenate and pass through a linear layer concated_input = T.cat((dec_outs, context), 2) concated_out = self.concat(concated_input.squeeze(1)).unsqueeze(1) concat_output = None try: concat_output = self.output(F.tanh(concated_out)) except Exception as e: print(concated_out.size(), concated_out) return (concat_output, hidden, attn_weights)
def forward(self, inputs): if isinstance(inputs, PackedSequence): # unpack output inputs, lengths = pad(inputs, batch_first=self.batch_first) if self.batch_first: batch_size, max_len = inputs.size()[:2] else: max_len, batch_size = inputs.size()[:2] inputs = inputs.permute(1, 0, 2) # att = torch.mul(inputs, self.att_weights.expand_as(inputs)) # att = att.sum(-1) weights = torch.bmm( inputs, self.att_weights # (1, hidden_size) .permute(1, 0) # (hidden_size, 1) .unsqueeze(0) # (1, hidden_size, 1) # (batch_size, hidden_size, 1) .repeat(batch_size, 1, 1)) attentions = F.softmax(F.relu(weights.squeeze()), dim=-1) # apply weights weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs)) # get the final fixed vector representations of the sentences representations = weighted.sum(1).squeeze() return representations, attentions
def forward(self, subword_idxs, subword_masks, token_start, batch_label): #self.args.use_crf = True bert_outs, _ = self.bert( subword_idxs, token_type_ids=None, attention_mask=subword_masks, output_all_encoded_layers=False, ) lens = token_start.sum(dim=1) bert_outs = torch.split(bert_outs[token_start], lens.tolist()) bert_outs = pad_sequence(bert_outs, batch_first=True) max_len = bert_outs.size(1) mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) # add lstm after bert sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) reverse_idx = torch.sort(sorted_idx, dim=0)[1] bert_outs = bert_outs[sorted_idx] bert_outs = pack(bert_outs, sorted_lens, batch_first=True) bert_outs, hidden = self.lstm(bert_outs) bert_outs, _ = pad(bert_outs, batch_first=True) bert_outs = bert_outs[reverse_idx] out = self.linear(torch.tanh(bert_outs)) if self.args.use_crf: score, seq = self.crf.viterbi_decode(out, mask) else: batch_size = out.size(0) seq_len = out.size(1) out = out.view(-1, out.size(2)) _, seq = torch.max(out, 1) seq = seq.view(batch_size, seq_len) seq = mask.long() * seq return seq
def forward(self, input_idxs, input_masks, syntax_ids=None): bert_outs, _ = self.bert( input_idxs, token_type_ids=None, attention_mask=input_masks, output_all_encoded_layers=False, ) lens = torch.sum(input_idxs.gt(0), dim=1) # bert_outs = torch.split(bert_outs[token_start], lens.tolist()) bert_outs = pad_sequence(bert_outs, batch_first=True) lstm_input = bert_outs if self.use_syntax: syntax_vec = self.syntax_embed(syntax_ids) lstm_input = torch.cat((lstm_input, syntax_vec),-1) max_len = lstm_input.size(1) lstm_input = lstm_input[:, :max_len, :] # mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) # add lstm after bert sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) reverse_idx = torch.sort(sorted_idx, dim=0)[1] lstm_input = lstm_input[sorted_idx] lstm_input = pack(lstm_input, sorted_lens, batch_first=True) lstm_output, (h, _) = self.lstm(lstm_input) # lstm_output:[batch,sequence_length,embeding] output, _ = pad(lstm_output, batch_first=True) output = lstm_output.permute(0, 2, 1) # lstm_output:[batch,embeding,sequence_length] output = nn.MaxPool1d(output, output.size()[2]) # lstm_output:[batch,embeding,1] output = output.squeeze(2) # lstm_output:[batch,embeding] output = output[reverse_idx] output = self.linear(output) out = self.linear(torch.tanh(output)) return out
def forward(self, hidden, encoder_outputs): lengths = None if type(encoder_outputs) is PackedSequence: encoder_outputs, lengths = pad(encoder_outputs, batch_first=True) else: lengths = [len(x) for x in encoder_outputs] batch_size = encoder_outputs.size()[0] attns = cuda(T.zeros(batch_size, max(lengths)), gpu_id=self.gpu_id) lengths = cuda(T.zeros(max(lengths), 1), gpu_id=self.gpu_id) if self.method == 'dot': attns = T.baddbmm(lengths, encoder_outputs, hidden.transpose(2, 1)).squeeze(2) elif self.method == 'general': attended = self.attn(encoder_outputs) attns = T.baddbmm(lengths, attended, hidden.transpose(2, 1)).squeeze(2) elif self.method == 'concat': concated = T.cat( (hidden.expand_as(encoder_outputs), encoder_outputs), 2) energy = self.attn(concated) expanded = self.other.unsqueeze(0).expand(batch_size, 1, self.hidden_size) attns = T.baddbmm(lengths, energy, expanded.transpose(2, 1)).squeeze(2) return F.softmax(attns).unsqueeze(1)
def forward(self, input_idxs, bert_lens, bert_mask, syntax_embed=None): bert_outs = self.bert_embedding(input_idxs, bert_lens, bert_mask) lens = torch.sum(bert_lens.gt(0), dim=1) # bert_outs = torch.split(bert_outs[token_start], lens.tolist()) # bert_outs = pad_sequence(bert_outs, batch_first=True) lstm_input = bert_outs if self.use_syntax: syntax_vec = syntax_embed lstm_input = torch.cat((lstm_input, syntax_vec), -1) # max_len = lstm_input.size(1) # lstm_input = lstm_input[:, :max_len, :] # mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) # add lstm after bert sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) reverse_idx = torch.sort(sorted_idx, dim=0)[1] lstm_input = lstm_input[sorted_idx] lstm_input = pack(lstm_input, sorted_lens, batch_first=True, enforce_sorted=False) lstm_output, (h, _) = self.lstm(lstm_input) # lstm_output:[batch,sequence_length,embeding] output, _ = pad(lstm_output, batch_first=True) output = output.permute(0, 2, 1) # lstm_output:[batch,embeding,sequence_length] if self.args.maxpooling: output = F.max_pool1d(output, output.size()[2]) # lstm_output:[batch,embeding,1] elif self.args.avepooling: output = F.avg_pool1d(output, output.size()[2]) # lstm_output:[batch,embeding,1] output = output.squeeze(2) # lstm_output:[batch,embeding] output = output[reverse_idx] out = self.linear(torch.tanh(output)) out = F.softmax(out,dim=1) return out
def forward(self, input, hx=(None, None, None)): # handle packed data is_packed = type(input) is PackedSequence if is_packed: input, lengths = pad(input) max_length = lengths[0] else: max_length = input.size(1) if self.batch_first else input.size(0) lengths = [input.size(1)] * max_length if self.batch_first else [ input.size(0) ] * max_length batch_size = input.size(0) if self.batch_first else input.size(1) # make the data batch-first if not self.batch_first: input = input.transpose(0, 1) controller_hidden, mem_hidden, last_read = self._init_hidden( hx, batch_size) # batched forward pass per element / word / etc outputs = None chxs = [] read_vectors = [last_read] * max_length outs = [ T.cat([input[:, x, :], last_read], 1) for x in range(max_length) ] for layer in range(self.num_layers): # this layer's hidden states chx = [x[layer] for x in controller_hidden ] if self.mode == 'LSTM' else controller_hidden[layer] # pass through controller outs, read_vectors, (chx, mem_hidden[layer]) = self._layer_forward( outs, layer, (chx, mem_hidden[layer])) chxs.append(chx) if layer == self.num_layers - 1: # final outputs outputs = T.stack(outs, 1) else: # the controller output + read vectors go into next layer outs = [T.cat([o, r], 1) for o, r in zip(outs, read_vectors)] # final hidden values if self.mode == 'LSTM': h = T.stack([x[0] for x in chxs], 0) c = T.stack([x[1] for x in chxs], 0) controller_hidden = (h, c) else: controller_hidden = T.stack(chxs, 0) if not self.batch_first: outputs = outputs.transpose(0, 1) if is_packed: outputs = pack(output, lengths) return outputs, (controller_hidden, mem_hidden, read_vectors[-1])
def collate(self, xs): max_seq_len = max([x['lengths'] for x in xs]) mask = [[1] * int(x['lengths']) + [0] * int(max_seq_len - x['lengths']) for x in xs] mask = torch.tensor(mask, dtype=torch.long) return { 'src': pad([x['src'] for x in xs]), 'trg': torch.stack([x['trg'] for x in xs], dim=-1), 'mask': mask, 'lengths': torch.stack([x['lengths'] for x in xs], dim=-1) }
def forward(self, inputs, lengths): lengths = lengths.contiguous().data.view(-1).tolist() embs = self.dropout(self.emb(inputs)) packed_embs = pack(embs, lengths, batch_first=True) output, _ = self.rnn(packed_embs) output = pad(output, batch_first=True)[0] output = self.dropout(output) output_flat = self.output(output.contiguous().view( output.size(0) * output.size(1), output.size(2))) return output_flat
def forward(self, source, source_lengths, hidden=None): embedded = self.embedding(source) packed = pack(embedded, source_lengths, batch_first=True) outputs, hidden = self.rnn(packed, hidden) outputs, _ = pad(outputs, batch_first=True) # sum the bidirectional outputs if self.bidirectional: outputs = outputs[:, :, :self.hidden_size] + \ outputs[:, :, self.hidden_size:] return outputs, hidden
def forward(self, state, length): # length should be sorted assert len(state.size()) == 3 # batch x n_features x input_dim # input_dim == n_features + 1 batch_size = state.size()[0] self.weight = np.zeros((int(batch_size), self.n_features))#state.data.new(int(batch_size), self.n_features).fill_(0.) nonzero = torch.sum(length > 0).cpu().numpy() # encode only nonzero points if nonzero == 0: return state.new(int(batch_size), self.lstm_size + self.embedded_dim).fill_(0.) length_ = list(length[:nonzero].cpu().numpy()) packed = pack(state[:nonzero], length_, batch_first=True) embedded = self.embedder(packed.data) if self.normalize: embedded = F.normalize(embedded, dim=1) embedded = PackedSequence(embedded, packed.batch_sizes) embedded, _ = pad(embedded, batch_first=True) # nonzero x max(length) x embedded_dim # define initial state qt = embedded.new(embedded.size()[0], self.lstm_size).fill_(0.) ct = embedded.new(embedded.size()[0], self.lstm_size).fill_(0.) ########################### # shuffling (set encoding) ########################### for i in range(self.n_shuffle): attended, weight = self.attending(qt, embedded, length[:nonzero]) # attended : nonzero x embedded_dim qt, ct = self.lstm(attended, (qt, ct)) # TODO edit here! weight = weight.detach().cpu().numpy() tmp = state[:, :, 1:] val, acq = torch.max(tmp, 2) # batch x n_features tmp = (val.long() * acq).cpu().numpy() #tmp = tmp.cpu().numpy() tmp = tmp[:weight.shape[0], :weight.shape[1]] self.weight[np.arange(nonzero).reshape(-1, 1), tmp] = weight encoded = torch.cat((attended, qt), dim=1) if batch_size > nonzero: encoded = torch.cat( (encoded, encoded.new(int(batch_size - nonzero), encoded.size()[1]).fill_(0.)), dim=0 ) return encoded
def featurize(self, x): wrd_ix = x[0] lengths = x[1] x = self.word_vectors(wrd_ix) x = pack(x, lengths.tolist(), batch_first=True) lstm_out, (hidden_state, cell_state) = self.lstm(x) if self.att: x, self.attentions = self.att_layer(lstm_out) else: output, _ = pad(lstm_out, batch_first=True) # get the last time step for each sequence idx = (lengths - 1).view(-1, 1).expand(output.size(0), output.size(2)).unsqueeze(1) x = output.gather(1, Variable(idx)).squeeze(1) return x
def forward(self, inputs, lengths): lengths = lengths.contiguous().data.view(-1).tolist() word_vecs = self.dropout(self.word_lut(inputs)) packed_word_vecs = pack(word_vecs, lengths) rnn_output, hidden = self.lstm(packed_word_vecs) rnn_output = pad(rnn_output)[0] rnn_output = self.dropout(rnn_output) output_flat = self.linear_output( rnn_output.view( rnn_output.size(0) * rnn_output.size(1), rnn_output.size(2) ) ) return output_flat
def neg_log_likehood(self, subword_idxs, subword_masks, token_start, batch_label): #self.args.use_crf = False bert_outs, _ = self.bert( subword_idxs, token_type_ids=None, attention_mask=subword_masks, output_all_encoded_layers=False, ) lens = token_start.sum(dim=1) #x = bert_outs[token_start] bert_outs = torch.split(bert_outs[token_start], lens.tolist()) bert_outs = pad_sequence(bert_outs, batch_first=True) max_len = bert_outs.size(1) mask = torch.arange(max_len).cuda() < lens.unsqueeze(-1) # add lstm after bert sorted_lens, sorted_idx = torch.sort(lens, dim=0, descending=True) reverse_idx = torch.sort(sorted_idx, dim=0)[1] bert_outs = bert_outs[sorted_idx] bert_outs = pack(bert_outs, sorted_lens, batch_first=True) bert_outs, hidden = self.lstm(bert_outs) bert_outs, _ = pad(bert_outs, batch_first=True) bert_outs = bert_outs[reverse_idx] out = self.linear(bert_outs) #out = self.dropout(out) if self.args.use_crf: loss = self.crf(out, mask, batch_label) score, seq = self.crf.viterbi_decode(out, mask) else: batch_size = out.size(0) seq_len = out.size(1) out = out.view(-1, out.size(2)) score = torch.nn.functional.log_softmax(out, 1) loss_function = nn.NLLLoss(ignore_index=0, reduction="sum") loss = loss_function(score, batch_label.view(-1)) _, seq = torch.max(score, 1) seq = seq.view(batch_size, seq_len) if self.args.average_loss: loss = loss / mask.float().sum() return loss, seq
def forward(self, input, source_lengths, hidden=(None, None, None)): batch_size = len(source_lengths) if not hidden: hidden = (None, None, None) (encoder_hidden, interface_hidden, mem_hidden) = hidden # encode encoded, encoder_hidden = self.encoder(input, source_lengths, encoder_hidden) # reset working memory mem_hidden = self.memory.reset(batch_size, mem_hidden) # nothing read in first time step (b*w*r) read_vectors = cuda(T.zeros(batch_size, self.w, self.r), gpu_id=self.gpu_id) dnc_encoded = cuda(T.zeros(encoded.size()), gpu_id=self.gpu_id) # unroll the rnn for each time step for x in range(max(source_lengths)): # concat the input and stuff read from memory in last time step b = encoded[:, x, :].unsqueeze(2) # b * w * 1 input = T.cat((b, read_vectors), 2).view(batch_size, -1, self.input_size) # b * 1 * ((r+1)*w) # pass it through an RNN input = pack(input, [1] * batch_size, batch_first=True) out, interface_hidden = self.rnn(input, interface_hidden) out, _ = pad(out, batch_first=True) ξ = out.squeeze(1) # forward pass through memory read_vectors, mem_hidden = self.memory(ξ, mem_hidden) # final output, todo: differs from deepmind's implementation # where they concat and then pass through a Linear read_vecs = read_vectors.view(-1, self.w * self.r) mem_encoded = T.cat([ξ, read_vecs], 1) dnc_encoded[:, x, :] = self.mem_out(mem_encoded) return dnc_encoded, (encoder_hidden, interface_hidden, mem_hidden)
def run_rnn(self, embedded_input, batch, rnn): """ Run embeddings through RNN and return the output. Args: embedded_input (torch.FloatTensor): batch x seq x dim batch (Batch): batch object containing .lengths tensor param (torch.nn.LSTM): LSTM to run the embeddings through Returns: torch.FloatTensor: hidden states output of LSTM, batch x seq x dim """ (sorted_input, sorted_lengths, input_unsort_indices, _) = \ sort_batch_by_length(embedded_input, batch.lengths) packed_input = pack(sorted_input, sorted_lengths.data.tolist(), batch_first=True) rnn.flatten_parameters() packed_sorted_output, _ = rnn(packed_input) sorted_output, _ = pad(packed_sorted_output, batch_first=True) return sorted_output[input_unsort_indices]
def forward(self, source, source_lengths, hidden=None): ''' source: nr_batches * max_len source_lengths: nr_batches hidden: nr_layers * nr_batches * nr_hidden ''' batch_size = self.source.size()[0] embedded = self.embedding(source) outputs = T.zeros(batch_size, max(source_lengths), self.nr_hidden) for c in range(self.columns): e = embedded[:, :, self.λ * c:(self.λ + 1) * c] packed = pack(e, source_lengths, batch_first=True) h = None if not hidden else hidden[:, :, self.λ * c:(self.λ + 1) * c] o, h = self.gru(packed, h) o, _ = pad(packed, batch_first=True) hidden[:, :, self.λ * c:(self.λ + 1) * c] = h outputs[:, :, self.λ * c:(self.λ + 1) * c] = o return outputs, hidden
def forward(self, input): B, T = input.size() lens = input.ne(self.vocabulary.pad_id).sum(1).long() x = self.embedding(input) # B x T x D x = F.dropout(x, self.dropout, self.training) x = pack(x, lens, batch_first=True) x, _ = self.rnn(x) x = pad(x, True) B_, T_, H = x.size() assert B_ == B assert T_ == T x=F.dropout(x, self.dropout, self.training) x=self.ffn(x) x=F.dropout(x, self.dropout, self.training) x = x.sum(1) / lens.unsqueeze(1).float() # B x H logits = self.logits(x) # B x C, C: num_class return logits
def collate(self, xs): return { 'src': pad([x['src'] for x in xs]), 'trg': torch.stack([x['trg'] for x in xs], dim=-1), 'lengths': torch.stack([x['lengths'] for x in xs], dim=-1) }
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True): # handle packed data is_packed = type(input) is PackedSequence if is_packed: input, lengths = pad(input) max_length = lengths[0] else: max_length = input.size(1) if self.batch_first else input.size(0) lengths = [input.size(1)] * max_length if self.batch_first else [input.size(0)] * max_length batch_size = input.size(0) if self.batch_first else input.size(1) if not self.batch_first: input = input.transpose(0, 1) # make the data time-first controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience) # concat input with last read (or padding) vectors inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)] # batched forward pass per element / word / etc if self.debug: viz = None outs = [None] * max_length read_vectors = None # pass through time for time in range(max_length): # pass thorugh layers for layer in range(self.num_layers): # this layer's hidden states chx = controller_hidden[layer] m = mem_hidden if self.share_memory else mem_hidden[layer] # pass through controller outs[time], (chx, m, read_vectors) = \ self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory) # debug memory if self.debug: viz = self._debug(m, viz) # store the memory back (per layer or shared) if self.share_memory: mem_hidden = m else: mem_hidden[layer] = m controller_hidden[layer] = chx if read_vectors is not None: # the controller output + read vectors go into next layer outs[time] = T.cat([outs[time], read_vectors], 1) else: outs[time] = T.cat([outs[time], last_read], 1) inputs[time] = outs[time] if self.debug: viz = {k: np.array(v) for k, v in viz.items()} viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items()} # pass through final output layer inputs = [self.output(i) for i in inputs] outputs = T.stack(inputs, 1 if self.batch_first else 0) if is_packed: outputs = pack(output, lengths) if self.debug: return outputs, (controller_hidden, mem_hidden, read_vectors), viz else: return outputs, (controller_hidden, mem_hidden, read_vectors)
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True): # handle packed data is_packed = type(input) is PackedSequence if is_packed: input, lengths = pad(input) max_length = lengths[0] else: max_length = input.size(1) if self.batch_first else input.size(0) lengths = [input.size(1)] * max_length if self.batch_first else [ input.size(0) ] * max_length batch_size = input.size(0) if self.batch_first else input.size(1) if not self.batch_first: input = input.transpose(0, 1) # make the data time-first controller_hidden, mem_hidden, last_read = self._init_hidden( hx, batch_size, reset_experience) # concat input with last read (or padding) vectors inputs = [ T.cat([input[:, x, :], last_read], 1) for x in range(max_length) ] # batched forward pass per element / word / etc if self.debug: viz = None outs = [None] * max_length read_vectors = None # pass through time for time in range(max_length): # pass thorugh layers for layer in range(self.num_layers): # this layer's hidden states chx = controller_hidden[layer] m = mem_hidden if self.share_memory else mem_hidden[layer] # pass through controller outs[time], (chx, m, read_vectors) = \ self._layer_forward( inputs[time], layer, (chx, m), pass_through_memory) # debug memory if self.debug: viz = self._debug(m, viz) # store the memory back (per layer or shared) if self.share_memory: mem_hidden = m else: mem_hidden[layer] = m controller_hidden[layer] = chx if read_vectors is not None: # the controller output + read vectors go into next layer outs[time] = T.cat([outs[time], read_vectors], 1) else: outs[time] = T.cat([outs[time], last_read], 1) inputs[time] = outs[time] if self.debug: viz = {k: np.array(v) for k, v in viz.items()} reshape_keys = [ "memory", "link_matrix", "precedence", "read_weights", "write_weights", "usage_vector" ] for key in reshape_keys: viz[key] = viz[key].reshape( viz[key].shape[0], viz[key].shape[1] * viz[key].shape[2]) # mean_keys = ["free_gates", "allocation_gate", "write_gate", "read_modes"] # for key in mean_keys: # viz[key] = np.mean(viz[key], axis=0) # viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) # for k, v in viz.items()} # pass through final output layer inputs_final = [self.output(i) for i in inputs] outputs = T.stack(inputs_final, 1 if self.batch_first else 0) if self.debug: self._update_controller_memory_contributions() c_contrib = [ self.controller_contribution(i[:, :self.output_size]) for i in inputs ] m_contrib = [ self.memory_contribution(i[:, self.output_size:]) for i in inputs ] c_contrib = T.stack(c_contrib, 1 if self.batch_first else 0) m_contrib = T.stack(m_contrib, 1 if self.batch_first else 0) outputs_ = outputs.clone().detach() # average over the batch dim c_contrib = c_contrib.mean(0) m_contrib = m_contrib.mean(0) outputs_ = outputs_.mean(0) c_contrib = nn.Softmax(1)(c_contrib) m_contrib = nn.Softmax(1)(m_contrib) outputs_ = nn.Softmax(1)(outputs_) m_influence = torch.abs(outputs_ - c_contrib) m_influence = torch.mean(m_influence, 1) m_influence_scalar = torch.mean(m_influence, 0) c_influence = torch.abs(outputs_ - m_contrib) c_influence = torch.mean(c_influence, 1) c_influence_scalar = torch.mean(c_influence, 0) m_inf_vec = m_influence.div(m_influence + c_influence) c_inf_vec = c_influence.div(m_influence + c_influence) m_inf = m_influence_scalar / (m_influence_scalar + c_influence_scalar) c_inf = c_influence_scalar / (m_influence_scalar + c_influence_scalar) viz["memory_influence"] = m_inf viz["controller_influence"] = c_inf viz["memory_influence_vec"] = m_inf_vec.detach().cpu().numpy() viz["controller_influence_vec"] = c_inf_vec.detach().cpu().numpy() if is_packed: outputs = pack(output, lengths) if self.debug: return outputs, (controller_hidden, mem_hidden, read_vectors), viz else: return outputs, (controller_hidden, mem_hidden, read_vectors)
def forward(self, inputs, lengths=None): """ Parameters ---------- inputs: LongTensor The input data. Shape is (seq_len, batch_size) if batch_first=False, or (batch_size, seq_len) if batch_first=True. lengths: LongTensor or List[int], optional List of integers with the sequence length for each element in the batch. Returns ------- output_distribution: FloatTensor FloatTensor of shape (batch_size, vocab_size), which is a distribution over the vocabulary for each batch. """ # Shape: (batch_size, seq_len, embedding_size) if batch_first=True # Shape: (batch_size, seq_len, embedding_size) if batch_first=False embedded_seq = self.embedding(inputs) embedded_seq = self.dropout(embedded_seq) if lengths is not None: # Pack the sequence embedded_seq = pack(embedded_seq, lengths, batch_first=self.batch_first) # encoded_seq shape if batch_first=True: (batch_size, seq_len, # hidden_size) # encoded_seq shape if batch_first=False: (seq_len, batch_size, # hidden_size) self.rnn.flatten_parameters() encoded_seq, _ = self.rnn(embedded_seq) if lengths is not None: # Pad the packed sequence encoded_seq, _ = pad(encoded_seq, batch_first=self.batch_first) # Apply dropout. encoded_seq = self.dropout(encoded_seq) # Get the output after encoding the entire sequence if lengths is not None: # Shape: (batch_size, hidden_size) idx = (torch.tensor( lengths, dtype=torch.long, device=encoded_seq.device) - 1).view(-1, 1).expand(len(lengths), encoded_seq.size(2)) time_dimension = 1 if self.batch_first else 0 idx = idx.unsqueeze(time_dimension) last_output = encoded_seq.gather(time_dimension, idx).squeeze(time_dimension) else: last_output = encoded_seq[:, -1] # Run this reshaped RNN output through the decoder to get # output of shape (batch_size, vocab_size) output_distribution = self.decoder(last_output) # Return decoded, a distribution over the output vocabulary for # each sequence in the batch. # Shape: (batch_size, vocab_size) return output_distribution
def forward(self, x, hx=(None, None), reset_experience=False, pass_through_memory=True): # handle packed data is_packed = type(x) is PackedSequence if is_packed: x, lengths = pad(x) max_length = lengths[0] else: max_length = x.size(1) if self.batch_first else x.size(0) lengths = [x.size(1)] * max_length if self.batch_first else [ x.size(0) ] * max_length batch_size = x.size(0) if self.batch_first else x.size(1) controller_hidden, mem_hidden = self._init_hidden( hx, batch_size, reset_experience) # batched forward pass per element / word / etc if self.debug: viz = None outs = [None] * max_length read_vectors = None # pass thorugh layers for layer in range(self.num_layers): # this layer's hidden states chx = controller_hidden[layer] m = mem_hidden if self.share_memory else mem_hidden[layer] # pass through controller x, (chx, m, read_vectors) = self._layer_forward(x, layer, (chx, m), pass_through_memory) # debug memory if self.debug: viz = self._debug(m, viz) # store the memory back (per layer or shared) if self.share_memory: mem_hidden = m else: mem_hidden[layer] = m controller_hidden[layer] = chx if read_vectors is not None: # the controller output + read vectors go into next layer x = T.cat([x[-1, :, :], read_vectors], 1) read_vectors = None if self.debug: viz = {k: np.array(v) for k, v in viz.items()} viz = { k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items() } # pass through final output layer for layer in range(len(self.output)): x = self.output[layer](x) if is_packed: x = pack(x, lengths) if self.debug: return x, (controller_hidden, mem_hidden), viz else: return x, (controller_hidden, mem_hidden)
def forward(self, input, hx=None): timesteps = input.size(1) if self.batch_first else input.size(0) directions = 2 if self.bidirectional else 1 is_packed = isinstance(input, PackedSequence) if is_packed: input, batch_sizes = pad(input) max_batch_size = batch_sizes[0] else: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size( 1) # layer * direction if hx is None: num_directions = 2 if self.bidirectional else 1 hx = var(input.data.new(max_batch_size, self.hidden_size).zero_(), requires_grad=False) hx = (hx, hx) hx = [[hx for x in range(directions)] for d in range(self.num_layers)] # make weights indexable with layer -> direction ws = self.all_weights if directions == 1: ws = [[w] for w in ws] else: ws = [[ws[l * 2], ws[l * 2 + 1]] for l in range(self.num_layers)] # make input batch-first, separate into timeslice wise chunks input = input if self.batch_first else input.transpose(0, 1) os = [[input[:, i, :] for i in range(timesteps)] for d in range(directions)] if directions > 1: os[1].reverse() for time in range(timesteps): for layer in range(self.num_layers): for direction in range(directions): if self.bias: (w_ih, w_hh, b_ih, b_hh) = ws[layer][direction] else: (w_ih, w_hh) = ws[layer][direction] b_ih = None b_hh = None hy, cy = SubLSTMCellF(os[direction][time], hx[layer][direction], w_ih, w_hh, b_ih, b_hh) hx[layer][direction] = (hy, cy) os[direction][time] = hy if directions > 1: os[0][time] = T.cat( [os[d][time] for d in range(directions)], -1) os[1][time] = os[0][time] output = T.stack([T.stack(o, 1) for o in os]) output = T.cat(output, -1) if self.bidirectional else output[0] output = output if self.batch_first else output.transpose(0, 1) if is_packed: output = pack(output, batch_sizes) return output, hx
def forward(self, word_inputs, word_seq_length, char_inputs, char_seq_length, char_recover): """ word_inputs: (batch_size,seq_len) word_seq_length:() """ batch_size = word_inputs.size(0) seq_len = word_inputs.size(1) word_emb = self.word_embedding(word_inputs) # word_rep = self.drop(word_emb) word_rep = word_emb if self.args.use_char: size = char_inputs.size(0) char_emb = self.char_embedding(char_inputs) char_emb = pack(char_emb, char_seq_length.numpy(), batch_first=True) char_lstm_out, char_hidden = self.char_feature(char_emb) char_lstm_out = pad(char_lstm_out, batch_first=True) char_hidden = char_hidden[0].transpose(1, 0).contiguous().view( size, -1) char_hidden = char_hidden[char_recover] char_hidden = char_hidden.view(batch_size, seq_len, -1) if self.args.attention: word_rep = F.tanh( self.attn1(word_emb) + self.attn2(char_hidden)) z = F.sigmoid(self.attn3(word_rep)) x = 1 - z word_rep = F.mul(z, word_emb) + F.mul(x, char_hidden) else: word_rep = torch.cat((word_emb, char_hidden), 2) word_rep = pack(word_rep, word_seq_length.cpu().numpy(), batch_first=True) out, hidden = self.word_feature(word_rep) out, _ = pad(out, batch_first=True) if self.args.lstm_attention: # tanh_out = F.tanh(out) # att_vec = self.att1(tanh_out) # att_sm_vec = F.softmax(att_vec) # out = F.mul(out,att_sm_vec) out_list, weight_list = [], [] for idx in range(seq_len): # slice_out = out[:,0:idx+1,:] if idx + 2 > seq_len: slice_out = out else: slice_out = out[:, 0:idx + 2, :] # slice_out = out slice_out, weights = self.attention(slice_out) # slice_out, weights = SelfAttention(self.args.hidden_dim*2).forward(slice_out) out_list.append(slice_out.unsqueeze(1)) weight_list.append(weights) out = torch.cat(out_list, dim=1) # pass # out = F.tanh(self.att1(out)) # out = self.softmax(out) # out = out*out # out = self.att2(out) # else: out = self.hidden2tag(self.drop(out)) return out