Ejemplo n.º 1
0
def gumbel_softmax(logits: FT,
                   temperature: float,
                   num_samples: Optional[int] = None) -> Tuple[FT, FT, LT]:
    """Sample from the Gumbel-Softmax distribution and optionally discretize."""
    logits = logits.align_to('batch', 'length', 'label')
    y = gumbel_softmax_sample(logits, temperature, num_samples)
    y = y.align_to('batch', 'length', 'label', ...)
    max_values, max_inds = y.max(dim='label')
    y_one_hot = (max_values.align_as(y) == y).float()
    y_one_hot = (y_one_hot - y).detach() + y
    bi = get_named_range(logits.size('batch'), 'batch').align_as(max_inds)
    li = get_named_range(logits.size('length'), 'length').align_as(max_inds)
    if num_samples is None:
        with NoName(max_inds, y_one_hot, bi, li):
            probs = y_one_hot[bi, li, max_inds]
        probs.rename_('batch', 'length')
    else:
        si = get_named_range(max_inds.size('sample'),
                             'sample').align_as(max_inds)
        with NoName(max_inds, y_one_hot, bi, li, si):
            probs = y_one_hot[bi, li, max_inds, si]
        probs.rename_('batch', 'length', 'sample')
    seq_probs = (1e-8 + probs).log().sum(dim='length').exp()

    return y, y_one_hot, max_inds, seq_probs
Ejemplo n.º 2
0
 def extend(self, label_log_probs: FT):
     num_labels = label_log_probs.size('label')
     label_log_probs = label_log_probs.align_to('batch', 'beam', 'label')
     new_hyp_log_probs = self.hyp_log_probs[-1].align_to(
         'batch', 'beam', 'label') + label_log_probs
     new_hyp_log_probs = new_hyp_log_probs.flatten(['beam', 'label'],
                                                   'beam_X_label')
     top_values, top_inds = torch.topk(new_hyp_log_probs, g.beam_size,
                                       'beam_X_label')
     beam_ids = top_inds // num_labels
     label_ids = top_inds % num_labels
     self.beam_ids.append(beam_ids.rename(beam_X_label='beam'))
     self.hyps.append(label_ids.rename(beam_X_label='beam'))
     self.hyp_log_probs.append(top_values.rename(beam_X_label='beam'))
Ejemplo n.º 3
0
    def _get_word_embedding_from_chars(self, emb: FT) -> FT:
        """Get word embeddings based on character embeddings."""
        if emb.ndim == 4:
            emb = emb.align_to('batch', 'word', 'emb', 'pos')
            bs, ws, es, l = emb.shape
            # NOTE(j_luo) embedding size might not match hidden size.
            emb_3d = emb.rename(None).reshape(bs * ws, es, -1)
            ret = self.cnn(emb_3d).max(dim=-1)[0]
            return ret.view(bs, ws, -1).rename('batch', 'word', 'emb')
        else:
            emb = emb.align_to('word', 'emb', 'pos')
            ret = self.cnn(emb.rename(None)).max(dim=-1)[0]
            return ret.rename('word', 'emb')

        return emb.mean(dim='pos')
Ejemplo n.º 4
0
 def search_by_probs(self, lengths: LT,
                     label_log_probs: FT) -> Tuple[LT, FT]:
     max_length = lengths.max().item()
     bs = label_log_probs.size('batch')
     label_log_probs = label_log_probs.align_to('length', 'batch', 'label')
     beam = Beam(bs)
     for step in range(max_length):
         __label_log_probs = label_log_probs[step]
         # __lengths = lengths[step]
         within_length = (step < lengths).align_as(
             __label_log_probs)  # __lengths
         beam.extend(__label_log_probs * within_length.float())
     beam.finish_search(lengths)
     samples = beam.samples.rename(beam='sample')
     sample_log_probs = beam.sample_log_probs.rename(beam='sample')
     return samples, sample_log_probs
Ejemplo n.º 5
0
 def search_by_probs(self, lengths: LT,
                     label_log_probs: FT) -> Tuple[LT, FT]:
     max_length = lengths.max().item()
     samples = get_tensor(
         torch.LongTensor(list(product([B, I, O], repeat=max_length))))
     samples.rename_('sample', 'length')
     bs = label_log_probs.size('batch')
     samples = samples.align_to('batch', 'sample',
                                'length').expand(bs, -1, -1)
     sample_log_probs = label_log_probs.gather('label', samples)
     with NoName(lengths):
         length_mask = get_length_mask(lengths, max_length).rename(
             'batch', 'length')
     length_mask = length_mask.align_to(sample_log_probs)
     sample_log_probs = (sample_log_probs *
                         length_mask.float()).sum(dim='length')
     return samples, sample_log_probs
Ejemplo n.º 6
0
    def search(self,
               sot_id: int,
               src_emb: FT,
               src_outputs: FT,
               src_paddings: BT,
               src_lengths: LT,
               beam_size: int,
               lang_emb: Optional[FT] = None) -> Hypotheses:
        if beam_size <= 0:
            raise ValueError(f'`beam_size` must be positive.')

        batch_size = src_emb.size('batch')
        tokens = torch.full([batch_size, beam_size], sot_id,
                            dtype=torch.long).to(src_emb.device).rename(
                                'batch', 'beam')
        accum_scores = torch.full_like(tokens, -9999.9).float()
        accum_scores[:, 0] = 0.0
        init_att = None
        if g.input_feeding:
            init_att = get_zeros(batch_size, beam_size,
                                 g.hidden_size).rename('batch', 'beam',
                                                       'hidden')
        lstm_state = LstmStatesByLayers.zero_state(
            self.cell.num_layers,
            batch_size,
            beam_size,
            self.attn.input_tgt_size,
            bidirectional=False,
            names=['batch', 'beam', 'hidden'])

        def expand_beam(orig, collapse: bool = True):
            if collapse:
                return torch.repeat_interleave(orig, beam_size, dim='batch')
            else:
                return duplicate(orig, 'batch', beam_size, 'beam')

        src_emb = expand_beam(src_emb)
        src_outputs = expand_beam(src_outputs)
        src_paddings = expand_beam(src_paddings)
        max_lengths = (src_lengths.float() * 1.5).long()
        max_lengths = expand_beam(max_lengths, collapse=False)
        constants = BeamConstant(src_emb,
                                 src_outputs,
                                 src_paddings,
                                 max_lengths,
                                 lang_emb=lang_emb)
        init_beam = Beam(0,
                         accum_scores,
                         tokens,
                         lstm_state,
                         constants,
                         prev_att=init_att)
        hyps = super().search(init_beam)
        return hyps
Ejemplo n.º 7
0
def gumbel_softmax_sample(logits: FT,
                          temperature: float,
                          num_samples: Optional[int] = None) -> FT:
    """Draw a sample from the Gumbel-Softmax distribution"""
    new_names = logits.names
    shape = tuple(logits.shape)
    if num_samples is not None:
        new_names += ('sample', )
        shape += (num_samples, )
    noise = sample_gumbel(shape).rename(*new_names)
    y = logits.align_as(noise) + noise
    return (y / temperature).log_softmax(dim='label').exp()
Ejemplo n.º 8
0
def get_ce_loss(log_probs: FT, batch: OnePairBatch, agg='all') -> FT:
    ce_losses = -log_probs.gather('unit', batch.tgt_seqs.ids)
    weights = batch.tgt_seqs.paddings.float()
    ce_losses = ce_losses * weights
    if agg == 'batch':
        return ce_losses.sum(dim='pos')
    elif agg == 'batch_mean':
        return ce_losses.sum(dim='pos') / weights.sum(dim='pos')
    elif agg == 'all':
        return ce_losses.sum()
    elif agg == 'char':
        return ce_losses
    elif agg == 'char_mean':
        return ce_losses.sum() / weights.sum()
    else:
        raise ValueError(f'Unrecognized value "{agg}" for agg.')
Ejemplo n.º 9
0
    def search(self,
               lengths: LT,
               label_log_probs: FT,
               gold_tag_seqs: Optional[LT] = None) -> Tuple[LT, FT]:
        samples, sample_log_probs = self.search_by_probs(
            lengths, label_log_probs)
        if gold_tag_seqs is not None:
            gold_tag_seqs = gold_tag_seqs.align_as(samples)

            max_length = lengths.max().item()
            with NoName(lengths):
                length_mask = get_length_mask(lengths, max_length).rename(
                    'batch', 'length')
            gold_log_probs = label_log_probs.gather('label', gold_tag_seqs)
            gold_log_probs = (
                gold_log_probs *
                length_mask.align_as(gold_log_probs)).sum('length')

            samples = torch.cat([gold_tag_seqs, samples], dim='sample')
            sample_log_probs = torch.cat([gold_log_probs, sample_log_probs],
                                         dim='sample')
        return samples, sample_log_probs
Ejemplo n.º 10
0
def get_beam_probs(scores: FT, duplicates: Optional[BT] = None):
    """Return normalized scores (approximated probabilities) for the entire beam."""
    if duplicates is not None:
        scores = scores.masked_fill(duplicates, float('-inf'))
    return (scores / g.concentration_scale).log_softmax(dim='beam').exp()
Ejemplo n.º 11
0
 def _get_Wh_s(self, h_s: FT) -> FT:
     sl, bs, ds = h_s.size()
     with NoName(h_s):
         Wh_s = h_s.reshape(sl * bs, -1).mm(self.Wa).view(sl, bs, -1)
     return Wh_s
Ejemplo n.º 12
0
def _compute_utility(logits: FT, sample_scores: FT) -> FT:
    sample_log_probs = logits.log_softmax(dim='sample')
    utility = (sample_log_probs.exp() * sample_scores).sum()
    return utility
Ejemplo n.º 13
0
    def _get_matches(self, extracted_word_repr: FT, unit_repr: FT,
                     viable_lens: LT, extracted_unit_ids: LT,
                     char_log_probs: FT) -> Matches:
        ns = extracted_word_repr.size('viable')
        len_w = extracted_word_repr.size('len_w')
        nt = len(self.vocab_feat_matrix)
        msl = extracted_word_repr.size('len_w')
        mtl = self.vocab_feat_matrix.size('length')

        # Compute cosine distances all at once: for each viable span, compare it against all units.
        ctx_logits = extracted_word_repr @ unit_repr.t()
        ctx_log_probs = ctx_logits.log_softmax(dim='unit').flatten(
            ['viable', 'len_w'], 'viable_X_len_w')
        with NoName(char_log_probs, extracted_unit_ids):
            global_log_probs = char_log_probs[extracted_unit_ids].rename(
                'viable_X_len_w', 'unit')
        weighted_log_probs = g.context_weight * ctx_log_probs + (
            1.0 - g.context_weight) * global_log_probs
        costs = -weighted_log_probs

        # Name: viable x len_w x unit
        costs = costs.unflatten('viable_X_len_w', [('viable', ns),
                                                   ('len_w', len_w)])

        # NOTE(j_luo) Use dictionary to save every state.
        fs = dict()
        for i in range(msl + 1):
            fs[(i, 0)] = get_zeros(ns, nt).fill_(i * self.ins_del_cost)
        for j in range(mtl + 1):
            fs[(0, j)] = get_zeros(ns, nt).fill_(j * self.ins_del_cost)

        # ------------------------ Main body: DP ----------------------- #

        # Transition.
        with NoName(self.indexed_segments, costs):
            for ls in range(1, msl + 1):
                min_lt = max(ls - 2, 1)
                max_lt = min(ls + 2, mtl + 1)
                for lt in range(min_lt, max_lt):
                    transitions = list()
                    if (ls - 1, lt) in fs:
                        transitions.append(fs[(ls - 1, lt)] +
                                           self.ins_del_cost)
                    if (ls, lt - 1) in fs:
                        transitions.append(fs[(ls, lt - 1)] +
                                           self.ins_del_cost)
                    if (ls - 1, lt - 1) in fs:
                        vocab_inds = self.indexed_segments[:, lt - 1]
                        sub_cost = costs[:, ls - 1, vocab_inds]
                        transitions.append(fs[(ls - 1, lt - 1)] + sub_cost)
                    if transitions:
                        all_s = torch.stack(transitions, dim=-1)
                        new_s, _ = all_s.min(dim=-1)
                        fs[(ls, lt)] = new_s

        f_lst = list()
        for i in range(msl + 1):
            for j in range(mtl + 1):
                if (i, j) not in fs:
                    fs[(i, j)] = get_zeros(ns, nt).fill_(9999.9)
                f_lst.append(fs[(i, j)])
        f = torch.stack(f_lst, dim=0).view(msl + 1, mtl + 1, -1,
                                           len(self.vocab))
        f.rename_('len_w_src', 'len_w_tgt', 'viable', 'vocab')

        # Get the values wanted.
        with NoName(f, viable_lens, self.vocab_length):
            idx_src = viable_lens.unsqueeze(dim=-1)
            idx_tgt = self.vocab_length
            viable_i = get_range(ns, 2, 0)
            vocab_i = get_range(len(self.vocab_length), 2, 1)
            nll = f[idx_src, idx_tgt, viable_i, vocab_i]
            nll.rename_('viable', 'vocab')

        # Get the best spans.
        matches = Matches(-nll, f)
        return matches
Ejemplo n.º 14
0
    def _extract_one_span(self, batch: ExtractBatch, extracted: Extracted,
                          word_repr: FT, unit_repr: FT,
                          char_log_probs: FT) -> Extracted:
        # Propose all span start/end positions.
        start_candidates = get_named_range(batch.max_length, 'len_s').align_to(
            'batch', 'len_s', 'len_e')
        # Range from `min_word_length` to `max_word_length`.
        len_candidates = get_named_range(
            g.max_word_length + 1 - g.min_word_length,
            'len_e') + g.min_word_length
        len_candidates = len_candidates.align_to('batch', 'len_s', 'len_e')
        # This is inclusive.
        end_candidates = start_candidates + len_candidates - 1

        # Only keep the viable/valid spans around.
        viable = (end_candidates < batch.lengths.align_as(end_candidates))
        start_candidates = start_candidates.expand_as(viable)
        len_candidates = len_candidates.expand_as(viable)
        # NOTE(j_luo) Use `viable` to get the lengths. `len_candidates` has dummy axes.
        # IDEA(j_luo) Any better way of handling this? Perhaps persistent names?
        len_s = viable.size('len_s')
        len_e = viable.size('len_e')
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        with NoName(start_candidates, end_candidates, len_candidates, bi,
                    viable):
            viable_starts = start_candidates[viable].rename('viable')
            viable_lens = len_candidates[viable].rename('viable')
            viable_bi = bi[viable].rename('viable')

        # Get the word positions to get the corresponding representations.
        viable_starts = viable_starts.align_to('viable', 'len_w')
        word_pos_offsets = get_named_range(g.max_word_length,
                                           'len_w').align_as(viable_starts)
        word_pos = viable_starts + word_pos_offsets
        word_pos = word_pos.clamp(max=batch.max_length - 1)

        # Get the corresponding representations.
        nh = NameHelper()
        viable_bi = viable_bi.expand_as(word_pos)
        word_pos = nh.flatten(word_pos, ['viable', 'len_w'], 'viable_X_len_w')
        viable_bi = nh.flatten(viable_bi, ['viable', 'len_w'],
                               'viable_X_len_w')
        word_repr = word_repr.align_to('batch', 'length', 'char_emb')
        if g.input_format == 'text':
            with NoName(word_repr, viable_bi, word_pos, batch.unit_id_seqs):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
                extracted_unit_ids = batch.unit_id_seqs[
                    viable_bi, word_pos].rename('viable_X_len_w')
        else:
            with NoName(word_repr, viable_bi, word_pos):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
            extracted_unit_ids = None
        extracted_word_repr = nh.unflatten(extracted_word_repr,
                                           'viable_X_len_w',
                                           ['viable', 'len_w'])

        # Main body: Run DP to find the best matches.
        matches = self._get_matches(extracted_word_repr, unit_repr,
                                    viable_lens, extracted_unit_ids,
                                    char_log_probs)
        # Revert to the old shape (so that invalid spans are included).
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        lsi = get_named_range(len_s, 'len_s').expand_as(viable)
        lei = get_named_range(len_e, 'len_e').expand_as(viable)
        vs = matches.ll.size('vocab')
        # IDEA(j_luo) NoName shouldn't make size() calls unavaiable. Otherwise size() calls have to be moved outside the context. Also the names should be preserved as well.
        with NoName(bi, lsi, lei, viable, matches.ll):
            v_bi = bi[viable]
            v_lsi = lsi[viable]
            v_lei = lei[viable]
            all_ll = get_zeros(batch.batch_size, len_s, len_e, vs)
            all_ll = all_ll.float().fill_(-9999.9)
            all_ll[v_bi, v_lsi, v_lei] = matches.ll
            matches.ll = all_ll.rename('batch', 'len_s', 'len_e', 'vocab')

        new_extracted = Extracted(batch.batch_size, matches, viable,
                                  len_candidates)
        return new_extracted