def forward(self, query_tensor, value_tensor, value_attention_mask=None): """ 输出的 length 与 query_tensor length 保持一致。 Args: query_tensor: value_tensor: value_attention_mask: Returns: """ batch_size, query_length, _ = query_tensor.shape _, value_length, _ = value_tensor.shape query_tensor = reshape_tensor(query_tensor, (-1, self.dim)) value_tensor = reshape_tensor(value_tensor, (-1, self.dim)) query_tensor = self.query_layer(query_tensor) key_tensor = self.key_layer(value_tensor) value_tensor = self.value_layer(value_tensor) query_tensor = self.transpose4score(query_tensor, (batch_size, query_length, self.attention_head_num, self.size_per_head)) key_tensor = self.transpose4score(key_tensor, (batch_size, value_length, self.attention_head_num, self.size_per_head)) attention_scores = torch.matmul(query_tensor, key_tensor.permute(0, 1, 3, 2)) # batch_size, attention_head_num, query_length, value_length attention_scores = attention_scores / math.sqrt(float(self.size_per_head)) if value_attention_mask is not None: # batch_size, 1, sqe_len value_attention_mask = torch.unsqueeze(value_attention_mask, 1) # batch_size, 1, sqe_len, 1 value_attention_mask = torch.unsqueeze(value_attention_mask, -1) # batch_size, attention_head_num, squ_len value_attention_mask = value_attention_mask.expand(batch_size, self.attention_head_num, query_length, value_length) attention_scores = attention_scores * value_attention_mask attention_scores = self.softmax(attention_scores) # attention_scores = self.dropout(attention_scores) value_tensor = reshape_tensor(value_tensor, (batch_size, value_length, self.attention_head_num, self.size_per_head)) value_tensor = value_tensor.permute(0, 2, 1, 3) # batch_size, attention_head_num, query_len, size_per_head attention = torch.matmul(attention_scores, value_tensor) # batch_size, attention_head_num, query_length, size_per_head # attention = torch.matmul(attention_mask, value_tensor) attention = attention.permute(0, 2, 1, 3) attention = reshape_tensor(attention, (batch_size, query_length, self.dim)) return attention
def forward(self, input_ids, segment_ids): """""" batch_size, sqe_length = input_ids.shape input_ids = reshape_tensor(input_ids, [-1]) segment_ids = reshape_tensor(segment_ids, [-1]) word_embedding = self.word_embeddings[input_ids] segment_embedding = self.segments_embedding[segment_ids] word_embedding = word_embedding + segment_embedding word_embedding = reshape_tensor(word_embedding, [batch_size, sqe_length, -1]) return (word_embedding, None)
def calculate_loss_binary_cls(model_output, answer, log_sofmax): # loss_start = F.nll_loss(log_sofmax(start_embeddings), start_positions, # reduction="mean") # loss_end = F.nll_loss(log_sofmax(end_embeddings), end_positions, # reduction="mean") # loss = (loss_start + loss_end) / 2 # return loss model_output = log_sofmax(model_output) model_output = reshape_tensor(model_output, (-1, 2)) answer = reshape_tensor(answer, (-1, )) loss = F.nll_loss(model_output, answer, reduction="sum") return loss
def forward(self, tensor): batch_size = tensor.shape[0] tensor = self.conv(tensor) # tensor = reshape_tensor(tensor, (batch_size, self.output_shape[0], -1)) # tensor = self.linear(tensor) tensor = reshape_tensor(tensor, [batch_size] + self.output_shape) return tensor
def query_pointer(self, embeddings, input_mask): """""" # size: batch_size, seq_length, 2 batch_size, len, dim = embeddings.shape embeddings = self.query_pointor_linear(embeddings) embeddings = reshape_tensor(embeddings, (batch_size, 512, 2)) embeddings = mask(embeddings, input_mask, -2) start_embeddings = embeddings[:, :, 0].squeeze(dim=-1) end_embeddings = embeddings[:, :, 1].squeeze(dim=-1) return start_embeddings, end_embeddings
def transpose4score(self, tensor, shape): """ 为计算 score 对 tensor 进行转换. Args: tensor: shape: Returns: """ tensor = reshape_tensor(tensor, shape) tensor = tensor.permute(0, 2, 1, 3) return tensor
def forward(self, query_tensor, value_tensor, attention_mask=None): """ Args: query_tensor: batch_size, len, dim value_tensor: batch_size, len, dim attention_mask: batch_size, len Returns: """ batch_size = query_tensor.shape[0] cnn_datas = [] # size: batch_size, len, len attention_matrix = torch.matmul(query_tensor, value_tensor.permute(0, 2, 1)) # TODO: attention mask 用上 attention_matrix = torch.unsqueeze(attention_matrix, 1) origin_matrix = attention_matrix for i in range(0, self.layer_num): attention_matrix = self.conv[i * 3](attention_matrix) attention_matrix = self.conv[i * 3 + 1](attention_matrix) attention_matrix = self.conv[i * 3 + 2](attention_matrix) attention_matrix = torch.relu(attention_matrix) # todo: 测试下有、没有性能一样不 attention_matrix = self.pools[i](attention_matrix) attention_matrix = torch.relu(attention_matrix) attention_matrix += self.adaptor[i](origin_matrix) attention_matrix = self.layer_normal[i](attention_matrix) cnn_datas.append(attention_matrix) attention_matrix = reshape_tensor(attention_matrix, (batch_size, 512, -1)) attention_matrix = self.linear(attention_matrix) # size: batch_size, length, 4 # attention_matrix = reshape_tensor(attention_matrix, [batch_size, -1, 4]) if self.config.visual_cnn: return attention_matrix, cnn_datas else: return attention_matrix