def convert_to_dense(feat_matrix: LT) -> DenseFeatureMatrix: names = feat_matrix.names bs = feat_matrix.size('batch') ml = feat_matrix.size('length') fm = _g2f[feat_matrix.rename(None)].refine_names(*names) dfms = dict() for cat in Category: e = get_enum_by_cat(cat) dfm_idx = fm[..., cat.value] dfm = get_zeros(bs, ml, len(e), cpu=True) dfm = dfm.scatter(2, dfm_idx.rename(None).unsqueeze(dim=-1), 1.0) dfms[cat] = dfm.refine_names('batch', 'length', f'{cat.name}_feat') if has_gpus(): dfms = {k: v.cuda() for k, v in dfms.items()} return dfms
def finish_search(self, lengths: LT): last_beam_id = get_zeros(lengths.size('batch'), g.beam_size).long().rename('batch', 'beam') start_beam_id = get_named_range(g.beam_size, 'beam').align_as(last_beam_id) samples = list() for i, (hyp, beam_id) in enumerate( zip(reversed(self.hyps), reversed(self.beam_ids))): step = len(self.beam_ids) - i start_backtrack = (step == lengths).align_as(beam_id) # new_last_beam_id = beam_id.gather('beam', last_beam_id) this_beam_id = torch.where(start_backtrack, start_beam_id, last_beam_id) samples.append(hyp.gather('beam', this_beam_id)) last_beam_id = beam_id.gather('beam', this_beam_id) self.samples = torch.stack(samples[::-1], new_name='length') hyp_log_probs = torch.stack(self.hyp_log_probs, new_name='length') self.sample_log_probs = hyp_log_probs.gather( 'length', lengths.align_as(hyp_log_probs)).squeeze('length')
def search(self, sot_id: int, src_emb: FT, src_outputs: FT, src_paddings: BT, src_lengths: LT, beam_size: int, lang_emb: Optional[FT] = None) -> Hypotheses: if beam_size <= 0: raise ValueError(f'`beam_size` must be positive.') batch_size = src_emb.size('batch') tokens = torch.full([batch_size, beam_size], sot_id, dtype=torch.long).to(src_emb.device).rename( 'batch', 'beam') accum_scores = torch.full_like(tokens, -9999.9).float() accum_scores[:, 0] = 0.0 init_att = None if g.input_feeding: init_att = get_zeros(batch_size, beam_size, g.hidden_size).rename('batch', 'beam', 'hidden') lstm_state = LstmStatesByLayers.zero_state( self.cell.num_layers, batch_size, beam_size, self.attn.input_tgt_size, bidirectional=False, names=['batch', 'beam', 'hidden']) def expand_beam(orig, collapse: bool = True): if collapse: return torch.repeat_interleave(orig, beam_size, dim='batch') else: return duplicate(orig, 'batch', beam_size, 'beam') src_emb = expand_beam(src_emb) src_outputs = expand_beam(src_outputs) src_paddings = expand_beam(src_paddings) max_lengths = (src_lengths.float() * 1.5).long() max_lengths = expand_beam(max_lengths, collapse=False) constants = BeamConstant(src_emb, src_outputs, src_paddings, max_lengths, lang_emb=lang_emb) init_beam = Beam(0, accum_scores, tokens, lstm_state, constants, prev_att=init_att) hyps = super().search(init_beam) return hyps
def forward(self, curr_ids: LT, end_ids: LT, almts: Optional[Tuple[LT, LT]] = None): if g.repr_mode != 'state' and almts is None: raise RuntimeError( f'Must pass `almts` if `repr_mode` is not "state".') if g.repr_mode != 'state': curr_almts, end_almts = almts assert curr_almts.shape == curr_ids.shape assert end_almts.shape[1:] == end_ids.shape # NOTE(j_luo) +1 for 0-index, +1 for storing fake scattered values. max_len = max(curr_almts.max(), end_almts.max()) + 2 new_shape = curr_almts.shape[:-1] + (max_len, ) aligned_curr_ids = get_zeros(*new_shape).long().fill_(PAD_ID) aligned_end_ids = get_zeros(*new_shape).long().fill_(PAD_ID) with NoName(curr_almts, curr_ids, end_almts, end_ids): curr_mask = curr_almts == -1 curr_almts[curr_mask] = max_len - 1 end_mask = end_almts == -1 end_almts[end_mask] = max_len - 1 aligned_curr_ids.scatter_(-1, curr_almts, curr_ids) aligned_end_ids.scatter_(-1, end_almts, end_ids.expand_as(end_almts)) aligned_curr_ids = aligned_curr_ids.narrow( -1, 0, max_len - 1).rename('batch', 'word', 'pos') aligned_end_ids = aligned_end_ids.narrow( -1, 0, max_len - 1).rename('batch', 'word', 'pos') curr_char_emb = self._get_char_embedding(aligned_curr_ids) end_char_emb = self._get_char_embedding(aligned_end_ids) if g.repr_mode == 'char': state_repr = self._get_word_embedding_from_chars( curr_char_emb - end_char_emb).mean(dim='word') else: curr_word_emb = self._get_word_embedding_from_chars( curr_char_emb) end_word_emb = self._get_word_embedding_from_chars( end_char_emb) state_repr = (curr_word_emb - end_word_emb).mean(dim='word') else: word_repr = self._get_word_embedding(curr_ids) end_word_repr = self._get_word_embedding(end_ids) state_repr = (word_repr - end_word_repr).mean(dim='word') return state_repr
def search_by_probs(self, lengths: LT, label_log_probs: FT) -> Tuple[LT, FT]: max_length = lengths.max().item() bs = label_log_probs.size('batch') label_log_probs = label_log_probs.align_to('length', 'batch', 'label') beam = Beam(bs) for step in range(max_length): __label_log_probs = label_log_probs[step] # __lengths = lengths[step] within_length = (step < lengths).align_as( __label_log_probs) # __lengths beam.extend(__label_log_probs * within_length.float()) beam.finish_search(lengths) samples = beam.samples.rename(beam='sample') sample_log_probs = beam.sample_log_probs.rename(beam='sample') return samples, sample_log_probs
def search_by_probs(self, lengths: LT, label_log_probs: FT) -> Tuple[LT, FT]: max_length = lengths.max().item() samples = get_tensor( torch.LongTensor(list(product([B, I, O], repeat=max_length)))) samples.rename_('sample', 'length') bs = label_log_probs.size('batch') samples = samples.align_to('batch', 'sample', 'length').expand(bs, -1, -1) sample_log_probs = label_log_probs.gather('label', samples) with NoName(lengths): length_mask = get_length_mask(lengths, max_length).rename( 'batch', 'length') length_mask = length_mask.align_to(sample_log_probs) sample_log_probs = (sample_log_probs * length_mask.float()).sum(dim='length') return samples, sample_log_probs
def search(self, lengths: LT, label_log_probs: FT, gold_tag_seqs: Optional[LT] = None) -> Tuple[LT, FT]: samples, sample_log_probs = self.search_by_probs( lengths, label_log_probs) if gold_tag_seqs is not None: gold_tag_seqs = gold_tag_seqs.align_as(samples) max_length = lengths.max().item() with NoName(lengths): length_mask = get_length_mask(lengths, max_length).rename( 'batch', 'length') gold_log_probs = label_log_probs.gather('label', gold_tag_seqs) gold_log_probs = ( gold_log_probs * length_mask.align_as(gold_log_probs)).sum('length') samples = torch.cat([gold_tag_seqs, samples], dim='sample') sample_log_probs = torch.cat([gold_log_probs, sample_log_probs], dim='sample') return samples, sample_log_probs
def _get_matches(self, extracted_word_repr: FT, unit_repr: FT, viable_lens: LT, extracted_unit_ids: LT, char_log_probs: FT) -> Matches: ns = extracted_word_repr.size('viable') len_w = extracted_word_repr.size('len_w') nt = len(self.vocab_feat_matrix) msl = extracted_word_repr.size('len_w') mtl = self.vocab_feat_matrix.size('length') # Compute cosine distances all at once: for each viable span, compare it against all units. ctx_logits = extracted_word_repr @ unit_repr.t() ctx_log_probs = ctx_logits.log_softmax(dim='unit').flatten( ['viable', 'len_w'], 'viable_X_len_w') with NoName(char_log_probs, extracted_unit_ids): global_log_probs = char_log_probs[extracted_unit_ids].rename( 'viable_X_len_w', 'unit') weighted_log_probs = g.context_weight * ctx_log_probs + ( 1.0 - g.context_weight) * global_log_probs costs = -weighted_log_probs # Name: viable x len_w x unit costs = costs.unflatten('viable_X_len_w', [('viable', ns), ('len_w', len_w)]) # NOTE(j_luo) Use dictionary to save every state. fs = dict() for i in range(msl + 1): fs[(i, 0)] = get_zeros(ns, nt).fill_(i * self.ins_del_cost) for j in range(mtl + 1): fs[(0, j)] = get_zeros(ns, nt).fill_(j * self.ins_del_cost) # ------------------------ Main body: DP ----------------------- # # Transition. with NoName(self.indexed_segments, costs): for ls in range(1, msl + 1): min_lt = max(ls - 2, 1) max_lt = min(ls + 2, mtl + 1) for lt in range(min_lt, max_lt): transitions = list() if (ls - 1, lt) in fs: transitions.append(fs[(ls - 1, lt)] + self.ins_del_cost) if (ls, lt - 1) in fs: transitions.append(fs[(ls, lt - 1)] + self.ins_del_cost) if (ls - 1, lt - 1) in fs: vocab_inds = self.indexed_segments[:, lt - 1] sub_cost = costs[:, ls - 1, vocab_inds] transitions.append(fs[(ls - 1, lt - 1)] + sub_cost) if transitions: all_s = torch.stack(transitions, dim=-1) new_s, _ = all_s.min(dim=-1) fs[(ls, lt)] = new_s f_lst = list() for i in range(msl + 1): for j in range(mtl + 1): if (i, j) not in fs: fs[(i, j)] = get_zeros(ns, nt).fill_(9999.9) f_lst.append(fs[(i, j)]) f = torch.stack(f_lst, dim=0).view(msl + 1, mtl + 1, -1, len(self.vocab)) f.rename_('len_w_src', 'len_w_tgt', 'viable', 'vocab') # Get the values wanted. with NoName(f, viable_lens, self.vocab_length): idx_src = viable_lens.unsqueeze(dim=-1) idx_tgt = self.vocab_length viable_i = get_range(ns, 2, 0) vocab_i = get_range(len(self.vocab_length), 2, 1) nll = f[idx_src, idx_tgt, viable_i, vocab_i] nll.rename_('viable', 'vocab') # Get the best spans. matches = Matches(-nll, f) return matches