def score_distr(self, distr: Dict[Cat, FT], batch: IpaBatch) -> Dict[Cat, FT]: scores = dict() for name, output in distr.items(): i = get_index(name, new_style=g.new_style) target = batch.target_feat[:, i] weight = batch.target_weight[:, i] if g.weighted_loss == '': # log_probs = gather(output, target) log_probs = output.gather(name.value, target) score = -log_probs else: e = get_new_style_enum(i) mat = get_tensor(e.get_distance_matrix()) mat = mat[target.rename(None)] if g.weighted_loss == 'mr': mat_exp = torch.where(mat > 0, (mat + 1e-8).log(), get_zeros(mat.shape).fill_(-99.9)) logits = mat_exp + output # NOTE(j_luo) For the categories except Ptype, the sums of probs are not 1.0 (they are conditioned on certain values of Ptyle). # As a result, we need to incur penalties based on the remaining prob mass as well. # Specifically, the remaining prob mass will result in a penalty of 1.0, which is e^(0.0). none_probs = ( 1.0 - output.exp().sum(dim=-1, keepdims=True)).clamp(min=0.0) none_penalty = (1e-8 + none_probs).log().align_as(output) logits = torch.cat([logits, none_penalty], dim=-1) score = torch.logsumexp(logits, dim=-1).exp() elif g.weighted_loss == 'ot': if not self.training: raise RuntimeError('Cannot use OT for training.') probs = output.exp() # We have to incur penalties based on the remaining prob mass as well. none_probs = (1.0 - probs.sum(dim=-1, keepdims=True)).clamp( min=0.0) mat = torch.cat([ mat, get_tensor(torch.ones_like(none_probs.rename(None))) ], dim=-1) probs = torch.cat([probs, none_probs], dim=-1) score = (mat * probs).sum(dim=-1) else: raise ValueError(f'Cannot recognize {self.weighted_loss}.') scores[name] = (score, weight) return scores
def __init__(self, hidden_size: Optional[int] = None): hidden_size = hidden_size or g.hidden_size super().__init__() self.linear = nn.Linear(hidden_size, hidden_size) self.feat_predictors = nn.ModuleDict() for e in get_needed_categories(g.feat_groups, new_style=g.new_style, breakdown=g.new_style): # NOTE(j_luo) ModuleDict can only handle str as keys. self.feat_predictors[e.__name__] = nn.Linear(hidden_size, len(e)) # If new_style, we need to get the necessary indices to convert the breakdown groups into the original feature groups. if g.new_style: self.conversion_idx = dict() for e in get_needed_categories(g.feat_groups, new_style=True, breakdown=False): if e.num_groups() > 1: cat_idx = list() for feat in e: feat_cat_idx = list() feat = feat.value for basic_feat in feat: auto_index = basic_feat.value feat_cat_idx.append(auto_index.f_idx) cat_idx.append(feat_cat_idx) cat_idx = get_tensor(cat_idx).refine_names( 'new_style_idx', 'old_style_idx') self.conversion_idx[e.__name__] = cat_idx
def __init__(self, feat_emb_name, group_name, char_emb_name, dim: int = 10): super().__init__() self.feat_emb_name = feat_emb_name self.group_name = group_name self.char_emb_name = char_emb_name self.dim = dim self.embed_layer = self._get_embeddings() self.register_buffer( 'c_idx', get_tensor( get_effective_c_idx()).refine_names('chosen_feat_group')) cat_enum_pairs = get_needed_categories(g.feat_groups, new_style=g.new_style, breakdown=g.new_style) if g.new_style: self.effective_num_feature_groups = sum( [e.num_groups() for e in cat_enum_pairs]) simple_conversions = np.zeros([g.num_features], dtype='int64') max_len = max( len(new_feat.value) for new_feat in conversions.values() if new_feat.value.is_complex()) complex_conversions = np.zeros([g.num_features, max_len], dtype='int64') for old_feat, new_feat in conversions.items(): if new_feat.value.is_complex(): l = len(new_feat.value) complex_conversions[old_feat.value.g_idx, :l] = [ x.value.g_idx for x in new_feat.value ] else: simple_conversions[ old_feat.value.g_idx] = new_feat.value.g_idx self.simple_conversions = get_tensor(simple_conversions) self.complex_conversions = get_tensor(complex_conversions) else: self.effective_num_feature_groups = len(cat_enum_pairs)
def pack(self, samples: LT, lengths: LT, feat_matrix: LT, segments: np.ndarray, segment_list: Optional[List[List[str]]] = None) -> PackedWords: with torch.no_grad(): feat_matrix = feat_matrix.align_to('batch', 'length', 'feat_group') samples = samples.align_to('batch', 'sample', 'length').int() ns = samples.size('sample') lengths = lengths.align_to('batch', 'sample').expand(-1, ns).int() batch_indices, sample_indices, word_positions, word_lengths, is_unique = extract_words( samples.cpu().numpy(), lengths.cpu().numpy(), num_threads=4) in_vocab = np.zeros_like(batch_indices, dtype=np.bool) if self.vocab is not None: in_vocab = check_in_vocab(batch_indices, word_positions, word_lengths, segment_list, self.vocab, num_threads=4) in_vocab = get_tensor(in_vocab).refine_names( 'batch_word').bool() batch_indices = get_tensor(batch_indices).refine_names( 'batch_word').long() sample_indices = get_tensor(sample_indices).refine_names( 'batch_word').long() word_positions = get_tensor(word_positions).refine_names( 'batch_word', 'position').long() word_lengths = get_tensor(word_lengths).refine_names( 'batch_word').long() is_unique = get_tensor(is_unique).refine_names('batch', 'sample').bool() key = (batch_indices.align_as(word_positions).rename(None), word_positions.rename(None)) word_feat_matrices = feat_matrix.rename(None)[key] word_feat_matrices = word_feat_matrices.refine_names( 'batch_word', 'position', 'feat_group') packed_words = PackedWords(word_feat_matrices, word_lengths, batch_indices, sample_indices, word_positions, is_unique, ns, segments, in_vocab=in_vocab) return packed_words
def forward( self, batch: Union[ContinuousIpaBatch, IpaBatch]) -> DecipherModelReturn: # Get the samples of label sequences first. out = self.emb_for_label(batch.feat_matrix, batch.source_padding) positions = get_named_range(batch.feat_matrix.size('length'), name='length') pos_emb = self.positional_embedding(positions).align_as(out) out = out + pos_emb out = out.align_to('length', 'batch', 'char_emb') with NoName(out, batch.source_padding): for i, layer in enumerate(self.self_attn_layers): out = layer(out, src_key_padding_mask=batch.source_padding) state = out.refine_names('length', 'batch', ...) logits = self.label_predictor(state) label_log_probs = logits.log_softmax(dim='label') label_probs = label_log_probs.exp() # NOTE(j_luo) O is equivalent to None. mask = expand_as(batch.source_padding, label_probs) source = expand_as( get_tensor([0.0, 0.0, 1.0]).refine_names('label').float(), label_probs) label_probs = label_probs.rename(None).masked_scatter( mask.rename(None), source.rename(None)) label_probs = label_probs.refine_names('length', 'batch', 'label') if not self.training or (g.supervised and not g.train_phi): probs = DecipherModelProbReturn(label_log_probs, None) return DecipherModelReturn(state, probs, None, None, None, None, None) # ------------------ More info during training ----------------- # # Get the lm score. gold_tag_seqs = batch.gold_tag_seqs if g.supervised and g.train_phi else None samples, sample_log_probs = self.searcher.search( batch.lengths, label_log_probs, gold_tag_seqs=gold_tag_seqs) probs = DecipherModelProbReturn(label_log_probs, sample_log_probs) packed_words, scores = self._get_scores(samples, batch.segments, batch.lengths, batch.feat_matrix, batch.source_padding) if g.supervised and g.train_phi: return DecipherModelReturn(state, probs, packed_words, None, scores, None, None) # ------------------- Contrastive estimation ------------------- # ptb_segments = list() duplicates = list() for segment in batch.segments: _ptb_segments, _duplicates = segment.perturb_n_times(g.n_times) # NOTE(j_luo) Ignore the first one. ptb_segments.extend(_ptb_segments[1:]) duplicates.extend(_duplicates[1:]) # ptb_segments = [segment.perturb_n_times(5) for segment in batch.segments] ptb_feat_matrix = [segment.feat_matrix for segment in ptb_segments] ptb_feat_matrix = torch.nn.utils.rnn.pad_sequence(ptb_feat_matrix, batch_first=True) ptb_feat_matrix.rename_('batch', 'length', 'feat_group') samples = samples.align_to('batch', ...) with NoName(samples, batch.lengths, batch.source_padding): ptb_samples = torch.repeat_interleave(samples, g.n_times * 2, dim=0) ptb_lengths = torch.repeat_interleave(batch.lengths, g.n_times * 2, dim=0) ptb_source_padding = torch.repeat_interleave(batch.source_padding, g.n_times * 2, dim=0) ptb_samples.rename_(*samples.names) ptb_lengths.rename_('batch') ptb_source_padding.rename_('batch', 'length') ptb_packed_words, ptb_scores = self._get_scores( ptb_samples, ptb_segments, ptb_lengths, ptb_feat_matrix, ptb_source_padding) ret = DecipherModelReturn(state, probs, packed_words, ptb_packed_words, scores, ptb_scores, duplicates) return ret