Example #1
0
def gumbel_softmax(logits: FT,
                   temperature: float,
                   num_samples: Optional[int] = None) -> Tuple[FT, FT, LT]:
    """Sample from the Gumbel-Softmax distribution and optionally discretize."""
    logits = logits.align_to('batch', 'length', 'label')
    y = gumbel_softmax_sample(logits, temperature, num_samples)
    y = y.align_to('batch', 'length', 'label', ...)
    max_values, max_inds = y.max(dim='label')
    y_one_hot = (max_values.align_as(y) == y).float()
    y_one_hot = (y_one_hot - y).detach() + y
    bi = get_named_range(logits.size('batch'), 'batch').align_as(max_inds)
    li = get_named_range(logits.size('length'), 'length').align_as(max_inds)
    if num_samples is None:
        with NoName(max_inds, y_one_hot, bi, li):
            probs = y_one_hot[bi, li, max_inds]
        probs.rename_('batch', 'length')
    else:
        si = get_named_range(max_inds.size('sample'),
                             'sample').align_as(max_inds)
        with NoName(max_inds, y_one_hot, bi, li, si):
            probs = y_one_hot[bi, li, max_inds, si]
        probs.rename_('batch', 'length', 'sample')
    seq_probs = (1e-8 + probs).log().sum(dim='length').exp()

    return y, y_one_hot, max_inds, seq_probs
Example #2
0
    def _train_one_step_mle(self, batch: OnePairBatch) -> Metrics:
        """Train for one step using maximum likelihood."""
        log_probs, almt_distrs = self.model(batch)

        metrics = Metrics()
        # Cross-entropy loss.
        ce_loss = get_ce_loss(log_probs, batch, agg='all')
        ce_loss = Metric('ce_loss', ce_loss, len(batch))
        metrics += ce_loss

        # Compute alignment regularization loss if needed.
        if g.almt_reg_hyper > 0:
            sl = almt_distrs.size("src_pos")
            pos = get_named_range(sl, 'src_pos').float()
            mean_pos = (pos.align_as(almt_distrs) *
                        almt_distrs).sum(dim='src_pos')
            mean_pos = mean_pos.align_to('batch', 'tgt_pos')
            mean_pos = torch.cat([get_zeros(len(batch), 1), mean_pos], dim=-1)
            src_lengths = batch.src_seqs.lengths.float().rename(None)
            reg_weight = src_lengths.unsqueeze(dim=-1) - 1.0 - mean_pos[:, :-1]
            reg_weight.clamp_(0.0, 1.0)
            rel_pos = mean_pos[:, 1:] - mean_pos[:, :-1]  # bs x tl
            rel_pos_diff = rel_pos - 1
            margin = rel_pos_diff != 0
            almt_reg = margin.float() * (rel_pos_diff**2)  # bs x tl
            almt_reg = (almt_reg * reg_weight).sum()
            almt_reg = Metric('almt_reg', almt_reg, len(batch))
            metrics += almt_reg

            loss = ce_loss.mean + g.almt_reg_hyper * almt_reg.mean
        else:
            loss = ce_loss.mean
        loss = Metric('loss', loss * len(batch), len(batch))
        metrics += loss
        return metrics
Example #3
0
 def trace_back(self, *attr_names: str) -> Dict[str, torch.Tensor]:
     """Trace back some attribute by going backwards through the beam search procedure."""
     beam_i = get_named_range(self.beam_size,
                              'beam').expand_as(self.beam_ids)
     batch_i = get_named_range(self.batch_size, 'batch').expand_as(beam_i)
     beam = self
     ret = defaultdict(list)
     while beam.last_beam is not None:
         with NoName(beam.beam_ids, beam_i, batch_i):
             for attr_name in attr_names:
                 attr = getattr(beam, attr_name)
                 with NoName(attr):
                     ret[attr_name].append(attr[batch_i, beam_i])
             beam_i = beam.beam_ids[batch_i, beam_i]
         beam = beam.last_beam
     for attr_name in attr_names:
         # NOTE(j_luo) Reverse the list since we are going backwards.
         last_name = 'src_pos' if attr_name == 'almt' else None
         ret[attr_name] = _stack_beam(ret[attr_name][::-1],
                                      last_name=last_name)
     return ret
Example #4
0
    def _sample(self,
                label_probs: FT,
                sampling_probs: FT,
                source_padding: FT,
                gold_tag_seqs: Optional[FT] = None) -> Tuple[LT, FT]:
        """Return samples based on `label_probs`."""
        # Ignore padded indices.
        label_probs = label_probs.align_to('batch', 'length', 'label')
        sampling_probs = sampling_probs.align_to('batch', 'length', 'label')
        source_padding = source_padding.align_to('batch', 'length')

        # Get packed batches.
        label_distr = Categorical(probs=sampling_probs.rename(None))
        label_samples = label_distr.sample([g.num_samples]).refine_names(
            'sample', 'batch', 'length')
        label_samples = label_samples.align_to('batch', 'sample', 'length')
        # Add the ground truth if needed.
        if gold_tag_seqs is not None:
            gold_tag_seqs = gold_tag_seqs.align_as(label_samples)
            all_other_tag_seqs = torch.full_like(gold_tag_seqs, O)
            label_samples = torch.cat(
                [gold_tag_seqs, all_other_tag_seqs, label_samples],
                dim='sample')
        batch_idx = get_named_range(
            label_samples.size('batch'),
            'batch').align_as(label_samples).rename(None)
        length_idx = get_named_range(
            label_samples.size('length'),
            'length').align_as(label_samples).rename(None)
        label_sample_probs = label_probs.rename(None)[
            batch_idx, length_idx,
            label_samples.rename(None)]
        label_sample_probs = label_sample_probs.refine_names(
            *label_samples.names)
        label_sample_log_probs = (1e-8 + label_sample_probs).log()
        label_sample_log_probs = (
            (~source_padding).align_as(label_sample_log_probs).float() *
            label_sample_log_probs).sum(dim='length')
        return label_samples, label_sample_log_probs
Example #5
0
    def get_next_beam(self, beam: Beam, cand: Candidates) -> Beam:
        nh = NameHelper()

        # Get the new scores. For finished hypotheses, we should keep adding EOT.
        placeholder = torch.full_like(cand.log_probs, -9999.9)
        placeholder[..., EOT_ID] = 0.0
        new_scores = torch.where(beam.finished.align_as(placeholder),
                                 placeholder, cand.log_probs)
        accum = new_scores + beam.accum_scores.align_as(cand.log_probs)
        lp = nh.flatten(accum, ['beam', 'unit'], 'BU')
        top_s, top_i = torch.topk(lp, beam.beam_size, dim='BU')
        num_units = accum.size('unit')
        beam_i = top_i // num_units
        tokens = top_i % num_units

        batch_i = get_named_range(beam.batch_size, 'batch')
        batch_i = batch_i.align_as(top_i)

        def retrieve(tensor, last_name: str = 'hidden') -> torch.Tensor:
            with NoName(tensor, batch_i, beam_i):
                ret = tensor[batch_i, beam_i]
            new_names = ('batch', 'beam')
            if last_name:
                new_names += (last_name, )
            return ret.refine_names(*new_names)

        next_scores = top_s.rename(BU='beam')
        next_tokens = tokens.rename(BU='beam')
        next_beam_ids = beam_i.rename(BU='beam')
        next_state = cand.state.apply(retrieve)
        next_almt = retrieve(cand.almt, last_name='tgt_pos')
        next_att = retrieve(cand.att,
                            last_name='hidden') if g.input_feeding else None
        last_finished = retrieve(beam.finished, last_name=None)
        this_ended = next_tokens == EOT_ID
        reached_max = (beam.step + 1 == beam.constants.max_lengths)
        next_finished = last_finished | this_ended | reached_max
        next_beam = beam.follow(next_finished,
                                next_scores,
                                next_tokens,
                                next_state,
                                next_beam_ids,
                                next_almt,
                                prev_att=next_att)
        return next_beam
Example #6
0
    def finish_search(self, lengths: LT):
        last_beam_id = get_zeros(lengths.size('batch'),
                                 g.beam_size).long().rename('batch', 'beam')
        start_beam_id = get_named_range(g.beam_size,
                                        'beam').align_as(last_beam_id)
        samples = list()
        for i, (hyp, beam_id) in enumerate(
                zip(reversed(self.hyps), reversed(self.beam_ids))):
            step = len(self.beam_ids) - i
            start_backtrack = (step == lengths).align_as(beam_id)
            # new_last_beam_id = beam_id.gather('beam', last_beam_id)
            this_beam_id = torch.where(start_backtrack, start_beam_id,
                                       last_beam_id)
            samples.append(hyp.gather('beam', this_beam_id))
            last_beam_id = beam_id.gather('beam', this_beam_id)
        self.samples = torch.stack(samples[::-1], new_name='length')

        hyp_log_probs = torch.stack(self.hyp_log_probs, new_name='length')
        self.sample_log_probs = hyp_log_probs.gather(
            'length', lengths.align_as(hyp_log_probs)).squeeze('length')
Example #7
0
    def _extract_one_span(self, batch: ExtractBatch, extracted: Extracted,
                          word_repr: FT, unit_repr: FT,
                          char_log_probs: FT) -> Extracted:
        # Propose all span start/end positions.
        start_candidates = get_named_range(batch.max_length, 'len_s').align_to(
            'batch', 'len_s', 'len_e')
        # Range from `min_word_length` to `max_word_length`.
        len_candidates = get_named_range(
            g.max_word_length + 1 - g.min_word_length,
            'len_e') + g.min_word_length
        len_candidates = len_candidates.align_to('batch', 'len_s', 'len_e')
        # This is inclusive.
        end_candidates = start_candidates + len_candidates - 1

        # Only keep the viable/valid spans around.
        viable = (end_candidates < batch.lengths.align_as(end_candidates))
        start_candidates = start_candidates.expand_as(viable)
        len_candidates = len_candidates.expand_as(viable)
        # NOTE(j_luo) Use `viable` to get the lengths. `len_candidates` has dummy axes.
        # IDEA(j_luo) Any better way of handling this? Perhaps persistent names?
        len_s = viable.size('len_s')
        len_e = viable.size('len_e')
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        with NoName(start_candidates, end_candidates, len_candidates, bi,
                    viable):
            viable_starts = start_candidates[viable].rename('viable')
            viable_lens = len_candidates[viable].rename('viable')
            viable_bi = bi[viable].rename('viable')

        # Get the word positions to get the corresponding representations.
        viable_starts = viable_starts.align_to('viable', 'len_w')
        word_pos_offsets = get_named_range(g.max_word_length,
                                           'len_w').align_as(viable_starts)
        word_pos = viable_starts + word_pos_offsets
        word_pos = word_pos.clamp(max=batch.max_length - 1)

        # Get the corresponding representations.
        nh = NameHelper()
        viable_bi = viable_bi.expand_as(word_pos)
        word_pos = nh.flatten(word_pos, ['viable', 'len_w'], 'viable_X_len_w')
        viable_bi = nh.flatten(viable_bi, ['viable', 'len_w'],
                               'viable_X_len_w')
        word_repr = word_repr.align_to('batch', 'length', 'char_emb')
        if g.input_format == 'text':
            with NoName(word_repr, viable_bi, word_pos, batch.unit_id_seqs):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
                extracted_unit_ids = batch.unit_id_seqs[
                    viable_bi, word_pos].rename('viable_X_len_w')
        else:
            with NoName(word_repr, viable_bi, word_pos):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
            extracted_unit_ids = None
        extracted_word_repr = nh.unflatten(extracted_word_repr,
                                           'viable_X_len_w',
                                           ['viable', 'len_w'])

        # Main body: Run DP to find the best matches.
        matches = self._get_matches(extracted_word_repr, unit_repr,
                                    viable_lens, extracted_unit_ids,
                                    char_log_probs)
        # Revert to the old shape (so that invalid spans are included).
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        lsi = get_named_range(len_s, 'len_s').expand_as(viable)
        lei = get_named_range(len_e, 'len_e').expand_as(viable)
        vs = matches.ll.size('vocab')
        # IDEA(j_luo) NoName shouldn't make size() calls unavaiable. Otherwise size() calls have to be moved outside the context. Also the names should be preserved as well.
        with NoName(bi, lsi, lei, viable, matches.ll):
            v_bi = bi[viable]
            v_lsi = lsi[viable]
            v_lei = lei[viable]
            all_ll = get_zeros(batch.batch_size, len_s, len_e, vs)
            all_ll = all_ll.float().fill_(-9999.9)
            all_ll[v_bi, v_lsi, v_lei] = matches.ll
            matches.ll = all_ll.rename('batch', 'len_s', 'len_e', 'vocab')

        new_extracted = Extracted(batch.batch_size, matches, viable,
                                  len_candidates)
        return new_extracted
Example #8
0
    def forward(
            self, batch: Union[ContinuousIpaBatch,
                               IpaBatch]) -> DecipherModelReturn:
        # Get the samples of label sequences first.
        out = self.emb_for_label(batch.feat_matrix, batch.source_padding)

        positions = get_named_range(batch.feat_matrix.size('length'),
                                    name='length')
        pos_emb = self.positional_embedding(positions).align_as(out)
        out = out + pos_emb
        out = out.align_to('length', 'batch', 'char_emb')
        with NoName(out, batch.source_padding):
            for i, layer in enumerate(self.self_attn_layers):
                out = layer(out, src_key_padding_mask=batch.source_padding)
        state = out.refine_names('length', 'batch', ...)
        logits = self.label_predictor(state)
        label_log_probs = logits.log_softmax(dim='label')
        label_probs = label_log_probs.exp()

        # NOTE(j_luo) O is equivalent to None.
        mask = expand_as(batch.source_padding, label_probs)
        source = expand_as(
            get_tensor([0.0, 0.0, 1.0]).refine_names('label').float(),
            label_probs)
        label_probs = label_probs.rename(None).masked_scatter(
            mask.rename(None), source.rename(None))
        label_probs = label_probs.refine_names('length', 'batch', 'label')

        if not self.training or (g.supervised and not g.train_phi):
            probs = DecipherModelProbReturn(label_log_probs, None)
            return DecipherModelReturn(state, probs, None, None, None, None,
                                       None)

        # ------------------ More info during training ----------------- #

        # Get the lm score.
        gold_tag_seqs = batch.gold_tag_seqs if g.supervised and g.train_phi else None
        samples, sample_log_probs = self.searcher.search(
            batch.lengths, label_log_probs, gold_tag_seqs=gold_tag_seqs)
        probs = DecipherModelProbReturn(label_log_probs, sample_log_probs)

        packed_words, scores = self._get_scores(samples, batch.segments,
                                                batch.lengths,
                                                batch.feat_matrix,
                                                batch.source_padding)

        if g.supervised and g.train_phi:
            return DecipherModelReturn(state, probs, packed_words, None,
                                       scores, None, None)

        # ------------------- Contrastive estimation ------------------- #

        ptb_segments = list()
        duplicates = list()
        for segment in batch.segments:
            _ptb_segments, _duplicates = segment.perturb_n_times(g.n_times)
            # NOTE(j_luo) Ignore the first one.
            ptb_segments.extend(_ptb_segments[1:])
            duplicates.extend(_duplicates[1:])
        # ptb_segments = [segment.perturb_n_times(5) for segment in batch.segments]
        ptb_feat_matrix = [segment.feat_matrix for segment in ptb_segments]
        ptb_feat_matrix = torch.nn.utils.rnn.pad_sequence(ptb_feat_matrix,
                                                          batch_first=True)
        ptb_feat_matrix.rename_('batch', 'length', 'feat_group')
        samples = samples.align_to('batch', ...)
        with NoName(samples, batch.lengths, batch.source_padding):
            ptb_samples = torch.repeat_interleave(samples,
                                                  g.n_times * 2,
                                                  dim=0)
            ptb_lengths = torch.repeat_interleave(batch.lengths,
                                                  g.n_times * 2,
                                                  dim=0)
            ptb_source_padding = torch.repeat_interleave(batch.source_padding,
                                                         g.n_times * 2,
                                                         dim=0)
        ptb_samples.rename_(*samples.names)
        ptb_lengths.rename_('batch')
        ptb_source_padding.rename_('batch', 'length')

        ptb_packed_words, ptb_scores = self._get_scores(
            ptb_samples, ptb_segments, ptb_lengths, ptb_feat_matrix,
            ptb_source_padding)

        ret = DecipherModelReturn(state, probs, packed_words, ptb_packed_words,
                                  scores, ptb_scores, duplicates)
        return ret