def _post_init_helper(self): bs, ml, nfg = self.feat_matrix.shape self.target_weight = get_zeros(bs, cpu=True).float().unsqueeze(dim=-1).repeat(1, nfg).fill_(1.0) # self.target_weight = self.target_weight.unsqueeze(dim=-1).repeat(1, 1, nfg).float() self.pos_to_predict = get_zeros(bs, cpu=True).long().fill_(g.window_size // 2) # NOTE(j_luo) This is global index. target_feat = self.feat_matrix[:, g.window_size // 2] # Get conversion matrix. if self._g2f is None: total = Index.total_indices() self._g2f = torch.LongTensor(total) indices = [Index.get_feature(i).value for i in range(total)] for index in indices: self._g2f[index.g_idx] = index.f_idx # NOTE(j_luo) This is feature index. self.target_feat = self._g2f[target_feat] # NOTE(j_luo) If the condition is not satisfied, the target weight should be set to 0. mask_out_target_weight(self.target_weight, self.target_feat) # NOTE(j_luo) Refine names. self.pos_to_predict = self.pos_to_predict.refine_names(self.batch_name) self.target_feat = self.target_feat.refine_names(self.batch_name, 'feat_group') self.target_weight = self.target_weight.refine_names(self.batch_name, 'feat_group') BaseBatch._post_init_helper(self)
def _get_word_score(self, packed_words: PackedWords, batch_size: int) -> FT: with torch.no_grad(): num_words = get_zeros(batch_size * packed_words.num_samples) bi = packed_words.batch_indices si = packed_words.sample_indices idx = (bi * packed_words.num_samples + si).rename(None) inc = get_zeros( packed_words.batch_indices.size('batch_word')).fill_(1.0) # TODO(j_luo) add scatter_add_ to named_tensor module num_words.scatter_add_(0, idx, inc) num_words = num_words.view(batch_size, packed_words.num_samples).refine_names( 'batch', 'sample') return num_words
def _unpack(self, nlls: FT, packed_words: PackedWords, batch_size: int) -> Tuple[FT, FT]: with torch.no_grad(): lm_loss = get_zeros(batch_size * packed_words.num_samples) bi = packed_words.batch_indices si = packed_words.sample_indices idx = (bi * packed_words.num_samples + si).rename(None) # TODO(j_luo) ugly lm_loss.scatter_add_(0, idx, nlls.rename(None)) lm_loss = lm_loss.view(batch_size, packed_words.num_samples).refine_names( 'batch', 'sample') in_vocab_score = get_zeros(batch_size * packed_words.num_samples) if self.vocab is not None: in_vocab_score.scatter_add_( 0, idx, packed_words.in_vocab.float().rename(None)) in_vocab_score = in_vocab_score.view( batch_size, packed_words.num_samples).refine_names('batch', 'sample') return -lm_loss, in_vocab_score # NOTE(j_luo) NLL are losses, not scores.
def _post_init_helper(self): super()._post_init_helper() names = self.feat_matrix.names bs = self.feat_matrix.size('batch') ml = self.feat_matrix.size('length') fm = self._g2f[self.feat_matrix.rename(None)].refine_names(*names) sfms = dict() for cat in Category: e = get_enum_by_cat(cat) sfm_idx = fm[..., cat.value] sfm = get_zeros(bs, ml, len(e), cpu=True) sfm = sfm.scatter(2, sfm_idx.rename(None).unsqueeze(dim=-1), 1.0) sfms[cat] = sfm.refine_names('batch', 'length', f'{cat.name}_feat') self.dense_feat_matrix = {k: v.cuda() for k, v in sfms.items()}
def score_distr(self, distr: Dict[Cat, FT], batch: IpaBatch) -> Dict[Cat, FT]: scores = dict() for name, output in distr.items(): i = get_index(name, new_style=g.new_style) target = batch.target_feat[:, i] weight = batch.target_weight[:, i] if g.weighted_loss == '': # log_probs = gather(output, target) log_probs = output.gather(name.value, target) score = -log_probs else: e = get_new_style_enum(i) mat = get_tensor(e.get_distance_matrix()) mat = mat[target.rename(None)] if g.weighted_loss == 'mr': mat_exp = torch.where(mat > 0, (mat + 1e-8).log(), get_zeros(mat.shape).fill_(-99.9)) logits = mat_exp + output # NOTE(j_luo) For the categories except Ptype, the sums of probs are not 1.0 (they are conditioned on certain values of Ptyle). # As a result, we need to incur penalties based on the remaining prob mass as well. # Specifically, the remaining prob mass will result in a penalty of 1.0, which is e^(0.0). none_probs = ( 1.0 - output.exp().sum(dim=-1, keepdims=True)).clamp(min=0.0) none_penalty = (1e-8 + none_probs).log().align_as(output) logits = torch.cat([logits, none_penalty], dim=-1) score = torch.logsumexp(logits, dim=-1).exp() elif g.weighted_loss == 'ot': if not self.training: raise RuntimeError('Cannot use OT for training.') probs = output.exp() # We have to incur penalties based on the remaining prob mass as well. none_probs = (1.0 - probs.sum(dim=-1, keepdims=True)).clamp( min=0.0) mat = torch.cat([ mat, get_tensor(torch.ones_like(none_probs.rename(None))) ], dim=-1) probs = torch.cat([probs, none_probs], dim=-1) score = (mat * probs).sum(dim=-1) else: raise ValueError(f'Cannot recognize {self.weighted_loss}.') scores[name] = (score, weight) return scores
def convert_to_dense(feat_matrix: LT) -> DenseFeatureMatrix: names = feat_matrix.names bs = feat_matrix.size('batch') ml = feat_matrix.size('length') fm = _g2f[feat_matrix.rename(None)].refine_names(*names) dfms = dict() for cat in Category: e = get_enum_by_cat(cat) dfm_idx = fm[..., cat.value] dfm = get_zeros(bs, ml, len(e), cpu=True) dfm = dfm.scatter(2, dfm_idx.rename(None).unsqueeze(dim=-1), 1.0) dfms[cat] = dfm.refine_names('batch', 'length', f'{cat.name}_feat') if has_gpus(): dfms = {k: v.cuda() for k, v in dfms.items()} return dfms
def _get_scores( self, samples: LT, segments: Sequence[SegmentWindow], lengths: LT, feat_matrix: LT, source_padding: BT ) -> Tuple[PackedWords, DecipherModelScoreReturn]: bs = len(segments) segment_list = None if self.vocab is not None: segment_list = [segment.segment_list for segment in segments] packed_words = self.pack(samples, lengths, feat_matrix, segments, segment_list=segment_list) packed_words.word_feat_matrices = self._adapt( packed_words.word_feat_matrices) try: lm_batch = self._prepare_batch( packed_words ) # TODO(j_luo) This is actually continous batching. scores = self._get_lm_scores(lm_batch) nlls = list() for cat, (nll, weight) in scores.items(): if should_include(g.feat_groups, cat): nlls.append(nll * weight) # nlls = sum(nlls) nlls = sum(nlls) / lm_batch.lengths bw = packed_words.word_lengths.size('batch_word') p = packed_words.word_positions.size('position') nlls = nlls.unflatten('batch', [('batch_word', bw), ('position', p)]) nlls = nlls.sum(dim='position') lm_score, in_vocab_score = self._unpack(nlls, packed_words, bs) except EmptyPackedWords: lm_score = get_zeros(bs, packed_words.num_samples) in_vocab_score = get_zeros(bs, packed_words.num_samples) word_score = self._get_word_score(packed_words, bs) readable_score, unreadable_score = self._get_readable_scores( source_padding, samples) scores = [ lm_score, word_score, in_vocab_score, readable_score, unreadable_score ] features = torch.stack(scores, new_name='feature') phi_score = self.phi_scorer(features).squeeze('score') # if g.search: # samples = samples.align_to('length', 'batch', 'sample') # flat_samples = samples.flatten(['batch', 'sample'], 'batch_X_sample') # flat_sample_embeddings = self.tag_embedding(flat_samples) # bxs = flat_samples.size('batch_X_sample') # h0 = get_zeros([1, bxs, 100]) # c0 = get_zeros([1, bxs, 100]) # with NoName(flat_sample_embeddings): # output, (hn, _) = self.tag_lstm(flat_sample_embeddings, (h0, c0)) # tag_score = self.tag_scorer(hn).squeeze(dim=0).squeeze(dim=-1) # tag_score = tag_score.view(samples.size('batch'), samples.size('sample')) # ret['tag_score'] = tag_score.rename('batch', 'sample') scores = DecipherModelScoreReturn(lm_score, word_score, in_vocab_score, readable_score, unreadable_score, phi_score) return packed_words, scores