예제 #1
0
 def sampled_segments(self) -> np.ndarray:
     pw = self.numpy()
     ret = list()
     for bi, si, wps, wl in zip(pw.batch_indices, pw.sample_indices,
                                pw.word_positions, pw.word_lengths):
         chars = [
             self.orig_segments[bi][wp] for i, wp in enumerate(wps)
             if i < wl
         ]
         ret.append('-'.join(chars))
     return get_array(ret)
예제 #2
0
 def collate_helper(key, cls, pad=False):
     ret = [item[key] for item in batch]
     if cls is np.ndarray:
         return get_array(ret)
     elif cls is torch.Tensor:
         if pad:
             ret = torch.nn.utils.rnn.pad_sequence(ret, batch_first=True)
         else:
             ret = torch.LongTensor(ret)
         return ret
     else:
         raise ValueError(f'Unsupported class "{cls}".')
예제 #3
0
    def load_data(self, data_path: Path):
        segment_dict = self._get_segment_dict(data_path)
        segment_windows = list()
        with data_path.open('r', encoding='utf8') as fin:
            for line in fin:
                tokens = line.strip().split()
                segments = [segment_dict[token] for token in tokens]
                sw = SegmentWindow(segments)
                start = 0
                while True:
                    end = min(start + self.window_size, len(sw))
                    broken_sw = sw.break_segment(start, end - 1)
                    segment_windows.append(broken_sw)
                    if end >= len(sw):
                        break
                    start += 1

        return {
            'segments': get_array(segment_windows),
            'matrices': [sw.feat_matrix for sw in segment_windows]
        }
예제 #4
0
 def load_data(self, data_path: Path):
     segment_dict = self._get_segment_dict(data_path)
     segment_windows = list()
     with data_path.open('r', encoding='utf8') as fin:
         for line in fin:
             tokens = line.strip().split()
             segments = [segment_dict[token] for token in tokens]
             lengths = np.asarray([len(segment) for segment in segments])
             cum_lengths = np.cumsum(lengths)
             ex_cum_lengths = np.concatenate(
                 [np.zeros([1], dtype=np.int32), cum_lengths[:-1]])
             last_end = -1
             end = 0
             start = 0
             while start < len(tokens):
                 # for start in range(len(tokens)):
                 while end < len(
                         tokens) and cum_lengths[end] - ex_cum_lengths[
                             start] <= g.max_segment_length:
                     end += 1
                 if end <= start:
                     end = start + 1
                     start += 1
                     continue
                 if end > last_end:
                     segment_window = segments[start:end]
                     if len(SegmentWindow(
                             segment_window)) >= g.min_word_length:
                         segment_windows.append(segment_window)
                 last_end = end
                 start = last_end
     return {
         'segments':
         get_array(segment_windows),
         'matrices': [
             torch.cat([segment.feat_matrix for segment in segment_window],
                       dim=0) for segment_window in segment_windows
         ]
     }
예제 #5
0
 def load_data(self, data_path: Path):
     segments = self._get_segment_dict(data_path)
     return {
         'segments': get_array(list(segments.keys())),
         'matrices': [segment.feat_matrix for segment in segments.values()]
     }
예제 #6
0
    def __init__(self, lu_size: int):
        super().__init__()

        def _has_proper_length(segment):
            l = len(segment)
            return g.min_word_length <= l <= g.max_word_length

        with open(g.vocab_path, 'r', encoding='utf8') as fin:
            _vocab = set(line.strip() for line in fin)
            segments = [Segment(w) for w in _vocab]
            self.vocab = get_array([
                segment for segment in segments if _has_proper_length(segment)
            ])
            lengths = torch.LongTensor(list(map(len, self.vocab)))
            feat_matrix = [segment.feat_matrix for segment in self.vocab]
            feat_matrix = torch.nn.utils.rnn.pad_sequence(feat_matrix,
                                                          batch_first=True)
            max_len = lengths.max().item()
            source_padding = ~get_length_mask(lengths, max_len)
            self.register_buffer('vocab_feat_matrix', feat_matrix)
            self.register_buffer('vocab_source_padding', source_padding)
            self.register_buffer('vocab_length', lengths)
            self.vocab_feat_matrix.rename_('vocab', 'length', 'feat_group')
            self.vocab_source_padding.rename_('vocab', 'length')
            self.vocab_length.rename_('vocab')

            with Rename(self.vocab_feat_matrix, vocab='batch'):
                vocab_dense_feat_matrix = convert_to_dense(
                    self.vocab_feat_matrix)
            self.vocab_dense_feat_matrix = {
                k: v.rename(batch='vocab')
                for k, v in vocab_dense_feat_matrix.items()
            }

            # Get the entire set of units from vocab.
            units = set()
            for segment in self.vocab:
                units.update(segment.segment_list)
            self.id2unit = sorted(units)
            self.unit2id = {u: i for i, u in enumerate(self.id2unit)}
            # Now indexify the vocab. Gather feature matrices for units as well.
            indexed_segments = np.zeros([len(self.vocab), max_len],
                                        dtype='int64')
            unit_feat_matrix = dict()
            for i, segment in enumerate(self.vocab):
                indexed_segments[i, range(len(segment))] = [
                    self.unit2id[u] for u in segment.segment_list
                ]
                for j, u in enumerate(segment.segment_list):
                    if u not in unit_feat_matrix:
                        unit_feat_matrix[u] = segment.feat_matrix[j]
            unit_feat_matrix = [unit_feat_matrix[u] for u in self.id2unit]
            unit_feat_matrix = torch.nn.utils.rnn.pad_sequence(
                unit_feat_matrix, batch_first=True)
            self.register_buffer('unit_feat_matrix',
                                 unit_feat_matrix.unsqueeze(dim=1))
            self.register_buffer('indexed_segments',
                                 torch.from_numpy(indexed_segments))
            # Use dummy length to avoid the trouble later on.
            # HACK(j_luo) Have to provide 'length'.
            self.unit_feat_matrix.rename_('unit', 'length', 'feat_group')
            self.indexed_segments.rename_('vocab', 'length')
            with Rename(self.unit_feat_matrix, unit='batch'):
                unit_dense_feat_matrix = convert_to_dense(
                    self.unit_feat_matrix)
            self.unit_dense_feat_matrix = {
                k: v.rename(batch='unit')
                for k, v in unit_dense_feat_matrix.items()
            }

        self.adapter = AdaptLayer()

        if g.input_format == 'text':
            self.g2p = G2PLayer(lu_size, len(self.id2unit))
예제 #7
0
 def segments(self) -> np.ndarray:
     ret = list()
     for bi in self.batch_indices.cpu().numpy():
         ret.append(self.orig_segments[bi])
     return get_array(ret)