Beispiel #1
0
    def _get_word_embedding_from_chars(self, emb: FT) -> FT:
        """Get word embeddings based on character embeddings."""
        if emb.ndim == 4:
            emb = emb.align_to('batch', 'word', 'emb', 'pos')
            bs, ws, es, l = emb.shape
            # NOTE(j_luo) embedding size might not match hidden size.
            emb_3d = emb.rename(None).reshape(bs * ws, es, -1)
            ret = self.cnn(emb_3d).max(dim=-1)[0]
            return ret.view(bs, ws, -1).rename('batch', 'word', 'emb')
        else:
            emb = emb.align_to('word', 'emb', 'pos')
            ret = self.cnn(emb.rename(None)).max(dim=-1)[0]
            return ret.rename('word', 'emb')

        return emb.mean(dim='pos')
Beispiel #2
0
def gumbel_softmax(logits: FT,
                   temperature: float,
                   num_samples: Optional[int] = None) -> Tuple[FT, FT, LT]:
    """Sample from the Gumbel-Softmax distribution and optionally discretize."""
    logits = logits.align_to('batch', 'length', 'label')
    y = gumbel_softmax_sample(logits, temperature, num_samples)
    y = y.align_to('batch', 'length', 'label', ...)
    max_values, max_inds = y.max(dim='label')
    y_one_hot = (max_values.align_as(y) == y).float()
    y_one_hot = (y_one_hot - y).detach() + y
    bi = get_named_range(logits.size('batch'), 'batch').align_as(max_inds)
    li = get_named_range(logits.size('length'), 'length').align_as(max_inds)
    if num_samples is None:
        with NoName(max_inds, y_one_hot, bi, li):
            probs = y_one_hot[bi, li, max_inds]
        probs.rename_('batch', 'length')
    else:
        si = get_named_range(max_inds.size('sample'),
                             'sample').align_as(max_inds)
        with NoName(max_inds, y_one_hot, bi, li, si):
            probs = y_one_hot[bi, li, max_inds, si]
        probs.rename_('batch', 'length', 'sample')
    seq_probs = (1e-8 + probs).log().sum(dim='length').exp()

    return y, y_one_hot, max_inds, seq_probs
Beispiel #3
0
 def extend(self, label_log_probs: FT):
     num_labels = label_log_probs.size('label')
     label_log_probs = label_log_probs.align_to('batch', 'beam', 'label')
     new_hyp_log_probs = self.hyp_log_probs[-1].align_to(
         'batch', 'beam', 'label') + label_log_probs
     new_hyp_log_probs = new_hyp_log_probs.flatten(['beam', 'label'],
                                                   'beam_X_label')
     top_values, top_inds = torch.topk(new_hyp_log_probs, g.beam_size,
                                       'beam_X_label')
     beam_ids = top_inds // num_labels
     label_ids = top_inds % num_labels
     self.beam_ids.append(beam_ids.rename(beam_X_label='beam'))
     self.hyps.append(label_ids.rename(beam_X_label='beam'))
     self.hyp_log_probs.append(top_values.rename(beam_X_label='beam'))
Beispiel #4
0
 def search_by_probs(self, lengths: LT,
                     label_log_probs: FT) -> Tuple[LT, FT]:
     max_length = lengths.max().item()
     bs = label_log_probs.size('batch')
     label_log_probs = label_log_probs.align_to('length', 'batch', 'label')
     beam = Beam(bs)
     for step in range(max_length):
         __label_log_probs = label_log_probs[step]
         # __lengths = lengths[step]
         within_length = (step < lengths).align_as(
             __label_log_probs)  # __lengths
         beam.extend(__label_log_probs * within_length.float())
     beam.finish_search(lengths)
     samples = beam.samples.rename(beam='sample')
     sample_log_probs = beam.sample_log_probs.rename(beam='sample')
     return samples, sample_log_probs
Beispiel #5
0
    def _extract_one_span(self, batch: ExtractBatch, extracted: Extracted,
                          word_repr: FT, unit_repr: FT,
                          char_log_probs: FT) -> Extracted:
        # Propose all span start/end positions.
        start_candidates = get_named_range(batch.max_length, 'len_s').align_to(
            'batch', 'len_s', 'len_e')
        # Range from `min_word_length` to `max_word_length`.
        len_candidates = get_named_range(
            g.max_word_length + 1 - g.min_word_length,
            'len_e') + g.min_word_length
        len_candidates = len_candidates.align_to('batch', 'len_s', 'len_e')
        # This is inclusive.
        end_candidates = start_candidates + len_candidates - 1

        # Only keep the viable/valid spans around.
        viable = (end_candidates < batch.lengths.align_as(end_candidates))
        start_candidates = start_candidates.expand_as(viable)
        len_candidates = len_candidates.expand_as(viable)
        # NOTE(j_luo) Use `viable` to get the lengths. `len_candidates` has dummy axes.
        # IDEA(j_luo) Any better way of handling this? Perhaps persistent names?
        len_s = viable.size('len_s')
        len_e = viable.size('len_e')
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        with NoName(start_candidates, end_candidates, len_candidates, bi,
                    viable):
            viable_starts = start_candidates[viable].rename('viable')
            viable_lens = len_candidates[viable].rename('viable')
            viable_bi = bi[viable].rename('viable')

        # Get the word positions to get the corresponding representations.
        viable_starts = viable_starts.align_to('viable', 'len_w')
        word_pos_offsets = get_named_range(g.max_word_length,
                                           'len_w').align_as(viable_starts)
        word_pos = viable_starts + word_pos_offsets
        word_pos = word_pos.clamp(max=batch.max_length - 1)

        # Get the corresponding representations.
        nh = NameHelper()
        viable_bi = viable_bi.expand_as(word_pos)
        word_pos = nh.flatten(word_pos, ['viable', 'len_w'], 'viable_X_len_w')
        viable_bi = nh.flatten(viable_bi, ['viable', 'len_w'],
                               'viable_X_len_w')
        word_repr = word_repr.align_to('batch', 'length', 'char_emb')
        if g.input_format == 'text':
            with NoName(word_repr, viable_bi, word_pos, batch.unit_id_seqs):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
                extracted_unit_ids = batch.unit_id_seqs[
                    viable_bi, word_pos].rename('viable_X_len_w')
        else:
            with NoName(word_repr, viable_bi, word_pos):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
            extracted_unit_ids = None
        extracted_word_repr = nh.unflatten(extracted_word_repr,
                                           'viable_X_len_w',
                                           ['viable', 'len_w'])

        # Main body: Run DP to find the best matches.
        matches = self._get_matches(extracted_word_repr, unit_repr,
                                    viable_lens, extracted_unit_ids,
                                    char_log_probs)
        # Revert to the old shape (so that invalid spans are included).
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        lsi = get_named_range(len_s, 'len_s').expand_as(viable)
        lei = get_named_range(len_e, 'len_e').expand_as(viable)
        vs = matches.ll.size('vocab')
        # IDEA(j_luo) NoName shouldn't make size() calls unavaiable. Otherwise size() calls have to be moved outside the context. Also the names should be preserved as well.
        with NoName(bi, lsi, lei, viable, matches.ll):
            v_bi = bi[viable]
            v_lsi = lsi[viable]
            v_lei = lei[viable]
            all_ll = get_zeros(batch.batch_size, len_s, len_e, vs)
            all_ll = all_ll.float().fill_(-9999.9)
            all_ll[v_bi, v_lsi, v_lei] = matches.ll
            matches.ll = all_ll.rename('batch', 'len_s', 'len_e', 'vocab')

        new_extracted = Extracted(batch.batch_size, matches, viable,
                                  len_candidates)
        return new_extracted