def get_initial_model_state(self, batch_input): model_state = {} model_state["merged_source_global_ids"] = batch_input["merged_source_global_ids"] model_state["merged_source_local_ids"] = batch_input["merged_source_local_ids"] model_state["source1_local_words_ids"] = batch_input["source1_local_words_ids"] model_state["source2_local_words_ids"] = batch_input["source2_local_words_ids"] batch_size = batch_input["source1_input_words_ids"].shape[0] source1_encoder_output, source2_encoder_output, initial_decoder_hidden = self.encode(batch_input) #initial_decoder_cell = torch.rand(batch_size, self.decoder_output_dim) initial_decoder_cell = initial_decoder_hidden.new_zeros(batch_size, self.decoder_output_dim) model_state["decoder_hidden_state"] = initial_decoder_hidden model_state["decoder_hidden_cell"] = initial_decoder_cell model_state["source1_encoder_output"] = source1_encoder_output model_state["source2_encoder_output"] = source2_encoder_output #initial_source1_decoder_attention = self.source1_attention_layer(initial_decoder_hidden, source1_encoder_output[:,1:, :]) #initial_source2_decoder_attention = self.source2_attention_layer(initial_decoder_hidden, source2_encoder_output[:,1:, :]) initial_source1_decoder_attention = self.source1_attention_layer(initial_decoder_hidden, source1_encoder_output[:,0:, :]) initial_source2_decoder_attention = self.source2_attention_layer(initial_decoder_hidden, source2_encoder_output[:,0:, :]) initial_source1_decoder_attention_score = torch.softmax(initial_source1_decoder_attention, -1) initial_source2_decoder_attention_score = torch.softmax(initial_source2_decoder_attention, -1) #initial_source1_weighted_context = weighted_sum(source1_encoder_output[:,1:, :], initial_source1_decoder_attention_score) #initial_source2_weighted_context = weighted_sum(source2_encoder_output[:,1:, :], initial_source2_decoder_attention_score) initial_source1_weighted_context = weighted_sum(source1_encoder_output, initial_source1_decoder_attention_score) initial_source2_weighted_context = weighted_sum(source2_encoder_output, initial_source2_decoder_attention_score) model_state["source1_weighted_context"] = initial_source1_weighted_context model_state["source2_weighted_context"] = initial_source2_weighted_context return model_state
def forward(self, sent_a, sent_a_mask, sent_b, sent_b_mask): """ 输入: sent_a: [batch_size, seq_a_len, vec_dim] sent_a_mask: [batch_size, seq_a_len] sent_b: [batch_size, seq_b_len, vec_dim] sent_b_mask: [batch_size, seq_b_len] 输出: sent_a_att: [batch_size, seq_a_len, seq_b_len] sent_b_att: [batch_size, seq_b_len, seq_a_len] """ # similarity matrix similarity_matrix = torch.matmul( sent_a, sent_b.transpose(1, 2).contiguous()) # [batch_size, seq_a, seq_b] sent_a_b_attn = masked_softmax( similarity_matrix, sent_b_mask) # [batch_size, seq_a, seq_b] sent_b_a_attn = masked_softmax( similarity_matrix.transpose(1, 2).contiguous(), sent_a_mask) # [batch_size, seq_b, seq_a] sent_a_att = weighted_sum(sent_b, sent_a_b_attn, sent_a_mask) # [batch_size, seq_a, vec_dim] sent_b_att = weighted_sum(sent_a, sent_b_a_attn, sent_b_mask) # [batch_size, seq_b, vec_dim] return sent_a_att, sent_b_att
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): """ Args: premise_batch: A batch of sequences of vectors representing the premises in some NLI task. The batch is assumed to have the size (batch, sequences, vector_dim). premise_mask: A mask for the sequences in the premise batch, to ignore padding data in the sequences during the computation of the attention. hypothesis_batch: A batch of sequences of vectors representing the hypotheses in some NLI task. The batch is assumed to have the size (batch, sequences, vector_dim). hypothesis_mask: A mask for the sequences in the hypotheses batch, to ignore padding data in the sequences during the computation of the attention. Returns: attended_premises: The sequences of attention vectors for the premises in the input batch. attended_hypotheses: The sequences of attention vectors for the hypotheses in the input batch. """ # Dot product between premises and hypotheses in each sequence of # the batch. similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1).contiguous()) # Softmax attention weights prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Weighted sums of the hypotheses for the the premises attention, # and vice-versa for the attention of the hypotheses. attended_premises = weighted_sum(hypothesis_batch, prem_hyp_attn, premise_mask) attended_hypotheses = weighted_sum(premise_batch, hyp_prem_attn, hypothesis_mask) return attended_premises, attended_hypotheses
def decode_step(self, previous_token_ids, model_state): # Fetch last timestep values. previous_source1_weighted_context = model_state["source1_weighted_context"] previous_source2_weighted_context = model_state["source2_weighted_context"] previous_decoder_hidden_state = model_state["decoder_hidden_state"] previous_decoder_hidden_cell = model_state["decoder_hidden_cell"] previous_token_embedding = self.get_target_token_embeddings(previous_token_ids) # update decoder hidden state of current timestep current_decoder_input = torch.cat((previous_token_embedding, previous_source1_weighted_context, previous_source2_weighted_context), dim=-1) decoder_hidden_state, decoder_hidden_cell = self.decoder_cell(current_decoder_input, (previous_decoder_hidden_state, previous_decoder_hidden_cell)) # print(decoder_hidden_state.shape, decoder_hidden_cell.shape) if self.flag_use_layernorm: decoder_hidden_state = self.decoder_hidden_layernorm(decoder_hidden_state) decoder_hidden_cell = self.decoder_cell_layernorm(decoder_hidden_cell) model_state["decoder_hidden_state"] = decoder_hidden_state model_state["decoder_hidden_cell"] = decoder_hidden_cell #Computing decoder's attention score on encoder output. source1_encoder_output, source2_encoder_output = model_state["source1_encoder_output"], model_state["source2_encoder_output"] #source1_decoder_attention_output = self.source1_attention_layer(decoder_hidden_state, source1_encoder_output[:,1:, :]) #source2_decoder_attention_output = self.source2_attention_layer(decoder_hidden_state, source2_encoder_output[:,1:, :]) source1_decoder_attention_output = self.source1_attention_layer(decoder_hidden_state, source1_encoder_output) source2_decoder_attention_output = self.source2_attention_layer(decoder_hidden_state, source2_encoder_output) # print("attention dim: ", source1_decoder_attention_output.shape) source1_decoder_attention_score = torch.softmax(source1_decoder_attention_output, -1) source2_decoder_attention_score = torch.softmax(source2_decoder_attention_output, -1) model_state["source1_decoder_attention_score"] = source1_decoder_attention_score model_state["source2_decoder_attention_score"] = source2_decoder_attention_score #context vector of source1 and source2, weighted sum of (source encoder output) * decoder attention score. #source1_weighted_context = weighted_sum(source1_encoder_output[:,1:, :], source1_decoder_attention_score) #source2_weighted_context = weighted_sum(source2_encoder_output[:,1:, :], source2_decoder_attention_score) source1_weighted_context = weighted_sum(source1_encoder_output, source1_decoder_attention_score) source2_weighted_context = weighted_sum(source2_encoder_output, source2_decoder_attention_score) model_state["source1_weighted_context"] = source1_weighted_context model_state["source2_weighted_context"] = source2_weighted_context #Computing current gate socre. gate_input = torch.cat((previous_token_embedding, source1_weighted_context, source2_weighted_context, decoder_hidden_state), dim=-1) gate_projected = self.gate_projection_layer(gate_input).squeeze(-1) gate_score = torch.sigmoid(gate_projected) model_state["gate_score"] = gate_score return model_state
def forward(self, premises, premises_mask, hypotheses, hypotheses_mask): """ params premises: (S, N, H) hypotheses: (T, N, H) premises mask: (N, S) hypotheses maks: (N, T) return new_premises: (S, N, H) new_hypotheses: (T, N, H) """ premises = premises.transpose(0, 1) logging.debug(f"premises shape: {premises.shape}") # (N, S, H) hypotheses = hypotheses.transpose(0, 1) logging.debug(f"hypotheses shape: {hypotheses.shape}") # (N, T, H) attn_premises = torch.bmm(premises, hypotheses.transpose(1, 2)) # (N, S, T) attn_hypotheses = attn_premises.transpose(1, 2) # (N, T, S) attn_premises = masked_softmax(attn_premises, premises_mask, hypotheses_mask) # (N, S, T) attn_hypotheses = masked_softmax(attn_hypotheses, hypotheses_mask, premises_mask) # (N, T, S) logging.debug( f"weight: {attn_premises.shape}, tensor: {hypotheses.shape}") new_premises = weighted_sum(attn_premises, hypotheses) # (N, S, H) new_hypotheses = weighted_sum(attn_hypotheses, premises) # (N, T, H) new_premises = new_premises.transpose(0, 1) # (S, N, H) new_hypotheses = new_hypotheses.transpose(0, 1) # (T, N, H) return new_premises, new_hypotheses, attn_premises, attn_hypotheses
def draw_class_map(image, class_map, num_classes): colors = np.random.RandomState(42).uniform(1 / 3, 1, size=(num_classes + 1, 3)) colors[0] = 0.0 colors = torch.tensor(colors, dtype=torch.float, device=class_map.device) class_map = colors[class_map] class_map = class_map.permute(0, 3, 1, 2) class_map = F.interpolate(class_map, size=image.size()[2:], mode="nearest") return weighted_sum(image, class_map, 0.5)
def mix_up(left, right, a): assert len(left) == len(right) assert all(l.size() == r.size() for l, r in zip(left, right)) lam = torch.distributions.Beta(a, a).sample( (left[0].size(0), )).to(left[0].device) lam = torch.max(lam, 1 - lam) return [ weighted_sum(l, r, a=lam.view(lam.size(0), *[1 for _ in range(l.dim() - 1)])) for l, r in zip(left, right) ]
def forward(self, input): age = (input["age"] / 100.0).unsqueeze(1) age_is_nan = torch.isnan(age) age[age_is_nan] = 0.0 age_0 = torch.where(age_is_nan, self.age_nan, self.age_0) age_1 = torch.where(age_is_nan, self.age_nan, self.age_1) age = weighted_sum(age_0, age_1, age) sex = self.sex(input["sex"]) site = self.site(input["site"]) input = torch.cat([age, sex, site], 1) input = self.output(input) return input
def forward(self, sentences1, sentences2): """ sentences1 [batch, max_len] sentences2 [batch, max_len] """ # get mask sentences1_mask = (sentences1 != self.padding_idx).long().to( self.device) # [batch_size, max_len] sentences2_mask = (sentences2 != self.padding_idx).long().to( self.device) # [batch_size, max_len] # input encoding sentences1_emb = self.emb(sentences1) # [batch_size, max_len, dim] sentences2_emb = self.emb(sentences2) # [batch_size, max_len, dim] sentences1_len = torch.sum(sentences1_mask, dim=-1).view(-1) # [batch_size] sentences2_len = torch.sum(sentences2_mask, dim=-1).view(-1) # [batch_size] #encoder s1_encoded = self.encoder_layer( sentences1_emb, sentences1_len) # [batch_size, max_len_q1, dim] s2_encoded = self.encoder_layer( sentences2_emb, sentences2_len) # [batch_size, max_len_q2, dim] # local inference # e_ij = a_i^Tb_j (11) similarity_matrix = s1_encoded.bmm( s2_encoded.transpose( 2, 1).contiguous()) # [batch_size, max_len_q1, max_len_q2] s1_s2_atten = masked_softmax( similarity_matrix, sentences2_mask) # [batch_size, max_len_q1, max_len_q2] s2_s1_atten = masked_softmax( similarity_matrix.transpose(2, 1).contiguous(), sentences1_mask) # [batch_size, max_len_q2, max_len_q1] # eij * bj a_hat = weighted_sum(s1_encoded, s1_s2_atten, sentences1_mask) # [batch_size, max_len_q1, dim] b_hat = weighted_sum(s2_encoded, s2_s1_atten, sentences2_mask) # [batch_size, max_len_q2, dim] # Enhancement of local inference information # ma = [a¯; a~; a¯ − a~; a¯ a~]; # mb = [b¯; b~; b¯ − b~; b¯ b~] m_a = torch.cat( [s1_encoded, a_hat, s1_encoded - a_hat, s1_encoded * a_hat], dim=-1) # [batch_size, max_len_q1, 4 * dim] m_b = torch.cat( [s2_encoded, b_hat, s2_encoded - b_hat, s2_encoded * b_hat], dim=-1) # 3.3 Inference Composition s1_projected = self.projection(m_a) # [batch_size, max_len_q1, dim] s2_projected = self.projection(m_b) # [batch_size, max_len_q2, dim] v_a = self.composition_layer( s1_projected, sentences1_len) # [batch_size, max_len_q1, dim] v_b = self.composition_layer( s2_projected, sentences2_len) # [batch_size, max_len_q2, dim] v_a_avg = torch.sum(v_a * sentences1_mask.unsqueeze(1).transpose(2, 1), dim=1) \ / torch.sum(sentences1_mask, dim=1, keepdim = True) # q1_mask batch_size, 1, max_len_q1 v_b_avg = torch.sum(v_b * sentences2_mask.unsqueeze(1).transpose(2, 1), dim=1) \ / torch.sum(sentences2_mask, dim=1, keepdim = True) v_a_max, _ = replace_masked(v_a, sentences1_mask, -1e7).max(dim=1) # [batch_size, dim] v_b_max, _ = replace_masked(v_b, sentences2_mask, -1e7).max(dim=1) v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # [batch_size, dim * 4] # predict logits = self.predict_fc(v) return logits