class AttentionScore(nn.Module): """ correlation_func = 1, sij = x1^Tx2 correlation_func = 2, sij = (Wx1)D(Wx2) correlation_func = 3, sij = Relu(Wx1)DRelu(Wx2) correlation_func = 4, sij = x1^TWx2 correlation_func = 5, sij = Relu(Wx1)DRelu(Wx2) """ def __init__(self, input_size, hidden_size, correlation_func=1, do_similarity=False): super(AttentionScore, self).__init__() self.correlation_func = correlation_func self.hidden_size = hidden_size if correlation_func == 2 or correlation_func == 3: self.linear = nn.Linear(input_size, hidden_size, bias=False) if do_similarity: self.diagonal = Parameter(torch.ones(1, 1, 1) / (hidden_size**0.5), requires_grad=False) else: self.diagonal = Parameter(torch.ones(1, 1, hidden_size), requires_grad=True) if correlation_func == 4: self.linear = nn.Linear(input_size, input_size, bias=False) if correlation_func == 5: self.linear = nn.Linear(input_size, hidden_size, bias=False) def forward(self, x1, x2): ''' Input: x1: batch x word_num1 x dim x2: batch x word_num2 x dim Output: scores: batch x word_num1 x word_num2 ''' x1 = dropout(x1, p=dropout_p, training=self.training) x2 = dropout(x2, p=dropout_p, training=self.training) x1_rep = x1 x2_rep = x2 batch = x1_rep.size(0) word_num1 = x1_rep.size(1) word_num2 = x2_rep.size(1) dim = x1_rep.size(2) if self.correlation_func == 2 or self.correlation_func == 3: x1_rep = self.linear(x1_rep.contiguous().view(-1, dim)).view( batch, word_num1, self.hidden_size) # Wx1 x2_rep = self.linear(x2_rep.contiguous().view(-1, dim)).view( batch, word_num2, self.hidden_size) # Wx2 if self.correlation_func == 3: x1_rep = F.relu(x1_rep) x2_rep = F.relu(x2_rep) x1_rep = x1_rep * self.diagonal.expand_as(x1_rep) # x1_rep is (Wx1)D or Relu(Wx1)D # x1_rep: batch x word_num1 x dim (corr=1) or hidden_size (corr=2,3) if self.correlation_func == 4: x2_rep = self.linear(x2_rep.contiguous().view(-1, dim)).view( batch, word_num2, dim) # Wx2 if self.correlation_func == 5: x1_rep = self.linear(x1_rep.contiguous().view(-1, dim)).view( batch, word_num1, self.hidden_size) # Wx1 x2_rep = self.linear(x2_rep.contiguous().view(-1, dim)).view( batch, word_num2, self.hidden_size) # Wx2 x1_rep = F.relu(x1_rep) x2_rep = F.relu(x2_rep) scores = x1_rep.bmm(x2_rep.transpose(1, 2)) return scores
class GraphConvolutionLayer(nn.Module): """ From https://github.com/tkipf/pygcn. Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, edge_dropout, activation, highway, bias=True): super(GraphConvolutionLayer, self).__init__() self.in_features = in_features out_features = int(out_features) self.out_features = out_features self.edge_dropout = nn.Dropout(edge_dropout) self.highway = highway self.activation = activation self.weight = Parameter(torch.Tensor(3, in_features, out_features)) if bias: self.bias = Parameter(torch.Tensor(3, 1, out_features)) if highway != "": assert (in_features == out_features) self.weight_highway = Parameter( torch.Tensor(in_features, out_features)) self.bias_highway = Parameter(torch.Tensor(1, out_features)) else: self.bias = None self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, inputs, adj): features = self.edge_dropout(inputs) outputs = [] for i in range(features.size()[1]): support = torch.bmm( features[:, i, :].unsqueeze(0).expand(self.weight.size(0), *features[:, i, :].size()), self.weight) if self.bias is not None: support += self.bias.expand_as(support) output = torch.mm( adj[:, i, :].transpose(1, 2).contiguous().view( support.size(0) * support.size(1), -1).transpose(0, 1), support.view(support.size(0) * support.size(1), -1)) outputs.append(output) if self.activation == "leaky_relu": output = F.leaky_relu(torch.stack(outputs, 1)) elif self.activation == "relu": output = F.relu(torch.stack(outputs, 1)) elif self.activation == "tanh": output = torch.tanh(torch.stack(outputs, 1)) elif self.activation == "sigmoid": output = torch.sigmoid(torch.stack(outputs, 1)) else: assert (False) if self.highway != "": transform = [] for i in range(features.size()[1]): transform_batch = torch.mm(features[:, i, :], self.weight_highway) transform_batch += self.bias_highway.expand_as(transform_batch) transform.append(transform_batch) if self.highway == "leaky_relu": transform = F.leaky_relu(torch.stack(transform, 1)) elif self.highway == "relu": transform = F.relu(torch.stack(transform, 1)) elif self.highway == "tanh": transform = torch.tanh(torch.stack(transform, 1)) elif self.highway == "sigmoid": transform = torch.sigmoid(torch.stack(transform, 1)) else: assert (False) carry = 1 - transform output = output * transform + features * carry return output
class MaskedLinear(nn.Module): """ Creates masked linear layer for MLP MADE. For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False. For hidden to output (y) layers: If output depends on input through y_i = f(x_{<i}) set diagonal_zeros = True. Else if output depends on input through y_i = f(x_{<=i}) set diagonal_zeros = False. """ def __init__(self, in_features, out_features, diagonal_zeros=False, bias=True): super(MaskedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.diagonal_zeros = diagonal_zeros self.weight = Parameter(torch.FloatTensor(in_features, out_features)) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias', None) mask = torch.from_numpy(self.build_mask()) if torch.cuda.is_available(): mask = mask.cuda() self.mask = torch.autograd.Variable(mask, requires_grad=False) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_normal(self.weight) if self.bias is not None: self.bias.data.zero_() def build_mask(self): n_in, n_out = self.in_features, self.out_features assert n_in % n_out == 0 or n_out % n_in == 0 mask = np.ones((n_in, n_out), dtype=np.float32) if n_out >= n_in: k = n_out // n_in for i in range(n_in): mask[i + 1:, i * k:(i + 1) * k] = 0 if self.diagonal_zeros: mask[i:i + 1, i * k:(i + 1) * k] = 0 else: k = n_in // n_out for i in range(n_out): mask[(i + 1) * k:, i:i + 1] = 0 if self.diagonal_zeros: mask[i * k:(i + 1) * k:, i:i + 1] = 0 return mask def forward(self, x): output = x.mm(self.mask * self.weight) if self.bias is not None: return output.add(self.bias.expand_as(output)) else: return output def __repr__(self): if self.bias is not None: bias = True else: bias = False return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ', diagonal_zeros=' \ + str(self.diagonal_zeros) + ', bias=' \ + str(bias) + ')'
class AttentionScore(nn.Module): """ 相关函数score(x1, x2)计算方法: correlation_func = 1, sij = x1^Tx2 correlation_func = 2, sij = (Wx1)D(Wx2) correlation_func = 3, sij = Relu(Wx1)DRelu(Wx2) correlation_func = 4, sij = x1^TWx2 correlation_func = 5, sij = Relu(Wx1)DRelu(Wx2) """ def __init__(self, input_size, hidden_size, correalition_func=1, do_similarity=False): super(AttentionScore, self).__init__() self.correlation_func = correalition_func self.hidden_size = hidden_size # 隐状态维度,即U矩阵的行数 # 实现公式:score(x1, x2) = ReLU(Ux1)^TDReLU(Ux2)。以下是矩阵U的初始设定 if correalition_func == 2 or correalition_func == 3: self.linear = nn.Linear(input_size, hidden_size, bias=True) # self.linear即矩阵U if do_similarity: # do_similarity控制初始化参数是否除以维度的平方根(类似Transformer中的Attention),以及是否更新D的参数 # 应用Parameter()将self.diagonal,即对角矩阵D,绑定到模型中,所以在训练的时候其是可优化的 self.diagonal = Parameter(torch.ones(1, 1, 1) / (hidden_size**0.5), requires_grad=False) else: self.diagonal = Parameter(torch.ones(1, 1, hidden_size), requires_grad=True) if correalition_func == 4: self.linear = nn.Linear(input_size, input_size, bias=False) # 不含矩阵U,即不含隐状态 if correalition_func == 5: self.linear = nn.Linear(input_size, hidden_size, bias=False) def forward(self, x1, x2): """ 计算x1和x2向量组的注意力分数 :param x1: batch * word_num1 * dim :param x2: batch * word_num2 * dim :return: scores: batch * word_num1 * word_num2 """ x1 = dropout(x1, p=dropout_p, training=self.training) x2 = dropout(x2, p=dropout_p, training=self.training) x1_rep = x1 x2_rep = x2 batch = x1_rep.size(0) word_num1 = x1_rep.size(1) word_num2 = x2_rep.size(1) dim = x1_rep.size(2) # 计算x1_rep和x2_rep if self.correlation_func == 2 or self.correlation_func == 3: x1_rep = self.linear(x1_rep.contiguous().view(-1, dim)).view( batch, word_num1, self.hidden_size) x2_rep = self.linear(x2_rep.contiguous().view(-1, dim)).view( batch, word_num2, self.hidden_size) if self.correlation_func == 3: # ReLU(Wx1)DReLU(Wx2) x1_rep = F.relu(x1_rep) x2_rep = F.relu(x2_rep) # x1_rep is (W1D) or ReLU(Wx1)D and the shape is batch * word_num1 * dim(corr=1) or hidden_size(corr=2, 3) x1_rep = x1_rep * self.diagonal.expand_as(x1_rep) if self.correlation_func == 4: x2_rep = self.linear(x2_rep.contiguous().view(-1, dim)).view( batch, word_num2, self.hidden_size) if self.correlation_func == 5: x1_rep = self.linear(x1_rep.contiguous().view(-1, dim)).view( batch, word_num1, self.hidden_size) x2_rep = self.linear(x2_rep.contiguous().view(-1, dim)).view( batch, word_num2, self.hidden_size) x1_rep = F.relu(x1_rep) x2_rep = F.relu(x2_rep) scores = x1_rep.bmm(x2_rep.transpose(1, 2)) return scores
class _WeightNormalizedConvNd(_ConvNd): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, scale, bias, init_factor, init_scale): super(_WeightNormalizedConvNd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, 1, False) if scale: self.scale = Parameter( torch.Tensor(1, self.out_channels, *((1, ) * len(kernel_size))).fill_(init_scale)) else: self.register_parameter('scale', None) if bias: self.bias = Parameter( torch.zeros(1, self.out_channels, *((1, ) * len(kernel_size)))) else: self.register_parameter('bias', None) self.weight.data.mul_(init_factor) self.weight_norm_factor = 1.0 if transposed: for t in self.stride: self.weight_norm_factor = self.weight_norm_factor / t def weight_norm(self): weight_norm = self.weight.pow(2) if self.transposed: weight_norm = weight_norm.sum(0, keepdim=True) else: weight_norm = weight_norm.sum(1, keepdim=True) for i in range(len(self.kernel_size)): weight_norm = weight_norm.sum(2 + i, keepdim=True) weight_norm = weight_norm.mul(self.weight_norm_factor).add(1e-6).sqrt() return weight_norm def norm_scale_bias(self, input): if self.transposed: output = input.div(self.weight_norm().expand_as(input)) else: output = input.div(self.weight_norm().transpose( 0, 1).expand_as(input)) if self.scale is not None: output = output.mul(self.scale.expand_as(input)) if self.bias is not None: output = output.add(self.bias.expand_as(input)) return output def __repr__(self): s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.output_padding != (0, ) * len(self.output_padding): s += ', output_padding={output_padding}' if self.scale is None: s += ', scale=False' if self.bias is None: s += ', bias=False' s += ')' return s.format(name=self.__class__.__name__, **self.__dict__)
class Filter(nn.Module): r"""Applies a complex filter to the incoming data: :math:`y = sum(x .* A, dim = -1) + b` Args: num_bin: number of frequency bin num_channel: number of channels bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` """ __constants__ = ['bias', 'num_bin', 'num_channel'] def __init__(self, num_bin, num_channel, init_weight=None, init_bias=None, bias=True, fix=False): super(Filter, self).__init__() self.num_bin = num_bin self.num_channel = num_channel self.intialized = False self.fix = fix self.weight = Parameter( torch.Tensor(1, 2, num_bin, num_channel), requires_grad=(not fix)) # (1, 2, num_bin, num_channel) if init_weight is not None: self.weight.data.copy_(init_weight) self.intialized = True if bias: self.bias = Parameter(torch.Tensor(1, 2, num_bin), requires_grad=(not fix)) # (1, 2, num_bin) if init_bias is not None: self.bias.data.copy_(init_bias) else: self.bias = None #self.register_parameter('bias', None) if init_weight is None: self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( self.weight) bound = 1 / math.sqrt(fan_in) torch.nn.init.uniform_(self.bias, -bound, bound) def forward(self, input): # input = ( num_frame, 2, num_bin, num_channel ) # weight = ( 1, 2, num_bin, num_channel ) num_frame, num_dim, num_bin, num_channel = input.size() assert num_frame > 0 and num_bin == self.num_bin and num_channel == self.num_channel and num_dim == 2, "illegal shape for input, the required input shape is (%d, 2, %d, %d), but got (%d, %s, %d, %d)" % ( num_frame, self.num_bin, self.num_channel, num_frame, num_dim, num_bin, num_channel) filter_out_r = self.weight[:, 0, :, :] * input[:, 0, :, :] - self.weight[:, 1, :, :] * input[:, 1, :, :] # (num_frame, num_bin, num_channel) filter_out_i = self.weight[:, 0, :, :] * input[:, 1, :, :] + self.weight[:, 1, :, :] * input[:, 0, :, :] # (num_frame, num_bin, num_channel) filter_out_r = filter_out_r.sum(dim=2) # (num_frame, num_bin) filter_out_i = filter_out_i.sum(dim=2) # (num_frame, num_bin) filter_out = torch.cat( (torch.unsqueeze(filter_out_r, 1), torch.unsqueeze( filter_out_i, 1)), 1) # (num_frame, 2, num_bin) if self.bias is not None: filter_out = filter_out + self.bias.expand_as(filter_out) # return filter_out # (num_frame, 2, num_bin) def extra_repr(self): return 'num_bin={}, num_channel={}, bias={}'.format( self.num_bin, self.num_channel, self.bias is not None)
class BatchReNorm1d(Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, rmax=3.0, dmax=5.0): super(BatchReNorm1d, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.rmax = rmax self.dmax = dmax if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('r', torch.ones(num_features)) self.register_buffer('d', torch.zeros(num_features)) self.reset_parameters() def reset_parameters(self): self.running_mean.zero_() self.running_var.fill_(1) self.r.fill_(1) self.d.zero_() if self.affine: self.weight.data.uniform_() self.bias.data.zero_() def _check_input_dim(self, input): if input.size(1) != self.running_mean.nelement(): raise ValueError('got {}-feature tensor, expected {}'.format( input.size(1), self.num_features)) def forward(self, input): self._check_input_dim(input) n = input.size()[0] if self.training: mean = torch.mean(input, dim=0) sum = torch.sum((input - mean.expand_as(input))**2, dim=0) if sum == 0 and self.eps == 0: invstd = 0.0 else: invstd = 1. / torch.sqrt(sum / n + self.eps) unbiased_var = sum / (n - 1) self.r = torch.clamp( torch.sqrt(unbiased_var).data / torch.sqrt(self.running_var), 1. / self.rmax, self.rmax) self.d = torch.clamp( (mean.data - self.running_mean) / torch.sqrt(self.running_var), -self.dmax, self.dmax) r = self.r.expand_as(input) d = self.d.expand_as(input) input_normalized = ( input - mean.expand_as(input)) * invstd.expand_as(input) input_normalized = input_normalized * r + d self.running_mean += self.momentum * (mean.data - self.running_mean) self.running_var += self.momentum * (unbiased_var.data - self.running_var) if not self.affine: return input_normalized output = input_normalized * self.weight.expand_as(input) output += self.bias.unsqueeze(0).expand_as(input) return output else: mean = self.running_mean.expand_as(input) invstd = 1. / torch.sqrt( self.running_var.expand_as(input) + self.eps) input_normalized = ( input - mean.expand_as(input)) * invstd.expand_as(input) if not self.affine: return input_normalized output = input_normalized * self.weight.expand_as(input) output += self.bias.unsqueeze(0).expand_as(input) return output def __repr__(self): return ('{name}({num_features}, eps={eps}, momentum={momentum},' 'affine={affine}, rmax={rmax}, dmax={dmax})'.format( name=self.__class__.__name__, **self.__dict__))
class FlowAttentionScore(nn.Module): """ sij = Relu(Wx1)DRelu(Wx2) """ def __init__(self, x1_input_size, x2_input_size, attention_hidden_size, similarity_score=False, tight_weight=True): super(FlowAttentionScore, self).__init__() self.tight_weight = tight_weight if tight_weight: assert x1_input_size == x2_input_size self.linear = nn.Linear(x1_input_size, attention_hidden_size, bias=False) else: self.linear1 = nn.Linear(x1_input_size, attention_hidden_size, bias=False) self.linear2 = nn.Linear(x2_input_size, attention_hidden_size, bias=False) if similarity_score: self.linear_final = Parameter(torch.ones(1, 1, 1) / (attention_hidden_size**0.5), requires_grad=False) else: self.linear_final = Parameter(torch.ones(1, 1, attention_hidden_size), requires_grad=True) def forward(self, x1, x2): """ x1: batch * len1 * input_size x2: batch * len2 * input_size scores: batch * len1 * len2 <the scores are not masked> """ x1 = dropout(x1, p=my_dropout_p, training=self.training) x2 = dropout(x2, p=my_dropout_p, training=self.training) if self.tight_weight: x1_rep = self.linear(x1.contiguous().view(-1, x1.size(-1))).view( x1.size(0), x1.size(1), -1) x2_rep = self.linear(x2.contiguous().view(-1, x2.size(-1))).view( x2.size(0), x2.size(1), -1) else: x1_rep = self.linear1(x1.contiguous().view(-1, x1.size(-1))).view( x1.size(0), x1.size(1), -1) x2_rep = self.linear2(x2.contiguous().view(-1, x2.size(-1))).view( x2.size(0), x2.size(1), -1) x1_rep = F.relu(x1_rep) x2_rep = F.relu(x2_rep) final_v = self.linear_final.expand_as(x2_rep) x2_rep_v = final_v * x2_rep scores = x1_rep.bmm(x2_rep_v.transpose(1, 2)) return scores
class OneEmbed(nn.Module): def __init__(self, num_embeddings, embedding_dim, padding_idx=None, one_emb_type='binary', dropout=0.5, std=1.0, codenum=64, codebooknum=8, layernum=1, interdim=0, relu_dropout=0.1, mask_file=''): super(OneEmbed, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.one_emb_type = one_emb_type self.layernum = layernum self.relu_dropout = relu_dropout if padding_idx is not None: if padding_idx > 0: assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.weight = Parameter(torch.Tensor(1, embedding_dim)) #embedding for all tokens if interdim == 0: interdim = embedding_dim self.weight_matrices = nn.ParameterList([nn.Parameter(torch.Tensor(embedding_dim, interdim)) if i+1 == self.layernum else (nn.Parameter(torch.Tensor(interdim, embedding_dim)) if i == 0 else nn.Parameter(torch.Tensor(interdim, interdim))) for i in range(self.layernum)]) if os.path.isfile(mask_file): self.mask = torch.load(mask_file) else: if self.one_emb_type == 'binary': prob = torch.Tensor(codenum, embedding_dim) nn.init.constant_(prob, (1 - dropout ** (1.0 / codebooknum))) self.masklist = [torch.bernoulli(prob) for _ in range(codebooknum)] else: mean_m = torch.zeros(codenum, embedding_dim) std_m = torch.Tensor(codenum, embedding_dim) nn.init.constant_(std_m, std * (codebooknum ** -0.5)) self.masklist = [torch.normal(mean_m, std_m) for _ in range(codebooknum)] self.hash2mask = torch.randint(0, codenum, (num_embeddings, codebooknum), dtype=torch.long) self.mask = self.construct_mask2each_token() #mask for each token dirname = '/'.join(mask_file.split('/')[:-1]) if not os.path.isdir(dirname): os.makedirs(dirname) torch.save(self.mask, mask_file) def construct_mask2each_token(self): mask = [] for i in range(self.hash2mask.size(1)): token_hash = self.hash2mask[:, i] mask.append(nn.functional.embedding(token_hash, self.masklist[i], padding_idx=self.padding_idx)) mask = sum(mask) if self.one_emb_type == 'binary': mask.clamp_(0, 1) return mask def construct_matrix_for_output_layer(self): vocab_vec = self.mask.new(range(self.num_embeddings)).long() matrix = self.forward(vocab_vec, dropout=0) return matrix def forward(self, input, dropout=None): if input.is_cuda and not self.mask.is_cuda: self.mask = self.mask.cuda() relu_dropout = self.relu_dropout if dropout is None else dropout each_token_mask = nn.functional.embedding(input, self.mask, padding_idx=self.padding_idx) embed = each_token_mask * self.weight.expand_as(each_token_mask) for i in range(self.layernum): embed = nn.functional.linear(embed, self.weight_matrices[i]) if i+1 != self.layernum: embed = nn.functional.relu(embed) embed = nn.functional.dropout(embed, p=relu_dropout, training=self.training) return embed
class FullAttention(nn.Module): def __init__(self, full_size, hidden_size, num_level): super(FullAttention, self).__init__() assert (hidden_size % num_level == 0) self.full_size = full_size self.hidden_size = hidden_size self.attsize_per_lvl = hidden_size // num_level self.num_level = num_level self.linear = nn.Linear(full_size, hidden_size, bias=False) self.linear_final = Parameter(torch.ones(1, hidden_size), requires_grad=True) self.output_size = hidden_size print("Full Attention: (atten. {} -> {}, take {}) x {}".format( self.full_size, self.attsize_per_lvl, hidden_size // num_level, self.num_level)) def forward(self, x1_att, x2_att, x1, x2, x2_mask): """ x1_att: batch * len1 * full_size x2_att: batch * len2 * full_size x1: batch * len1 * hidden_size x2: batch * len2 * hidden_size x2_mask: batch * len2 """ x1_att = dropout(x1_att, p=my_dropout_p, training=self.training) x2_att = dropout(x2_att, p=my_dropout_p, training=self.training) x1_key = F.relu(self.linear(x1_att.view(-1, self.full_size))) x2_key = F.relu(self.linear(x2_att.view(-1, self.full_size))) final_v = self.linear_final.expand_as(x2_key) x2_key = final_v * x2_key x1_rep = x1_key.view(-1, x1.size(1), self.num_level, self.attsize_per_lvl).transpose( 1, 2).contiguous().view(-1, x1.size(1), self.attsize_per_lvl) x2_rep = x2_key.view(-1, x2.size(1), self.num_level, self.attsize_per_lvl).transpose( 1, 2).contiguous().view(-1, x2.size(1), self.attsize_per_lvl) scores = x1_rep.bmm(x2_rep.transpose(1, 2)).view( -1, self.num_level, x1.size(1), x2.size(1)) # batch * num_level * len1 * len2 # x2_mask = x2_mask.unsqueeze(1).unsqueeze(2).expand_as(scores) # scores.data.masked_fill_(x2_mask.data, -float('inf')) alpha_flat = F.softmax(scores.view(-1, x2.size(1))) alpha = alpha_flat.view(-1, x1.size(1), x2.size(1)) size_per_level = self.hidden_size // self.num_level atten_seq = alpha.bmm(x2.contiguous().view( -1, x2.size(1), self.num_level, size_per_level).transpose( 1, 2).contiguous().view(-1, x2.size(1), size_per_level)) return atten_seq.view(-1, self.num_level, x1.size(1), size_per_level).transpose(1, 2).contiguous().view( -1, x1.size(1), self.hidden_size)