class Attention(nn.Module): def __init__(self, input_size, hidden_size, num_classes, num_embeddings=128, CUDA=True): super(Attention, self).__init__() self.attention_cell = AttentionCell(input_size, hidden_size, num_embeddings, CUDA=CUDA) self.input_size = input_size self.hidden_size = hidden_size self.generator = nn.Linear(hidden_size, num_classes) self.char_embeddings = Parameter( torch.randn(num_classes + 1, num_embeddings)) self.num_embeddings = num_embeddings self.num_classes = num_classes self.cuda = CUDA # targets is nT * nB def forward(self, feats, text_length, text, test=False): nT = feats.size(0) nB = feats.size(1) nC = feats.size(2) hidden_size = self.hidden_size input_size = self.input_size assert (input_size == nC) assert (nB == text_length.numel()) num_steps = text_length.data.max() num_labels = text_length.data.sum() if not test: targets = torch.zeros(nB, num_steps + 1).long() if self.cuda: targets = targets.cuda() start_id = 0 for i in range(nB): targets[i][1:1 + text_length. data[i]] = text.data[start_id:start_id + text_length.data[i]] + 1 start_id = start_id + text_length.data[i] targets = Variable(targets.transpose(0, 1).contiguous()) output_hiddens = Variable( torch.zeros(num_steps, nB, hidden_size).type_as(feats.data)) hidden = Variable(torch.zeros(nB, hidden_size).type_as(feats.data)) for i in range(num_steps): cur_embeddings = self.char_embeddings.index_select( 0, targets[i]) hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings, test) output_hiddens[i] = hidden new_hiddens = Variable( torch.zeros(num_labels, hidden_size).type_as(feats.data)) b = 0 start = 0 for length in text_length.data: new_hiddens[start:start + length] = output_hiddens[0:length, b, :] start = start + length b = b + 1 probs = self.generator(new_hiddens) return probs else: hidden = Variable(torch.zeros(nB, hidden_size).type_as(feats.data)) targets_temp = Variable(torch.zeros(nB).long().contiguous()) probs = Variable(torch.zeros(nB * num_steps, self.num_classes)) if self.cuda: targets_temp = targets_temp.cuda() probs = probs.cuda() for i in range(num_steps): cur_embeddings = self.char_embeddings.index_select( 0, targets_temp) hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings, test) hidden2class = self.generator(hidden) probs[i * nB:(i + 1) * nB] = hidden2class _, targets_temp = hidden2class.max(1) targets_temp += 1 probs = probs.view(num_steps, nB, self.num_classes).permute(1, 0, 2).contiguous() probs = probs.view(-1, self.num_classes).contiguous() probs_res = Variable( torch.zeros(num_labels, self.num_classes).type_as(feats.data)) b = 0 start = 0 for length in text_length.data: probs_res[start:start + length] = probs[b * num_steps:b * num_steps + length] start = start + length b = b + 1 return probs_res
class GCNLayer(nn.Module): """ Graph convolutional neural network encoder. """ def __init__(self, num_inputs, num_units, num_labels, in_arcs=True, out_arcs=True, batch_first=False, use_gates=True, use_glus=False): super(GCNLayer, self).__init__() self.in_arcs = in_arcs self.out_arcs = out_arcs self.num_inputs = num_inputs self.num_units = num_units self.num_labels = num_labels self.batch_first = batch_first self.glu = nn.GLU(3) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.use_gates = use_gates self.use_glus = use_glus #https://www.cs.toronto.edu/~yujiali/files/talks/iclr16_ggnn_talk.pdf #https://arxiv.org/pdf/1612.08083.pdf if in_arcs: self.V_in = Parameter(torch.Tensor(self.num_inputs, self.num_units)) nn.init.xavier_normal_(self.V_in) self.b_in = Parameter(torch.Tensor(num_labels, self.num_units)) nn.init.constant_(self.b_in, 0) if self.use_gates: self.V_in_gate = Parameter(torch.Tensor(self.num_inputs, 1)) nn.init.xavier_normal_(self.V_in_gate) self.b_in_gate = Parameter(torch.Tensor(num_labels, 1)) nn.init.constant_(self.b_in_gate, 1) if out_arcs: self.V_out = Parameter( torch.Tensor(self.num_inputs, self.num_units)) nn.init.xavier_normal_(self.V_out) self.b_out = Parameter(torch.Tensor(num_labels, self.num_units)) nn.init.constant_(self.b_out, 0) if self.use_gates: self.V_out_gate = Parameter(torch.Tensor(self.num_inputs, 1)) nn.init.xavier_normal_(self.V_out_gate) self.b_out_gate = Parameter(torch.Tensor(num_labels, 1)) nn.init.constant_(self.b_out_gate, 1) self.W_self_loop = Parameter( torch.Tensor(self.num_inputs, self.num_units)) nn.init.xavier_normal_(self.W_self_loop) if self.use_gates: self.W_self_loop_gate = Parameter(torch.Tensor(self.num_inputs, 1)) nn.init.xavier_normal_(self.W_self_loop_gate) def forward( self, src, lengths=None, arc_tensor_in=None, arc_tensor_out=None, label_tensor_in=None, label_tensor_out=None, mask_in=None, mask_out=None, # batch* t, degree mask_loop=None, sent_mask=None): if not self.batch_first: encoder_outputs = src.permute(1, 0, 2).contiguous() else: encoder_outputs = src.contiguous() batch_size = encoder_outputs.size()[0] seq_len = encoder_outputs.size()[1] max_degree = 1 input_ = encoder_outputs.view( (batch_size * seq_len, self.num_inputs)) # [b* t, h] if self.in_arcs: input_in = torch.mm(input_, self.V_in) # [b* t, h] * [h,h] = [b*t, h] first_in = input_in.index_select( 0, arc_tensor_in[0] * seq_len + arc_tensor_in[1]) # [b* t* degr, h] second_in = self.b_in.index_select( 0, label_tensor_in[0]) # [b* t* degr, h] in_ = first_in + second_in degr = int(first_in.size()[0] / batch_size / seq_len) in_ = in_.view((batch_size, seq_len, degr, self.num_units)) if self.use_glus: # gate the information of each neighbour, self nodes are in here too. in_ = torch.cat((in_, in_), 3) in_ = self.glu(in_) if self.use_gates: # compute gate weights input_in_gate = torch.mm( input_, self.V_in_gate) # [b* t, h] * [h,h] = [b*t, h] first_in_gate = input_in_gate.index_select( 0, arc_tensor_in[0] * seq_len + arc_tensor_in[1]) # [b* t* mxdeg, h] second_in_gate = self.b_in_gate.index_select( 0, label_tensor_in[0]) in_gate = (first_in_gate + second_in_gate).view( (batch_size, seq_len, degr)) max_degree += degr if self.out_arcs: input_out = torch.mm(input_, self.V_out) # [b* t, h] * [h,h] = [b* t, h] first_out = input_out.index_select( 0, arc_tensor_out[0] * seq_len + arc_tensor_out[1]) # [b* t* mxdeg, h] second_out = self.b_out.index_select(0, label_tensor_out[0]) degr = int(first_out.size()[0] / batch_size / seq_len) max_degree += degr out_ = (first_out + second_out).view( (batch_size, seq_len, degr, self.num_units)) if self.use_glus: # gate the information of each neighbour, self nodes are in here too. out_ = torch.cat((out_, out_), 3) out_ = self.glu(out_) if self.use_gates: # compute gate weights input_out_gate = torch.mm( input_, self.V_out_gate) # [b* t, h] * [h,h] = [b* t, h] first_out_gate = input_out_gate.index_select( 0, arc_tensor_out[0] * seq_len + arc_tensor_out[1]) # [b* t* mxdeg, h] second_out_gate = self.b_out_gate.index_select( 0, label_tensor_out[0]) out_gate = (first_out_gate + second_out_gate).view( (batch_size, seq_len, degr)) same_input = torch.mm(encoder_outputs.view(-1, encoder_outputs.size(2)), self.W_self_loop). \ view(encoder_outputs.size(0), encoder_outputs.size(1), -1) same_input = same_input.view(encoder_outputs.size(0), encoder_outputs.size(1), 1, self.W_self_loop.size(1)) if self.use_gates: same_input_gate = torch.mm(encoder_outputs.view(-1, encoder_outputs.size(2)), self.W_self_loop_gate) \ .view(encoder_outputs.size(0), encoder_outputs.size(1), -1) if self.in_arcs and self.out_arcs: potentials = torch.cat((in_, out_, same_input), dim=2) # [b, t, mxdeg, h] if self.use_gates: potentials_gate = torch.cat( (in_gate, out_gate, same_input_gate), dim=2) # [b, t, mxdeg, h] mask_soft = torch.cat((mask_in, mask_out, mask_loop), dim=1) # [b* t, mxdeg] elif self.out_arcs: potentials = torch.cat((out_, same_input), dim=2) # [b, t, 2*mxdeg+1, h] if self.use_gates: potentials_gate = torch.cat((out_gate, same_input_gate), dim=2) # [b, t, mxdeg, h] mask_soft = torch.cat((mask_out, mask_loop), dim=1) # [b* t, mxdeg] elif self.in_arcs: potentials = torch.cat((in_, same_input), dim=2) # [b, t, 2*mxdeg+1, h] if self.use_gates: potentials_gate = torch.cat((in_gate, same_input_gate), dim=2) # [b, t, mxdeg, h] mask_soft = torch.cat((mask_in, mask_loop), dim=1) # [b* t, mxdeg] else: potentials = same_input # [b, t, 2*mxdeg+1, h] if self.use_gates: potentials_gate = same_input_gate # [b, t, mxdeg, h] mask_soft = mask_loop # [b* t, mxdeg] potentials_resh = potentials.view(( batch_size * seq_len, max_degree, self.num_units, )) # [h, b * t, mxdeg] if self.use_gates: potentials_r = potentials_gate.view( (batch_size * seq_len, max_degree)) # [b * t, mxdeg] probs_det_ = (self.sigmoid(potentials_r) * mask_soft).unsqueeze( 2) # [b * t, mxdeg] potentials_masked = potentials_resh * probs_det_ # [b * t, mxdeg,h] else: # NO Gates potentials_masked = potentials_resh * mask_soft.unsqueeze(2) potentials_masked_ = potentials_masked.sum(dim=1) # [b * t, h] potentials_masked_ = self.relu(potentials_masked_) # [b * t, h] result_ = potentials_masked_.view( (batch_size, seq_len, self.num_units)) # [ b, t, h] result_ = result_ * sent_mask.permute(1, 0).contiguous().unsqueeze( 2) # [b, t, h] memory_bank = result_.permute(1, 0, 2).contiguous() # [t, b, h] return memory_bank
class DTD(nn.Module): # LSTM DTD def __init__(self, nclass, nchannel, dropout=0.3): super(DTD, self).__init__() self.nclass = nclass self.nchannel = nchannel self.pre_lstm = nn.LSTM(nchannel, int(nchannel / 2), bidirectional=True) self.rnn = nn.GRUCell(nchannel * 2, nchannel) self.generator = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(nchannel, nclass)) self.char_embeddings = Parameter(torch.randn(nclass, nchannel)) def forward(self, feature, A, text, text_length, test=False): nB, nC, nH, nW = feature.size() nT = A.size()[1] # Normalize A = A / A.view(nB, nT, -1).sum(2).view(nB, nT, 1, 1) # weighted sum C = feature.view(nB, 1, nC, nH, nW) * A.view(nB, nT, 1, nH, nW) C = C.view(nB, nT, nC, -1).sum(3).transpose(1, 0) C, _ = self.pre_lstm(C) C = F.dropout(C, p=0.3, training=self.training) if not test: lenText = int(text_length.sum()) nsteps = int(text_length.max()) gru_res = torch.zeros(C.size()).type_as(C.data) out_res = torch.zeros(lenText, self.nclass).type_as(feature.data) out_attns = torch.zeros(lenText, nH, nW).type_as(A.data) hidden = torch.zeros(nB, self.nchannel).type_as(C.data) prev_emb = self.char_embeddings.index_select( 0, torch.zeros(nB).long().type_as(text.data)) for i in range(0, nsteps): hidden = self.rnn(torch.cat((C[i, :, :], prev_emb), dim=1), hidden) gru_res[i, :, :] = hidden prev_emb = self.char_embeddings.index_select(0, text[:, i]) gru_res = self.generator(gru_res) start = 0 for i in range(0, nB): cur_length = int(text_length[i]) out_res[start:start + cur_length] = gru_res[0:cur_length, i, :] out_attns[start:start + cur_length] = A[i, 0:cur_length, :, :] start += cur_length return out_res, out_attns else: lenText = nT nsteps = nT out_res = torch.zeros(lenText, nB, self.nclass).type_as(feature.data) hidden = torch.zeros(nB, self.nchannel).type_as(C.data) prev_emb = self.char_embeddings.index_select( 0, torch.zeros(nB).long().type_as(text.data)) out_length = torch.zeros(nB) now_step = 0 while 0 in out_length and now_step < nsteps: hidden = self.rnn( torch.cat((C[now_step, :, :], prev_emb), dim=1), hidden) tmp_result = self.generator(hidden) out_res[now_step] = tmp_result tmp_result = tmp_result.topk(1)[1].squeeze() for j in range(nB): if out_length[j] == 0 and tmp_result[j] == 0: out_length[j] = now_step + 1 prev_emb = self.char_embeddings.index_select(0, tmp_result) now_step += 1 for j in range(0, nB): if int(out_length[j]) == 0: out_length[j] = nsteps start = 0 output = torch.zeros(int(out_length.sum()), self.nclass).type_as(feature.data) for i in range(0, nB): cur_length = int(out_length[i]) output[start:start + cur_length] = out_res[0:cur_length, i, :] start += cur_length return output, out_length