示例#1
0
    def score_distr(self, distr: Dict[Cat, FT],
                    batch: IpaBatch) -> Dict[Cat, FT]:
        scores = dict()
        for name, output in distr.items():
            i = get_index(name, new_style=g.new_style)
            target = batch.target_feat[:, i]
            weight = batch.target_weight[:, i]

            if g.weighted_loss == '':
                # log_probs = gather(output, target)
                log_probs = output.gather(name.value, target)
                score = -log_probs
            else:
                e = get_new_style_enum(i)
                mat = get_tensor(e.get_distance_matrix())
                mat = mat[target.rename(None)]
                if g.weighted_loss == 'mr':
                    mat_exp = torch.where(mat > 0, (mat + 1e-8).log(),
                                          get_zeros(mat.shape).fill_(-99.9))
                    logits = mat_exp + output
                    # NOTE(j_luo) For the categories except Ptype, the sums of probs are not 1.0 (they are conditioned on certain values of Ptyle).
                    # As a result, we need to incur penalties based on the remaining prob mass as well.
                    # Specifically, the remaining prob mass will result in a penalty of 1.0, which is e^(0.0).
                    none_probs = (
                        1.0 -
                        output.exp().sum(dim=-1, keepdims=True)).clamp(min=0.0)
                    none_penalty = (1e-8 + none_probs).log().align_as(output)
                    logits = torch.cat([logits, none_penalty], dim=-1)
                    score = torch.logsumexp(logits, dim=-1).exp()
                elif g.weighted_loss == 'ot':
                    if not self.training:
                        raise RuntimeError('Cannot use OT for training.')

                    probs = output.exp()
                    # We have to incur penalties based on the remaining prob mass as well.
                    none_probs = (1.0 -
                                  probs.sum(dim=-1, keepdims=True)).clamp(
                                      min=0.0)
                    mat = torch.cat([
                        mat,
                        get_tensor(torch.ones_like(none_probs.rename(None)))
                    ],
                                    dim=-1)
                    probs = torch.cat([probs, none_probs], dim=-1)
                    score = (mat * probs).sum(dim=-1)
                else:
                    raise ValueError(f'Cannot recognize {self.weighted_loss}.')
            scores[name] = (score, weight)
        return scores
示例#2
0
 def __init__(self, hidden_size: Optional[int] = None):
     hidden_size = hidden_size or g.hidden_size
     super().__init__()
     self.linear = nn.Linear(hidden_size, hidden_size)
     self.feat_predictors = nn.ModuleDict()
     for e in get_needed_categories(g.feat_groups,
                                    new_style=g.new_style,
                                    breakdown=g.new_style):
         # NOTE(j_luo) ModuleDict can only handle str as keys.
         self.feat_predictors[e.__name__] = nn.Linear(hidden_size, len(e))
     # If new_style, we need to get the necessary indices to convert the breakdown groups into the original feature groups.
     if g.new_style:
         self.conversion_idx = dict()
         for e in get_needed_categories(g.feat_groups,
                                        new_style=True,
                                        breakdown=False):
             if e.num_groups() > 1:
                 cat_idx = list()
                 for feat in e:
                     feat_cat_idx = list()
                     feat = feat.value
                     for basic_feat in feat:
                         auto_index = basic_feat.value
                         feat_cat_idx.append(auto_index.f_idx)
                     cat_idx.append(feat_cat_idx)
                 cat_idx = get_tensor(cat_idx).refine_names(
                     'new_style_idx', 'old_style_idx')
                 self.conversion_idx[e.__name__] = cat_idx
示例#3
0
    def __init__(self,
                 feat_emb_name,
                 group_name,
                 char_emb_name,
                 dim: int = 10):
        super().__init__()
        self.feat_emb_name = feat_emb_name
        self.group_name = group_name
        self.char_emb_name = char_emb_name
        self.dim = dim

        self.embed_layer = self._get_embeddings()
        self.register_buffer(
            'c_idx',
            get_tensor(
                get_effective_c_idx()).refine_names('chosen_feat_group'))
        cat_enum_pairs = get_needed_categories(g.feat_groups,
                                               new_style=g.new_style,
                                               breakdown=g.new_style)
        if g.new_style:
            self.effective_num_feature_groups = sum(
                [e.num_groups() for e in cat_enum_pairs])
            simple_conversions = np.zeros([g.num_features], dtype='int64')
            max_len = max(
                len(new_feat.value) for new_feat in conversions.values()
                if new_feat.value.is_complex())
            complex_conversions = np.zeros([g.num_features, max_len],
                                           dtype='int64')
            for old_feat, new_feat in conversions.items():
                if new_feat.value.is_complex():
                    l = len(new_feat.value)
                    complex_conversions[old_feat.value.g_idx, :l] = [
                        x.value.g_idx for x in new_feat.value
                    ]
                else:
                    simple_conversions[
                        old_feat.value.g_idx] = new_feat.value.g_idx
            self.simple_conversions = get_tensor(simple_conversions)
            self.complex_conversions = get_tensor(complex_conversions)
        else:
            self.effective_num_feature_groups = len(cat_enum_pairs)
示例#4
0
    def pack(self,
             samples: LT,
             lengths: LT,
             feat_matrix: LT,
             segments: np.ndarray,
             segment_list: Optional[List[List[str]]] = None) -> PackedWords:
        with torch.no_grad():
            feat_matrix = feat_matrix.align_to('batch', 'length', 'feat_group')
            samples = samples.align_to('batch', 'sample', 'length').int()
            ns = samples.size('sample')
            lengths = lengths.align_to('batch', 'sample').expand(-1, ns).int()
            batch_indices, sample_indices, word_positions, word_lengths, is_unique = extract_words(
                samples.cpu().numpy(), lengths.cpu().numpy(), num_threads=4)

            in_vocab = np.zeros_like(batch_indices, dtype=np.bool)
            if self.vocab is not None:
                in_vocab = check_in_vocab(batch_indices,
                                          word_positions,
                                          word_lengths,
                                          segment_list,
                                          self.vocab,
                                          num_threads=4)
                in_vocab = get_tensor(in_vocab).refine_names(
                    'batch_word').bool()

            batch_indices = get_tensor(batch_indices).refine_names(
                'batch_word').long()
            sample_indices = get_tensor(sample_indices).refine_names(
                'batch_word').long()
            word_positions = get_tensor(word_positions).refine_names(
                'batch_word', 'position').long()
            word_lengths = get_tensor(word_lengths).refine_names(
                'batch_word').long()
            is_unique = get_tensor(is_unique).refine_names('batch',
                                                           'sample').bool()

            key = (batch_indices.align_as(word_positions).rename(None),
                   word_positions.rename(None))
            word_feat_matrices = feat_matrix.rename(None)[key]
            word_feat_matrices = word_feat_matrices.refine_names(
                'batch_word', 'position', 'feat_group')
            packed_words = PackedWords(word_feat_matrices,
                                       word_lengths,
                                       batch_indices,
                                       sample_indices,
                                       word_positions,
                                       is_unique,
                                       ns,
                                       segments,
                                       in_vocab=in_vocab)
            return packed_words
示例#5
0
    def forward(
            self, batch: Union[ContinuousIpaBatch,
                               IpaBatch]) -> DecipherModelReturn:
        # Get the samples of label sequences first.
        out = self.emb_for_label(batch.feat_matrix, batch.source_padding)

        positions = get_named_range(batch.feat_matrix.size('length'),
                                    name='length')
        pos_emb = self.positional_embedding(positions).align_as(out)
        out = out + pos_emb
        out = out.align_to('length', 'batch', 'char_emb')
        with NoName(out, batch.source_padding):
            for i, layer in enumerate(self.self_attn_layers):
                out = layer(out, src_key_padding_mask=batch.source_padding)
        state = out.refine_names('length', 'batch', ...)
        logits = self.label_predictor(state)
        label_log_probs = logits.log_softmax(dim='label')
        label_probs = label_log_probs.exp()

        # NOTE(j_luo) O is equivalent to None.
        mask = expand_as(batch.source_padding, label_probs)
        source = expand_as(
            get_tensor([0.0, 0.0, 1.0]).refine_names('label').float(),
            label_probs)
        label_probs = label_probs.rename(None).masked_scatter(
            mask.rename(None), source.rename(None))
        label_probs = label_probs.refine_names('length', 'batch', 'label')

        if not self.training or (g.supervised and not g.train_phi):
            probs = DecipherModelProbReturn(label_log_probs, None)
            return DecipherModelReturn(state, probs, None, None, None, None,
                                       None)

        # ------------------ More info during training ----------------- #

        # Get the lm score.
        gold_tag_seqs = batch.gold_tag_seqs if g.supervised and g.train_phi else None
        samples, sample_log_probs = self.searcher.search(
            batch.lengths, label_log_probs, gold_tag_seqs=gold_tag_seqs)
        probs = DecipherModelProbReturn(label_log_probs, sample_log_probs)

        packed_words, scores = self._get_scores(samples, batch.segments,
                                                batch.lengths,
                                                batch.feat_matrix,
                                                batch.source_padding)

        if g.supervised and g.train_phi:
            return DecipherModelReturn(state, probs, packed_words, None,
                                       scores, None, None)

        # ------------------- Contrastive estimation ------------------- #

        ptb_segments = list()
        duplicates = list()
        for segment in batch.segments:
            _ptb_segments, _duplicates = segment.perturb_n_times(g.n_times)
            # NOTE(j_luo) Ignore the first one.
            ptb_segments.extend(_ptb_segments[1:])
            duplicates.extend(_duplicates[1:])
        # ptb_segments = [segment.perturb_n_times(5) for segment in batch.segments]
        ptb_feat_matrix = [segment.feat_matrix for segment in ptb_segments]
        ptb_feat_matrix = torch.nn.utils.rnn.pad_sequence(ptb_feat_matrix,
                                                          batch_first=True)
        ptb_feat_matrix.rename_('batch', 'length', 'feat_group')
        samples = samples.align_to('batch', ...)
        with NoName(samples, batch.lengths, batch.source_padding):
            ptb_samples = torch.repeat_interleave(samples,
                                                  g.n_times * 2,
                                                  dim=0)
            ptb_lengths = torch.repeat_interleave(batch.lengths,
                                                  g.n_times * 2,
                                                  dim=0)
            ptb_source_padding = torch.repeat_interleave(batch.source_padding,
                                                         g.n_times * 2,
                                                         dim=0)
        ptb_samples.rename_(*samples.names)
        ptb_lengths.rename_('batch')
        ptb_source_padding.rename_('batch', 'length')

        ptb_packed_words, ptb_scores = self._get_scores(
            ptb_samples, ptb_segments, ptb_lengths, ptb_feat_matrix,
            ptb_source_padding)

        ret = DecipherModelReturn(state, probs, packed_words, ptb_packed_words,
                                  scores, ptb_scores, duplicates)
        return ret