Ejemplo n.º 1
0
    def get_initial_model_state(self, batch_input):

        model_state = {}
        model_state["merged_source_global_ids"] = batch_input["merged_source_global_ids"]
        model_state["merged_source_local_ids"] = batch_input["merged_source_local_ids"]
        model_state["source1_local_words_ids"] = batch_input["source1_local_words_ids"]
        model_state["source2_local_words_ids"] = batch_input["source2_local_words_ids"]

        batch_size = batch_input["source1_input_words_ids"].shape[0]

        source1_encoder_output, source2_encoder_output, initial_decoder_hidden = self.encode(batch_input)
        #initial_decoder_cell = torch.rand(batch_size, self.decoder_output_dim)
        initial_decoder_cell = initial_decoder_hidden.new_zeros(batch_size, self.decoder_output_dim)
        
        model_state["decoder_hidden_state"] = initial_decoder_hidden
        model_state["decoder_hidden_cell"]  = initial_decoder_cell
        model_state["source1_encoder_output"] = source1_encoder_output
        model_state["source2_encoder_output"] = source2_encoder_output

        #initial_source1_decoder_attention = self.source1_attention_layer(initial_decoder_hidden, source1_encoder_output[:,1:, :])
        #initial_source2_decoder_attention = self.source2_attention_layer(initial_decoder_hidden, source2_encoder_output[:,1:, :])
        initial_source1_decoder_attention = self.source1_attention_layer(initial_decoder_hidden, source1_encoder_output[:,0:, :])
        initial_source2_decoder_attention = self.source2_attention_layer(initial_decoder_hidden, source2_encoder_output[:,0:, :])
        
        initial_source1_decoder_attention_score = torch.softmax(initial_source1_decoder_attention, -1)
        initial_source2_decoder_attention_score = torch.softmax(initial_source2_decoder_attention, -1)
        #initial_source1_weighted_context = weighted_sum(source1_encoder_output[:,1:, :], initial_source1_decoder_attention_score)
        #initial_source2_weighted_context = weighted_sum(source2_encoder_output[:,1:, :], initial_source2_decoder_attention_score)
        initial_source1_weighted_context = weighted_sum(source1_encoder_output, initial_source1_decoder_attention_score)
        initial_source2_weighted_context = weighted_sum(source2_encoder_output, initial_source2_decoder_attention_score)
        model_state["source1_weighted_context"] = initial_source1_weighted_context
        model_state["source2_weighted_context"] = initial_source2_weighted_context

        return model_state
Ejemplo n.º 2
0
 def forward(self, sent_a, sent_a_mask, sent_b, sent_b_mask):
     """
     输入:
     sent_a: [batch_size, seq_a_len, vec_dim]
     sent_a_mask: [batch_size, seq_a_len]
     sent_b: [batch_size, seq_b_len, vec_dim]
     sent_b_mask: [batch_size, seq_b_len]
     输出:
     sent_a_att: [batch_size, seq_a_len, seq_b_len]
     sent_b_att: [batch_size, seq_b_len, seq_a_len]
     """
     # similarity matrix
     similarity_matrix = torch.matmul(
         sent_a,
         sent_b.transpose(1, 2).contiguous())  # [batch_size, seq_a, seq_b]
     sent_a_b_attn = masked_softmax(
         similarity_matrix, sent_b_mask)  # [batch_size, seq_a, seq_b]
     sent_b_a_attn = masked_softmax(
         similarity_matrix.transpose(1, 2).contiguous(),
         sent_a_mask)  # [batch_size, seq_b, seq_a]
     sent_a_att = weighted_sum(sent_b, sent_a_b_attn,
                               sent_a_mask)  # [batch_size, seq_a, vec_dim]
     sent_b_att = weighted_sum(sent_a, sent_b_a_attn,
                               sent_b_mask)  # [batch_size, seq_b, vec_dim]
     return sent_a_att, sent_b_att
Ejemplo n.º 3
0
 def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
     """
     Args:
         premise_batch: A batch of sequences of vectors representing the
             premises in some NLI task. The batch is assumed to have the
             size (batch, sequences, vector_dim).
         premise_mask: A mask for the sequences in the premise batch, to
             ignore padding data in the sequences during the computation of
             the attention.
         hypothesis_batch: A batch of sequences of vectors representing the
             hypotheses in some NLI task. The batch is assumed to have the
             size (batch, sequences, vector_dim).
         hypothesis_mask: A mask for the sequences in the hypotheses batch,
             to ignore padding data in the sequences during the computation
             of the attention.
     Returns:
         attended_premises: The sequences of attention vectors for the
             premises in the input batch.
         attended_hypotheses: The sequences of attention vectors for the
             hypotheses in the input batch.
     """
     # Dot product between premises and hypotheses in each sequence of
     # the batch.
     similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1).contiguous())
     # Softmax attention weights
     prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
     hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
     # Weighted sums of the hypotheses for the the premises attention,
     # and vice-versa for the attention of the hypotheses.
     attended_premises = weighted_sum(hypothesis_batch, prem_hyp_attn, premise_mask)
     attended_hypotheses = weighted_sum(premise_batch, hyp_prem_attn, hypothesis_mask)
     return attended_premises, attended_hypotheses
Ejemplo n.º 4
0
    def decode_step(self, previous_token_ids, model_state): 
        # Fetch last timestep values.
        previous_source1_weighted_context = model_state["source1_weighted_context"]
        previous_source2_weighted_context = model_state["source2_weighted_context"]
        previous_decoder_hidden_state     = model_state["decoder_hidden_state"]
        previous_decoder_hidden_cell      = model_state["decoder_hidden_cell"]
        previous_token_embedding = self.get_target_token_embeddings(previous_token_ids)
        
        # update decoder hidden state of current timestep
        current_decoder_input = torch.cat((previous_token_embedding, previous_source1_weighted_context, 
                                           previous_source2_weighted_context), dim=-1)
        decoder_hidden_state, decoder_hidden_cell = self.decoder_cell(current_decoder_input, 
                           (previous_decoder_hidden_state, previous_decoder_hidden_cell))
        # print(decoder_hidden_state.shape, decoder_hidden_cell.shape)
        if self.flag_use_layernorm:
            decoder_hidden_state = self.decoder_hidden_layernorm(decoder_hidden_state)
            decoder_hidden_cell  = self.decoder_cell_layernorm(decoder_hidden_cell)
        model_state["decoder_hidden_state"] = decoder_hidden_state
        model_state["decoder_hidden_cell"]  = decoder_hidden_cell

        #Computing decoder's attention score on encoder output.
        source1_encoder_output, source2_encoder_output = model_state["source1_encoder_output"], model_state["source2_encoder_output"]
        #source1_decoder_attention_output = self.source1_attention_layer(decoder_hidden_state, source1_encoder_output[:,1:, :])
        #source2_decoder_attention_output = self.source2_attention_layer(decoder_hidden_state, source2_encoder_output[:,1:, :])
        source1_decoder_attention_output = self.source1_attention_layer(decoder_hidden_state, source1_encoder_output)
        source2_decoder_attention_output = self.source2_attention_layer(decoder_hidden_state, source2_encoder_output)
        
        # print("attention dim: ", source1_decoder_attention_output.shape)
        source1_decoder_attention_score = torch.softmax(source1_decoder_attention_output, -1)
        source2_decoder_attention_score = torch.softmax(source2_decoder_attention_output, -1)
        model_state["source1_decoder_attention_score"] = source1_decoder_attention_score
        model_state["source2_decoder_attention_score"] = source2_decoder_attention_score
        
        #context vector of source1 and source2, weighted sum of (source encoder output) * decoder attention score. 
        #source1_weighted_context = weighted_sum(source1_encoder_output[:,1:, :], source1_decoder_attention_score)
        #source2_weighted_context = weighted_sum(source2_encoder_output[:,1:, :], source2_decoder_attention_score)
        source1_weighted_context = weighted_sum(source1_encoder_output, source1_decoder_attention_score)
        source2_weighted_context = weighted_sum(source2_encoder_output, source2_decoder_attention_score)
        model_state["source1_weighted_context"] = source1_weighted_context
        model_state["source2_weighted_context"] = source2_weighted_context
        
        #Computing current gate socre.
        gate_input = torch.cat((previous_token_embedding, source1_weighted_context, 
                                source2_weighted_context, decoder_hidden_state), dim=-1)
        gate_projected = self.gate_projection_layer(gate_input).squeeze(-1)
        gate_score = torch.sigmoid(gate_projected)
        model_state["gate_score"] = gate_score

        return model_state
Ejemplo n.º 5
0
    def forward(self, premises, premises_mask, hypotheses, hypotheses_mask):
        """
        params
            premises: (S, N, H)
            hypotheses: (T, N, H)
            premises mask: (N, S)
            hypotheses maks: (N, T) 
        return
            new_premises: (S, N, H)
            new_hypotheses: (T, N, H)
        """
        premises = premises.transpose(0, 1)
        logging.debug(f"premises shape: {premises.shape}")
        # (N, S, H)
        hypotheses = hypotheses.transpose(0, 1)
        logging.debug(f"hypotheses shape: {hypotheses.shape}")
        # (N, T, H)

        attn_premises = torch.bmm(premises, hypotheses.transpose(1, 2))
        # (N, S, T)
        attn_hypotheses = attn_premises.transpose(1, 2)
        # (N, T, S)

        attn_premises = masked_softmax(attn_premises, premises_mask,
                                       hypotheses_mask)
        # (N, S, T)
        attn_hypotheses = masked_softmax(attn_hypotheses, hypotheses_mask,
                                         premises_mask)
        # (N, T, S)

        logging.debug(
            f"weight: {attn_premises.shape}, tensor: {hypotheses.shape}")
        new_premises = weighted_sum(attn_premises, hypotheses)
        # (N, S, H)
        new_hypotheses = weighted_sum(attn_hypotheses, premises)
        # (N, T, H)

        new_premises = new_premises.transpose(0, 1)
        # (S, N, H)
        new_hypotheses = new_hypotheses.transpose(0, 1)
        # (T, N, H)

        return new_premises, new_hypotheses, attn_premises, attn_hypotheses
Ejemplo n.º 6
0
def draw_class_map(image, class_map, num_classes):
    colors = np.random.RandomState(42).uniform(1 / 3,
                                               1,
                                               size=(num_classes + 1, 3))
    colors[0] = 0.0
    colors = torch.tensor(colors, dtype=torch.float, device=class_map.device)

    class_map = colors[class_map]
    class_map = class_map.permute(0, 3, 1, 2)
    class_map = F.interpolate(class_map, size=image.size()[2:], mode="nearest")

    return weighted_sum(image, class_map, 0.5)
Ejemplo n.º 7
0
def mix_up(left, right, a):
    assert len(left) == len(right)
    assert all(l.size() == r.size() for l, r in zip(left, right))

    lam = torch.distributions.Beta(a, a).sample(
        (left[0].size(0), )).to(left[0].device)
    lam = torch.max(lam, 1 - lam)

    return [
        weighted_sum(l,
                     r,
                     a=lam.view(lam.size(0), *[1 for _ in range(l.dim() - 1)]))
        for l, r in zip(left, right)
    ]
Ejemplo n.º 8
0
    def forward(self, input):
        age = (input["age"] / 100.0).unsqueeze(1)
        age_is_nan = torch.isnan(age)
        age[age_is_nan] = 0.0

        age_0 = torch.where(age_is_nan, self.age_nan, self.age_0)
        age_1 = torch.where(age_is_nan, self.age_nan, self.age_1)
        age = weighted_sum(age_0, age_1, age)

        sex = self.sex(input["sex"])
        site = self.site(input["site"])

        input = torch.cat([age, sex, site], 1)
        input = self.output(input)

        return input
Ejemplo n.º 9
0
    def forward(self, sentences1, sentences2):
        """
            sentences1 [batch, max_len]
            sentences2 [batch, max_len]
        """
        # get mask
        sentences1_mask = (sentences1 != self.padding_idx).long().to(
            self.device)  # [batch_size, max_len]
        sentences2_mask = (sentences2 != self.padding_idx).long().to(
            self.device)  # [batch_size, max_len]
        # input encoding
        sentences1_emb = self.emb(sentences1)  # [batch_size, max_len, dim]
        sentences2_emb = self.emb(sentences2)  # [batch_size, max_len, dim]
        sentences1_len = torch.sum(sentences1_mask,
                                   dim=-1).view(-1)  # [batch_size]
        sentences2_len = torch.sum(sentences2_mask,
                                   dim=-1).view(-1)  # [batch_size]
        #encoder
        s1_encoded = self.encoder_layer(
            sentences1_emb, sentences1_len)  # [batch_size, max_len_q1, dim]
        s2_encoded = self.encoder_layer(
            sentences2_emb, sentences2_len)  # [batch_size, max_len_q2, dim]
        # local inference
        # e_ij = a_i^Tb_j  (11)
        similarity_matrix = s1_encoded.bmm(
            s2_encoded.transpose(
                2, 1).contiguous())  # [batch_size, max_len_q1, max_len_q2]
        s1_s2_atten = masked_softmax(
            similarity_matrix,
            sentences2_mask)  # [batch_size, max_len_q1, max_len_q2]
        s2_s1_atten = masked_softmax(
            similarity_matrix.transpose(2, 1).contiguous(),
            sentences1_mask)  # [batch_size, max_len_q2, max_len_q1]

        # eij * bj
        a_hat = weighted_sum(s1_encoded, s1_s2_atten,
                             sentences1_mask)  # [batch_size, max_len_q1, dim]
        b_hat = weighted_sum(s2_encoded, s2_s1_atten,
                             sentences2_mask)  # [batch_size, max_len_q2, dim]

        # Enhancement of local inference information
        # ma = [a¯; a~; a¯ − a~; a¯ a~];
        # mb = [b¯; b~; b¯ − b~; b¯ b~]
        m_a = torch.cat(
            [s1_encoded, a_hat, s1_encoded - a_hat, s1_encoded * a_hat],
            dim=-1)  # [batch_size, max_len_q1, 4 * dim]
        m_b = torch.cat(
            [s2_encoded, b_hat, s2_encoded - b_hat, s2_encoded * b_hat],
            dim=-1)

        # 3.3 Inference Composition
        s1_projected = self.projection(m_a)  # [batch_size, max_len_q1, dim]
        s2_projected = self.projection(m_b)  # [batch_size, max_len_q2, dim]
        v_a = self.composition_layer(
            s1_projected, sentences1_len)  # [batch_size, max_len_q1, dim]
        v_b = self.composition_layer(
            s2_projected, sentences2_len)  # [batch_size, max_len_q2, dim]
        v_a_avg = torch.sum(v_a * sentences1_mask.unsqueeze(1).transpose(2, 1), dim=1)  \
                   / torch.sum(sentences1_mask, dim=1, keepdim = True) # q1_mask batch_size, 1, max_len_q1
        v_b_avg = torch.sum(v_b * sentences2_mask.unsqueeze(1).transpose(2, 1), dim=1) \
                   / torch.sum(sentences2_mask, dim=1, keepdim = True)
        v_a_max, _ = replace_masked(v_a, sentences1_mask,
                                    -1e7).max(dim=1)  # [batch_size, dim]
        v_b_max, _ = replace_masked(v_b, sentences2_mask, -1e7).max(dim=1)

        v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max],
                      dim=1)  # [batch_size, dim * 4]
        # predict
        logits = self.predict_fc(v)
        return logits