Exemplo n.º 1
0
    def forward(self, batch):
        sentence = batch['words']
        lengths = batch['words_lens']

        if self.config.is_caps:
            caps = batch['caps']
        max_length = torch.max(lengths)
        char_emb = []
        word_embed = self.word_embeds(sentence)
        for chars, char_len in batch['chars']:
            seq_embed = self.char_embeds(chars)
            seq_lengths, sort_idx = torch.sort(char_len, descending=True)
            _, unsort_idx = torch.sort(sort_idx)
            seq_embed = seq_embed[sort_idx]
            packed = pack_padded_sequence(seq_embed,
                                          seq_lengths,
                                          batch_first=True)
            output, hidden = self.lstm_char(packed)
            lstm_feats, _ = pad_packed_sequence(output, batch_first=True)
            lstm_feats = lstm_feats.contiguous()
            b, t_k, d = list(lstm_feats.size())

            seq_rep = lstm_feats.view(b, t_k, 2, -1)  #0 is fwd and 1 is bwd

            last_idx = char_len - 1
            seq_rep_fwd = seq_rep[unsort_idx, 0, 0]
            seq_rep_bwd = seq_rep[unsort_idx, last_idx, 1]

            seq_out = torch.cat([seq_rep_fwd, seq_rep_bwd], 1)
            # fill up the dummy char embedding for padding
            seq_out = F.pad(seq_out, (0, 0, 0, max_length - seq_out.size(0)))
            char_emb.append(seq_out.unsqueeze(0))

        char_emb = torch.cat(char_emb, 0)  #b x n x c_dim

        if self.config.is_caps:
            caps_embd = self.caps_embeds(caps)
            word_embed = torch.cat([char_emb, word_embed, caps_embd], 2)
        else:
            word_embed = torch.cat([char_emb, word_embed], 2)
        word_embed = self.dropout(word_embed)

        mask = get_mask(lengths, self.config.is_cuda)
        word_embed_proj = self.tx_proj(word_embed)
        lstm_feats = self.lstm(word_embed_proj, mask.unsqueeze(-1))

        b, t_k, d = list(lstm_feats.size())

        h = self.hidden_layer(lstm_feats.view(-1, d))
        h = self.tanh_layer(h)
        logits = self.hidden2tag(h)
        logits = logits.view(b, t_k, -1)

        return logits
Exemplo n.º 2
0
    def viterbi_decode_batch(self, emissions, lengths):
        mask = get_mask(lengths, self.config)
        seq_len = emissions.shape[1]

        log_prob = emissions[:, 0].clone()
        log_prob += self.transitions[
            self.start_tag, :self.start_tag].unsqueeze(0)

        end_scores = log_prob + self.transitions[:self.start_tag,
                                                 self.end_tag].unsqueeze(0)

        best_scores_list = []
        best_scores_list.append(end_scores.unsqueeze(1))

        best_paths_0 = torch.Tensor().long()
        if self.config.is_cuda:
            best_paths_0 = best_paths_0.cuda()
        best_paths_list = [best_paths_0]

        for idx in range(1, seq_len):
            broadcast_emissions = emissions[:, idx].unsqueeze(1)
            broadcast_transmissions = self.transitions[:self.start_tag, :self.
                                                       start_tag].unsqueeze(0)
            broadcast_log_prob = log_prob.unsqueeze(2)
            score = broadcast_emissions + broadcast_transmissions + broadcast_log_prob
            max_scores, max_score_indices = torch.max(score, 1)
            best_paths_list.append(max_score_indices.unsqueeze(1))
            end_scores = max_scores + self.transitions[:self.start_tag, self.
                                                       end_tag].unsqueeze(0)

            best_scores_list.append(end_scores.unsqueeze(1))
            log_prob = max_scores

        best_scores = torch.cat(best_scores_list, 1).float()
        best_paths = torch.cat(best_paths_list, 1)

        max_scores, max_indices_from_scores = torch.max(best_scores, 2)

        valid_index_tensor = torch.tensor(0).long()
        padding_tensor = torch.tensor(Constants.TAG_PAD_ID).long()

        if self.config.is_cuda:
            valid_index_tensor = valid_index_tensor.cuda()
            padding_tensor = padding_tensor.cuda()

        labels = max_indices_from_scores[:, seq_len - 1]
        labels = torch.where(mask[:, seq_len - 1] != 1.0, padding_tensor,
                             labels)
        all_labels = labels.unsqueeze(1).long()

        for idx in range(seq_len - 2, -1, -1):
            indices_for_lookup = all_labels[:, -1].clone()
            indices_for_lookup = torch.where(
                indices_for_lookup == Constants.TAG_PAD_ID, valid_index_tensor,
                indices_for_lookup)

            indices_from_prev_pos = best_paths[:, idx, :].gather(
                1,
                indices_for_lookup.view(-1, 1).long()).squeeze(1)
            indices_from_prev_pos = torch.where(mask[:, idx + 1] != 1.0,
                                                padding_tensor,
                                                indices_from_prev_pos)

            indices_from_max_scores = max_indices_from_scores[:, idx]
            indices_from_max_scores = torch.where(mask[:, idx + 1] == 1.0,
                                                  padding_tensor,
                                                  indices_from_max_scores)

            labels = torch.where(
                indices_from_max_scores == Constants.TAG_PAD_ID,
                indices_from_prev_pos, indices_from_max_scores)

            # Set to ignore_index if present state is not valid.
            labels = torch.where(mask[:, idx] != 1.0, padding_tensor, labels)
            all_labels = torch.cat((all_labels, labels.view(-1, 1).long()), 1)

        last_tag_indices = mask.sum(1, dtype=torch.long) - 1
        sentence_score = max_scores.gather(1,
                                           last_tag_indices.view(-1,
                                                                 1)).squeeze(1)

        return sentence_score, torch.flip(all_labels, [1])
Exemplo n.º 3
0
 def predict(self, emissions, lengths):
     mask = get_mask(lengths, self.config.is_cuda)
     best_scores, pred = self.crf.viterbi_decode_batch(emissions, mask)
     return pred