Example #1
0
    def _forward_step(self,
                      input_: LT,
                      src_emb: FT,
                      state: LstmStatesByLayers,
                      src_states: FT,
                      mask_src: BT,
                      lang_emb: Optional[FT] = None,
                      prev_att: Optional[FT] = None) -> Tuple[FT, FT, FT, FT]:
        emb = self.char_emb(input_)
        if lang_emb is not None:
            emb = emb + lang_emb
        inp = torch.cat([emb, prev_att], dim=-1) if g.input_feeding else emb
        hid_rnn, next_state = self.cell(
            inp, state)  # hid_rnn has gone through dropout already.
        almt, ctx = self.attn.forward(hid_rnn, src_states,
                                      mask_src)  # So has src_states.
        with NoName(hid_rnn, ctx):
            cat = torch.cat([hid_rnn, ctx], dim=-1)
        hid_cat = self.hidden(cat)
        hid_cat = self.drop(hid_cat)

        with NoName(src_emb, hid_cat, almt):
            ctx_emb = (src_emb * almt.t().unsqueeze(dim=-1)).sum(dim=0)
            hid_res = self.nc_residual(ctx_emb,
                                       hid_cat).rename('batch', 'hidden')

        logit = self.char_emb.project(hid_res)
        log_prob = logit.log_softmax(dim=-1).refine_names('batch', 'unit')

        return next_state, log_prob, almt, hid_res
Example #2
0
def gumbel_softmax(logits: FT,
                   temperature: float,
                   num_samples: Optional[int] = None) -> Tuple[FT, FT, LT]:
    """Sample from the Gumbel-Softmax distribution and optionally discretize."""
    logits = logits.align_to('batch', 'length', 'label')
    y = gumbel_softmax_sample(logits, temperature, num_samples)
    y = y.align_to('batch', 'length', 'label', ...)
    max_values, max_inds = y.max(dim='label')
    y_one_hot = (max_values.align_as(y) == y).float()
    y_one_hot = (y_one_hot - y).detach() + y
    bi = get_named_range(logits.size('batch'), 'batch').align_as(max_inds)
    li = get_named_range(logits.size('length'), 'length').align_as(max_inds)
    if num_samples is None:
        with NoName(max_inds, y_one_hot, bi, li):
            probs = y_one_hot[bi, li, max_inds]
        probs.rename_('batch', 'length')
    else:
        si = get_named_range(max_inds.size('sample'),
                             'sample').align_as(max_inds)
        with NoName(max_inds, y_one_hot, bi, li, si):
            probs = y_one_hot[bi, li, max_inds, si]
        probs.rename_('batch', 'length', 'sample')
    seq_probs = (1e-8 + probs).log().sum(dim='length').exp()

    return y, y_one_hot, max_inds, seq_probs
Example #3
0
    def forward(self, h_t: FT, h_s: FT, mask_src: BT) -> Tuple[FT, FT]:
        dt = h_t.shape[-1]
        Wh_s = self._get_Wh_s(h_s)

        with NoName(h_t):
            scores = (Wh_s * h_t).sum(dim=-1)

        scores = torch.where(mask_src, scores,
                             torch.full_like(scores, -9999.9))
        almt_distr = nn.functional.log_softmax(scores, dim=0).exp()  # sl x bs
        with NoName(almt_distr):
            ctx = (almt_distr.unsqueeze(dim=-1) * h_s).sum(dim=0)  # bs x d
        almt_distr = almt_distr.t()
        return almt_distr, ctx
Example #4
0
    def forward(self, input_: LT, lengths: LT) -> Tuple[FT, LstmOutputTuple]:
        # input_: seq_length x batch_size
        # note that input_size == hidden_size
        # define n_conv as the number of parallel convolutional layers
        emb = self.embedding(input_)  # seq_length x batch_size x input_size

        with NoName(emb, lengths):
            reshaped_emb = emb.permute(
                1, 2, 0
            )  # reshape to batch_size x input_size x seq_length for CNN input
            conv_outputs = [
                self.dropout(F.relu(conv(reshaped_emb)))
                for conv in self.conv_layers
            ]
            # each conv layer's output is batch_size x hidden_size x seq_length

            # stack the CNN outputs on the hidden_size dimension
            x = torch.cat(
                conv_outputs,
                dim=1)  # batch_size x n_conv*hidden_size x seq_length
            x = x.permute(2, 0,
                          1)  # seq_length x batch_size x n_conv*hidden_size

            # project the concatenated convolutional layer outputs into 2*hidden_size dimensions so that `output` looks as though it were the states of a bidirectional lstm
            output = self.W_output(
                x)  # seq_length x batch_size x 2*hidden_size
            # we don't try to reconstruct the state, so we just pass (None, None)
        return emb, (output, (None, None))
Example #5
0
 def forward(self,
             curr_ids: LT,
             end_ids: LT,
             steps: Optional[LT] = None,
             done: Optional[BT] = None) -> FT:
     """Get policy evaluation. if `done` is provided, we get values for s1 instead of s0.
     In that case, end states should have values set to 0.
     `step` should start with 0.
     """
     state_repr = self.enc(curr_ids, end_ids)
     # NOTE(j_luo) If s1 is being evaluated, we should increment `step`.
     if done is not None and g.use_finite_horizon:
         steps = steps + 1
     with NoName(state_repr, steps):
         if g.use_finite_horizon:
             rel_step = steps.float() / g.max_rollout_length
             state_repr = torch.cat(
                 [state_repr, rel_step.unsqueeze(dim=-1)], dim=-1)
         values = self.regressor(state_repr).squeeze(dim=-1)
     # Deal with special cases. We start with final step case, and then overwrite it if done.
     if g.use_finite_horizon:
         final_step = steps == g.max_rollout_length
         values = torch.where(final_step, torch.zeros_like(values), values)
     if done is not None:
         # NOTE(j_luo) Use final reward for the value of the end state.
         values = torch.where(done, torch.full_like(values, g.final_reward),
                              values)
     return values
Example #6
0
def _restore_shape(tensor,
                   bi,
                   lsi,
                   lei,
                   viable,
                   value: Optional[float] = None):
    bs = bi.size('batch')
    len_s = lsi.size('len_s')
    len_e = lei.size('len_e')

    shape = (bs, len_s, len_e)
    names = ('batch', 'len_s', 'len_e')
    if tensor.ndim > 1:
        shape += tensor.shape[1:]
        names += tensor.names[1:]

    with NoName(bi, lsi, lei, viable, tensor):
        v_bi = bi[viable]
        v_lsi = lsi[viable]
        v_lei = lei[viable]
        ret = get_zeros(*shape).to(tensor.dtype)
        if value is not None:
            ret.fill_(value)
        ret[v_bi, v_lsi, v_lei] = tensor

    ret.rename_(*names)
    return ret
Example #7
0
    def _analyze_unsupervised(self, model_ret: DecipherModelReturn,
                              batch: ContinuousIpaBatch) -> Metrics:
        metrics = Metrics()
        # TODO(j_luo) Check the sample scores for hyps that are dummies (i.e., the length of the segment is too small to get beam_size hyps).
        is_unique = model_ret.packed_words.is_unique
        modified_logits = model_ret.probs.sample_log_probs * g.concentration + (
            ~is_unique).float() * (-999.9)
        sample_scores = model_ret.scores.phi_score
        ptb_sample_scores = model_ret.ptb_scores.phi_score
        duplicates = model_ret.duplicates
        with NoName(ptb_sample_scores):
            ptb_sample_scores[duplicates] = -999.9
        bs = sample_scores.size('batch')
        ptb_sample_scores = ptb_sample_scores.unflatten(
            'batch', [('batch', bs), ('contrast', g.n_times * 2)])
        sample_scores = sample_scores.align_as(ptb_sample_scores)
        all_scores = torch.cat([sample_scores, ptb_sample_scores],
                               dim='contrast')
        all_probs = all_scores.log_softmax(dim='contrast').exp()
        sample_probs = all_probs.align_to(..., 'contrast')[..., 0]
        utility = _compute_utility(modified_logits, sample_probs)
        total_loss = Metric('total_loss', -utility, batch.batch_size)
        metrics += total_loss

        return metrics
Example #8
0
 def retrieve(tensor, last_name: str = 'hidden') -> torch.Tensor:
     with NoName(tensor, batch_i, beam_i):
         ret = tensor[batch_i, beam_i]
     new_names = ('batch', 'beam')
     if last_name:
         new_names += (last_name, )
     return ret.refine_names(*new_names)
Example #9
0
    def forward(self,
                dense_feat_matrices: Dict[Category, FT],
                padding: Optional[BT] = None,
                masked_positions: Optional[LT] = None) -> FT:
        if padding is not None:
            padding = padding.align_to('batch', 'length')

        embs = list()
        for cat in Category:
            if cat.name in self.embed_layer and cat in dense_feat_matrices:
                sfm = dense_feat_matrices[cat]
                emb_param = self.embed_layer[cat.name]
                sfm = sfm.align_to('batch', 'length', ...)
                emb = sfm @ emb_param
                if padding is not None:
                    emb.rename(None)[padding.rename(None)] = 0.0
                embs.append(emb)
        feat_emb = torch.cat(embs,
                             dim=-1).refine_names('batch', 'length',
                                                  self.char_emb_name)

        if masked_positions is not None:
            batch_i = get_range(padding.size('batch'), 1, 0)
            feat_emb = feat_emb.align_to('batch', 'char_emb', 'length')
            # feat_emb = self.feat_embeddings(feat_matrix).view(bs, l, -1).transpose(1, 2)  # size: bs x D x l
            with NoName(feat_emb, masked_positions):
                feat_emb[batch_i, :, masked_positions] = 0.0
        return feat_emb
Example #10
0
    def forward(self,
                sot_id: int,
                src_emb: FT,
                src_outputs: FT,
                mask_src: BT,
                max_length: Optional[int] = None,
                target: Optional[LT] = None,
                lang_emb: Optional[FT] = None) -> Tuple[FT, FT]:
        # Prepare inputs.
        max_length = self._get_max_length(max_length, target)
        batch_size = mask_src.size('batch')
        input_ = self._prepare_first_input(sot_id, batch_size, mask_src.device)
        prev_att = get_zeros(batch_size,
                             g.hidden_size) if g.input_feeding else None
        state = LstmStatesByLayers.zero_state(self.cell.num_layers,
                                              batch_size,
                                              self.attn.input_tgt_size,
                                              bidirectional=False)

        # Main loop.
        log_probs = list()
        almt_distrs = list()
        with ScopedCache('Wh_s'):
            for l in range(max_length):
                state, log_prob, almt_distr, prev_att = self._forward_step(
                    input_,
                    src_emb,
                    state,
                    src_outputs,
                    mask_src,
                    lang_emb=lang_emb,
                    prev_att=prev_att)
                if target is None:
                    input_ = log_prob.max(dim=-1)[1].rename('batch')
                else:
                    input_ = target[l]

                log_probs.append(log_prob)
                almt_distrs.append(almt_distr)

        # Prepare outputs.
        with NoName(*log_probs), NoName(*almt_distrs):
            log_probs = torch.stack(log_probs).rename('pos', 'batch', 'unit')
            almt_distrs = torch.stack(almt_distrs).rename(
                'tgt_pos', 'batch', 'src_pos')
        return log_probs, almt_distrs
Example #11
0
def _stack_beam(lst: List[torch.Tensor], last_name=None):
    new_names = ('batch', 'beam', 'pos')
    if last_name:
        new_names += (last_name, )
    with NoName(*lst):
        # NOTE(j_luo) Set dim = 2 instead of -1 since some tensors might have an extra dimension.
        ret = torch.stack(lst, dim=2).refine_names(*new_names)
    return ret
Example #12
0
    def forward(self, ku_id_seqs: LT, lu_repr: FT) -> Tuple[FT, FT]:
        """Returns lu x ku representation and bs x l x ku representation."""
        ku_char_weight = self.unit_aligner.weight
        ku_char_repr = ku_char_weight @ lu_repr

        ku_char_repr = ku_char_repr.refine_names('ku_char_emb', 'char_emb')
        with NoName(ku_char_repr, ku_id_seqs):
            _ku_repr = ku_char_repr[ku_id_seqs].rename('batch', 'length',
                                                       'char_emb')
        _ku_repr = _ku_repr.align_to('batch', 'char_emb', ...)
        with NoName(_ku_repr):
            ku_ctx_repr = self.conv(_ku_repr).rename('batch', 'char_emb',
                                                     'length')
        ku_ctx_repr = ku_ctx_repr.align_to(..., 'char_emb')
        ku_ctx_repr = self.dropout(ku_ctx_repr)

        return ku_char_repr, ku_ctx_repr
Example #13
0
    def evaluate(self,
                 states,
                 steps: Optional[Union[int, LT]] = None) -> List[float]:
        """Expand and evaluate the leaf node."""
        values = [None] * len(states)
        outstanding_idx = list()
        outstanding_states = list()
        # Deal with end states first.
        for i, state in enumerate(states):
            if state.stopped or state.done:
                # NOTE(j_luo) This value is used for backup. If already reaching the end state, the final reward is either accounted for by the step reward, or by the value network. Therefore, we need to set it to 0.0 here.
                values[i] = 0.0
            else:
                outstanding_idx.append(i)
                outstanding_states.append(state)

        # Collect states that need evaluation.
        if outstanding_states:
            almts1 = almts2 = None
            if g.use_alignment:
                id_seqs, almts1, almts2 = parallel_stack_ids(
                    outstanding_states, g.num_workers, True,
                    self.env.max_end_length)
                almts1 = get_tensor(almts1).rename('batch', 'word', 'pos')
                almts2 = get_tensor(almts2).rename('batch', 'word', 'pos')
            else:
                id_seqs = parallel_stack_ids(outstanding_states, g.num_workers,
                                             False, self.env.max_end_length)
            id_seqs = get_tensor(id_seqs).rename('batch', 'word', 'pos')
            if steps is not None and not isinstance(steps, int):
                steps = steps[outstanding_idx]

            # TODO(j_luo) Scoped might be wrong here.
            # with ScopedCache('state_repr'):
            # NOTE(j_luo) Don't forget to call exp().
            priors = self.agent.get_policy(id_seqs,
                                           almts=(almts1, almts2)).exp()
            with NoName(priors):
                meta_priors = priors[:, [0, 2, 3, 4, 5, 6]].cpu().numpy()
                special_priors = priors[:, 1].cpu().numpy()
            if g.use_value_guidance:
                agent_values = self.agent.get_values(
                    id_seqs, steps=steps).cpu().numpy()
            else:
                agent_values = np.zeros([len(id_seqs)], dtype='float32')

            for i, state, mp, sp, v in zip(outstanding_idx, outstanding_states,
                                           meta_priors, special_priors,
                                           agent_values):
                # NOTE(j_luo) Values should be returned even if states are duplicates or have been visited.
                values[i] = v
                # NOTE(j_luo) Skip duplicate states (due to exploration collapse) or visited states (due to rollout truncation).
                if not state.is_leaf():
                    continue

                # print(mp[1, 111])
                self.env.evaluate(state, mp, sp)
        return values
Example #14
0
    def get_scores(self,
                   batch: OnePairBatch,
                   tgt_vocab_seqs: PaddedUnitSeqs,
                   chunk_size: int = 100) -> FT:
        """Given a batch and a list of target tokens (provided as id sequences), return scores produced by the model."""
        src_emb, (output, state) = self.encoder(batch.src_seqs.ids,
                                                batch.src_seqs.lengths)
        src_emb = src_emb.refine_names('pos', 'batch', 'src_emb')
        output = output.refine_names('pos', 'batch', 'output')
        batch_size = src_emb.size('batch')
        lang_emb = self._prepare_lang_emb(batch)

        def create_chunk(size, base, old_chunk, interleave: bool = True):
            if not interleave:
                return base.repeat(1, batch_size)

            if old_chunk is not None and old_chunk.size(
                    'batch') == batch_size * size:
                return old_chunk

            new_chunk = torch.repeat_interleave(base, size, dim='batch')
            return new_chunk

        chunk_src_emb = None
        chunk_output = None
        chunk_src_paddings = None
        scores = list()
        for split in pbar(tgt_vocab_seqs.split(chunk_size),
                          desc='Get scores: chunk'):
            split: PaddedUnitSeqs
            bs_split = len(split)
            chunk_src_emb = create_chunk(bs_split, src_emb, chunk_src_emb)
            chunk_output = create_chunk(bs_split, output, chunk_output)
            chunk_src_paddings = create_chunk(bs_split,
                                              batch.src_seqs.paddings,
                                              chunk_src_paddings)
            chunk_target = create_chunk(None,
                                        split.ids,
                                        None,
                                        interleave=False)
            chunk_tgt_paddings = create_chunk(None,
                                              split.paddings,
                                              None,
                                              interleave=False)
            chunk_log_probs, _ = self.decoder(SOT_ID,
                                              chunk_src_emb,
                                              chunk_output,
                                              chunk_src_paddings,
                                              target=chunk_target,
                                              lang_emb=lang_emb)
            chunk_scores = chunk_log_probs.gather('unit', chunk_target)
            chunk_scores = (chunk_scores * chunk_tgt_paddings).sum('pos')
            with NoName(chunk_scores):
                scores.append(
                    chunk_scores.view(batch_size, bs_split).refine_names(
                        'batch', 'tgt_vocab'))
        scores = torch.cat(scores, dim='tgt_vocab')
        return scores
Example #15
0
 def trace_back(self, *attr_names: str) -> Dict[str, torch.Tensor]:
     """Trace back some attribute by going backwards through the beam search procedure."""
     beam_i = get_named_range(self.beam_size,
                              'beam').expand_as(self.beam_ids)
     batch_i = get_named_range(self.batch_size, 'batch').expand_as(beam_i)
     beam = self
     ret = defaultdict(list)
     while beam.last_beam is not None:
         with NoName(beam.beam_ids, beam_i, batch_i):
             for attr_name in attr_names:
                 attr = getattr(beam, attr_name)
                 with NoName(attr):
                     ret[attr_name].append(attr[batch_i, beam_i])
             beam_i = beam.beam_ids[batch_i, beam_i]
         beam = beam.last_beam
     for attr_name in attr_names:
         # NOTE(j_luo) Reverse the list since we are going backwards.
         last_name = 'src_pos' if attr_name == 'almt' else None
         ret[attr_name] = _stack_beam(ret[attr_name][::-1],
                                      last_name=last_name)
     return ret
Example #16
0
    def forward(self,
                inp: FT,
                sparse: bool = False,
                indices: Optional[NDA] = None) -> Tuple[FT, FT]:
        is_2d = inp.ndim == 2
        if g.use_conditional and not is_2d:
            raise RuntimeError(f'Not sure why you end up here.')
        # assert True, 'Cannot deal with dense action space.'
        assert is_2d

        potentials = self.potential_block(inp)
        with NoName(potentials):
            potentials = potentials.view(-1, 7, len(self.env.abc))
            return potentials.rename('batch', 'phase', 'action')
Example #17
0
 def forward(self, input_: LT, lengths: LT) -> Tuple[FT, LstmOutputTuple]:
     emb = self.embedding(input_)
     with NoName(emb, lengths):
         packed_emb = pack_padded_sequence(emb,
                                           lengths,
                                           enforce_sorted=False)
         output, state = self.lstm(packed_emb)
         output = pad_packed_sequence(output)[0]
         output = self.drop(
             output
         )  # Dropout after last output, different from the behavior for nn.LSTM.
     return emb, (output,
                  LstmStateTuple(state,
                                 bidirectional=self.lstm.bidirectional))
Example #18
0
    def forward(self,
                curr_ids: LT,
                end_ids: LT,
                almts: Optional[Tuple[LT, LT]] = None):
        if g.repr_mode != 'state' and almts is None:
            raise RuntimeError(
                f'Must pass `almts` if `repr_mode` is not "state".')

        if g.repr_mode != 'state':
            curr_almts, end_almts = almts
            assert curr_almts.shape == curr_ids.shape
            assert end_almts.shape[1:] == end_ids.shape
            # NOTE(j_luo) +1 for 0-index, +1 for storing fake scattered values.
            max_len = max(curr_almts.max(), end_almts.max()) + 2
            new_shape = curr_almts.shape[:-1] + (max_len, )
            aligned_curr_ids = get_zeros(*new_shape).long().fill_(PAD_ID)
            aligned_end_ids = get_zeros(*new_shape).long().fill_(PAD_ID)

            with NoName(curr_almts, curr_ids, end_almts, end_ids):
                curr_mask = curr_almts == -1
                curr_almts[curr_mask] = max_len - 1

                end_mask = end_almts == -1
                end_almts[end_mask] = max_len - 1

                aligned_curr_ids.scatter_(-1, curr_almts, curr_ids)
                aligned_end_ids.scatter_(-1, end_almts,
                                         end_ids.expand_as(end_almts))

            aligned_curr_ids = aligned_curr_ids.narrow(
                -1, 0, max_len - 1).rename('batch', 'word', 'pos')
            aligned_end_ids = aligned_end_ids.narrow(
                -1, 0, max_len - 1).rename('batch', 'word', 'pos')
            curr_char_emb = self._get_char_embedding(aligned_curr_ids)
            end_char_emb = self._get_char_embedding(aligned_end_ids)
            if g.repr_mode == 'char':
                state_repr = self._get_word_embedding_from_chars(
                    curr_char_emb - end_char_emb).mean(dim='word')
            else:
                curr_word_emb = self._get_word_embedding_from_chars(
                    curr_char_emb)
                end_word_emb = self._get_word_embedding_from_chars(
                    end_char_emb)
                state_repr = (curr_word_emb - end_word_emb).mean(dim='word')
        else:
            word_repr = self._get_word_embedding(curr_ids)
            end_word_repr = self._get_word_embedding(end_ids)
            state_repr = (word_repr - end_word_repr).mean(dim='word')
        return state_repr
Example #19
0
 def split(self, size: int) -> List[PaddedUnitSeqs]:
     with NoName(self.ids, self.paddings):
         ids_lst = self.ids.split(size, dim=-1)
         paddings_lst = self.paddings.split(size, dim=-1)
     start = 0
     ret = list()
     for ids, paddings in zip(ids_lst, paddings_lst):
         length = ids.size(1)
         units = self.units[start: start + length]
         forms = self.forms[start: start + length]
         split = PaddedUnitSeqs(self.lang, forms, units, ids, paddings,
                                lang_id=self.lang_id)
         ret.append(split)
         start += length
     assert start == self.ids.size('batch')
     return ret
Example #20
0
 def search_by_probs(self, lengths: LT,
                     label_log_probs: FT) -> Tuple[LT, FT]:
     max_length = lengths.max().item()
     samples = get_tensor(
         torch.LongTensor(list(product([B, I, O], repeat=max_length))))
     samples.rename_('sample', 'length')
     bs = label_log_probs.size('batch')
     samples = samples.align_to('batch', 'sample',
                                'length').expand(bs, -1, -1)
     sample_log_probs = label_log_probs.gather('label', samples)
     with NoName(lengths):
         length_mask = get_length_mask(lengths, max_length).rename(
             'batch', 'length')
     length_mask = length_mask.align_to(sample_log_probs)
     sample_log_probs = (sample_log_probs *
                         length_mask.float()).sum(dim='length')
     return samples, sample_log_probs
Example #21
0
    def forward(self,
                input_: FT,
                state: LstmStatesByLayers,
                state_direction: Optional[str] = None) -> LstmOutputsByLayers:
        assert state.num_layers == self.num_layers

        new_states = list()
        for i in range(self.num_layers):
            h, c = state.get_layer(i, state_direction)
            with NoName(input_, h, c):
                new_h, new_c = self.cells[i](input_, (h, c))
            new_h.rename_(*h.names)
            new_c.rename_(*c.names)
            new_states.append((new_h, new_c))
            input_ = new_h.refine_names('batch', ...)
            # Note that the last layer also uses dropout, which is different from nn.LSTM.
            input_ = self.drop(input_)
        return input_, LstmStatesByLayers(new_states)
Example #22
0
 def forward(self, feat_matrix: LT, pos_to_predict: LT,
             source_padding: BT) -> FT:
     bs = source_padding.size('batch')
     l = source_padding.size('length')
     batch_i = get_range(bs, 1, 0)
     feat_emb = self.feat_embedding(feat_matrix,
                                    source_padding,
                                    masked_positions=pos_to_predict)
     feat_emb = feat_emb.align_to('batch', 'char_emb', 'length')
     output = self.conv_layers(feat_emb.rename(None))
     output = output.refine_names('batch', 'char_conv_repr',
                                  'length')  # size: bs x D x l
     output = self.linear(output.align_to(
         ..., 'char_conv_repr'))  # size: bs x l x n_hid
     output = output.refine_names('batch', 'length', 'hidden_repr')
     output = nn.functional.leaky_relu(output, negative_slope=0.1)
     # NOTE(j_luo) This is actually quite wasteful because we are discarding all the irrelevant information, which is computed anyway. This is equivalent to training on ngrams.
     with NoName(output, pos_to_predict):
         h = output[batch_i, pos_to_predict]
     h = h.refine_names('batch', 'hidden_repr')  # size: bs x n_hid
     return h
Example #23
0
    def forward(self,
                feat_matrix: LT,
                padding: Optional[BT] = None,
                masked_positions: Optional[LT] = None) -> FT:
        feat_matrix = adv_index(feat_matrix, 'feat_group', self.c_idx)
        # Convert old style to new style ipa features.
        if g.new_style:
            new_feat_matrix = list()
            for c_idx, one_feat_group in zip(
                    self.c_idx.unbind(dim=self.group_name),
                    feat_matrix.unbind(dim=self.group_name)):
                one_feat_group = one_feat_group.rename(None)
                new_enum = get_new_style_enum(c_idx.item())
                l = new_enum.num_groups()
                if l > 1:
                    new_feat_matrix.append(
                        self.complex_conversions[one_feat_group][..., :l])
                else:
                    new_feat_matrix.append(
                        self.simple_conversions[one_feat_group].unsqueeze(
                            dim=-1))
            new_feat_matrix = torch.cat(
                new_feat_matrix, dim=-1).refine_names(*feat_matrix.names)
            feat_matrix = new_feat_matrix
        feat_emb = embed(self.embed_layer, feat_matrix, self.feat_emb_name)
        feat_emb = feat_emb.flatten([self.group_name, self.feat_emb_name],
                                    self.char_emb_name)
        feat_emb = feat_emb.align_to('batch', 'length', self.char_emb_name)
        if padding is not None:
            padding = padding.align_to('batch', 'length')
            feat_emb.rename(None)[padding.rename(None)] = 0.0

        if masked_positions is not None:
            batch_i = get_range(padding.size('batch'), 1, 0)
            feat_emb = feat_emb.align_to('batch', 'char_emb', 'length')
            # feat_emb = self.feat_embeddings(feat_matrix).view(bs, l, -1).transpose(1, 2)  # size: bs x D x l
            with NoName(feat_emb, masked_positions):
                feat_emb[batch_i, :, masked_positions] = 0.0
        return feat_emb
Example #24
0
    def search(self,
               lengths: LT,
               label_log_probs: FT,
               gold_tag_seqs: Optional[LT] = None) -> Tuple[LT, FT]:
        samples, sample_log_probs = self.search_by_probs(
            lengths, label_log_probs)
        if gold_tag_seqs is not None:
            gold_tag_seqs = gold_tag_seqs.align_as(samples)

            max_length = lengths.max().item()
            with NoName(lengths):
                length_mask = get_length_mask(lengths, max_length).rename(
                    'batch', 'length')
            gold_log_probs = label_log_probs.gather('label', gold_tag_seqs)
            gold_log_probs = (
                gold_log_probs *
                length_mask.align_as(gold_log_probs)).sum('length')

            samples = torch.cat([gold_tag_seqs, samples], dim='sample')
            sample_log_probs = torch.cat([gold_log_probs, sample_log_probs],
                                         dim='sample')
        return samples, sample_log_probs
Example #25
0
 def forward(self, positions: LT):
     with NoName(self.embeddings, positions):
         ret = self.embeddings[positions]
     new_names = positions.names + ('char_emb', )
     return ret.refine_names(*new_names)
Example #26
0
    def forward(self, batch: ExtractBatch) -> ExtractModelReturn:
        """
        The generating story is:
            v
            |
            w
            |
            x -- ww -- theta

        Pr(x) = sum_w Pr(w) Pr(ww)
              = sum_w Pr(w) theta^|ww|
              = sum_{w, v} Pr(w | v) Pr(v) theta^|ww|

        Terminologies:
        matched_: the prefix after selecting v
        score: after multiplication with |w|
        best_: the prefix after selecting w
        """
        # Prepare representations.
        alignment = None
        if g.dense_input:
            # IDEA(j_luo) NoName shouldn't use reveal_name. Just keep the name in the context manager.
            with NoName(*self.unit_dense_feat_matrix.values()):
                unit_repr = torch.cat([
                    self.unit_dense_feat_matrix[cat]
                    for cat in self.effective_categories
                ],
                                      dim=-1)
            unit_repr = unit_repr.rename('batch', 'length',
                                         'char_emb').squeeze(dim='length')

            if g.input_format == 'text':
                ku_char_repr, word_repr = self.g2p(batch.unit_id_seqs,
                                                   unit_repr)
                char_log_probs = (ku_char_repr @ unit_repr.t()).log_softmax(
                    dim=-1)
                alignment = char_log_probs.exp()
            else:
                dfm = batch.dense_feat_matrix
                with Rename(*self.unit_dense_feat_matrix.values(),
                            unit='batch'):
                    adapted_dfm = self.adapter(dfm)
                with NoName(*adapted_dfm.values()):
                    word_repr = torch.cat([
                        adapted_dfm[cat] for cat in self.effective_categories
                    ],
                                          dim=-1)
                word_repr.rename_('batch', 'length', 'char_emb')
        else:
            with Rename(self.unit_feat_matrix, unit='batch'):
                word_repr = self.embedding(batch.feat_matrix,
                                           batch.source_padding)
                unit_repr = self.embedding(self.unit_feat_matrix)
            unit_repr = unit_repr.squeeze('length')
        unit_repr.rename_(batch='unit')

        # Main body: extract one span.
        extracted = Extracted(batch.batch_size)
        new_extracted = self._extract_one_span(batch, extracted, word_repr,
                                               unit_repr, char_log_probs)
        matches = new_extracted.matches
        len_e = matches.ll.size('len_e')
        vs = len(self.vocab)

        # Get the best score and span.
        # NOTE(j_luo) Some segments don't have any viable spans.
        flat_ll = matches.ll.flatten(['len_s', 'len_e', 'vocab'], 'cand')
        flat_viable = new_extracted.viable.expand_as(matches.ll).flatten(
            ['len_s', 'len_e', 'vocab'], 'cand')
        flat_viable_ll = (~flat_viable) * (-9999.9) + flat_ll
        # Add probs for unextracted characters.
        unextracted = batch.lengths.align_as(
            new_extracted.len_candidates) - new_extracted.len_candidates
        unextracted = unextracted.expand_as(matches.ll)
        flat_unextracted = unextracted.flatten(['len_s', 'len_e', 'vocab'],
                                               'cand')
        flat_unextracted_ll = flat_unextracted * math.log(g.unextracted_prob)
        flat_total_ll = flat_viable_ll + flat_unextracted_ll
        # Get the top candiates based on total scores.
        best_matched_ll, best_span_ind = flat_total_ll.max(dim='cand')
        start = best_span_ind // (len_e * vs)
        # NOTE(j_luo) Don't forget the length is off by g.min_word_length - 1.
        end = best_span_ind % (len_e *
                               vs) // vs + start + g.min_word_length - 1
        best_matched_vocab = best_span_ind % vs

        if self.training:
            any_viable = new_extracted.viable.any('len_s').any('len_e')
            best_matched_ll = flat_total_ll.logsumexp(dim='cand')
            best_matched_ll = best_matched_ll * any_viable

        ret = ExtractModelReturn(start, end, best_matched_ll,
                                 best_matched_vocab, new_extracted, alignment)

        return ret
Example #27
0
    def _extract_one_span(self, batch: ExtractBatch, extracted: Extracted,
                          word_repr: FT, unit_repr: FT,
                          char_log_probs: FT) -> Extracted:
        # Propose all span start/end positions.
        start_candidates = get_named_range(batch.max_length, 'len_s').align_to(
            'batch', 'len_s', 'len_e')
        # Range from `min_word_length` to `max_word_length`.
        len_candidates = get_named_range(
            g.max_word_length + 1 - g.min_word_length,
            'len_e') + g.min_word_length
        len_candidates = len_candidates.align_to('batch', 'len_s', 'len_e')
        # This is inclusive.
        end_candidates = start_candidates + len_candidates - 1

        # Only keep the viable/valid spans around.
        viable = (end_candidates < batch.lengths.align_as(end_candidates))
        start_candidates = start_candidates.expand_as(viable)
        len_candidates = len_candidates.expand_as(viable)
        # NOTE(j_luo) Use `viable` to get the lengths. `len_candidates` has dummy axes.
        # IDEA(j_luo) Any better way of handling this? Perhaps persistent names?
        len_s = viable.size('len_s')
        len_e = viable.size('len_e')
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        with NoName(start_candidates, end_candidates, len_candidates, bi,
                    viable):
            viable_starts = start_candidates[viable].rename('viable')
            viable_lens = len_candidates[viable].rename('viable')
            viable_bi = bi[viable].rename('viable')

        # Get the word positions to get the corresponding representations.
        viable_starts = viable_starts.align_to('viable', 'len_w')
        word_pos_offsets = get_named_range(g.max_word_length,
                                           'len_w').align_as(viable_starts)
        word_pos = viable_starts + word_pos_offsets
        word_pos = word_pos.clamp(max=batch.max_length - 1)

        # Get the corresponding representations.
        nh = NameHelper()
        viable_bi = viable_bi.expand_as(word_pos)
        word_pos = nh.flatten(word_pos, ['viable', 'len_w'], 'viable_X_len_w')
        viable_bi = nh.flatten(viable_bi, ['viable', 'len_w'],
                               'viable_X_len_w')
        word_repr = word_repr.align_to('batch', 'length', 'char_emb')
        if g.input_format == 'text':
            with NoName(word_repr, viable_bi, word_pos, batch.unit_id_seqs):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
                extracted_unit_ids = batch.unit_id_seqs[
                    viable_bi, word_pos].rename('viable_X_len_w')
        else:
            with NoName(word_repr, viable_bi, word_pos):
                extracted_word_repr = word_repr[viable_bi, word_pos].rename(
                    'viable_X_len_w', 'char_emb')
            extracted_unit_ids = None
        extracted_word_repr = nh.unflatten(extracted_word_repr,
                                           'viable_X_len_w',
                                           ['viable', 'len_w'])

        # Main body: Run DP to find the best matches.
        matches = self._get_matches(extracted_word_repr, unit_repr,
                                    viable_lens, extracted_unit_ids,
                                    char_log_probs)
        # Revert to the old shape (so that invalid spans are included).
        bi = get_named_range(batch.batch_size, 'batch').expand_as(viable)
        lsi = get_named_range(len_s, 'len_s').expand_as(viable)
        lei = get_named_range(len_e, 'len_e').expand_as(viable)
        vs = matches.ll.size('vocab')
        # IDEA(j_luo) NoName shouldn't make size() calls unavaiable. Otherwise size() calls have to be moved outside the context. Also the names should be preserved as well.
        with NoName(bi, lsi, lei, viable, matches.ll):
            v_bi = bi[viable]
            v_lsi = lsi[viable]
            v_lei = lei[viable]
            all_ll = get_zeros(batch.batch_size, len_s, len_e, vs)
            all_ll = all_ll.float().fill_(-9999.9)
            all_ll[v_bi, v_lsi, v_lei] = matches.ll
            matches.ll = all_ll.rename('batch', 'len_s', 'len_e', 'vocab')

        new_extracted = Extracted(batch.batch_size, matches, viable,
                                  len_candidates)
        return new_extracted
Example #28
0
    def _get_matches(self, extracted_word_repr: FT, unit_repr: FT,
                     viable_lens: LT, extracted_unit_ids: LT,
                     char_log_probs: FT) -> Matches:
        ns = extracted_word_repr.size('viable')
        len_w = extracted_word_repr.size('len_w')
        nt = len(self.vocab_feat_matrix)
        msl = extracted_word_repr.size('len_w')
        mtl = self.vocab_feat_matrix.size('length')

        # Compute cosine distances all at once: for each viable span, compare it against all units.
        ctx_logits = extracted_word_repr @ unit_repr.t()
        ctx_log_probs = ctx_logits.log_softmax(dim='unit').flatten(
            ['viable', 'len_w'], 'viable_X_len_w')
        with NoName(char_log_probs, extracted_unit_ids):
            global_log_probs = char_log_probs[extracted_unit_ids].rename(
                'viable_X_len_w', 'unit')
        weighted_log_probs = g.context_weight * ctx_log_probs + (
            1.0 - g.context_weight) * global_log_probs
        costs = -weighted_log_probs

        # Name: viable x len_w x unit
        costs = costs.unflatten('viable_X_len_w', [('viable', ns),
                                                   ('len_w', len_w)])

        # NOTE(j_luo) Use dictionary to save every state.
        fs = dict()
        for i in range(msl + 1):
            fs[(i, 0)] = get_zeros(ns, nt).fill_(i * self.ins_del_cost)
        for j in range(mtl + 1):
            fs[(0, j)] = get_zeros(ns, nt).fill_(j * self.ins_del_cost)

        # ------------------------ Main body: DP ----------------------- #

        # Transition.
        with NoName(self.indexed_segments, costs):
            for ls in range(1, msl + 1):
                min_lt = max(ls - 2, 1)
                max_lt = min(ls + 2, mtl + 1)
                for lt in range(min_lt, max_lt):
                    transitions = list()
                    if (ls - 1, lt) in fs:
                        transitions.append(fs[(ls - 1, lt)] +
                                           self.ins_del_cost)
                    if (ls, lt - 1) in fs:
                        transitions.append(fs[(ls, lt - 1)] +
                                           self.ins_del_cost)
                    if (ls - 1, lt - 1) in fs:
                        vocab_inds = self.indexed_segments[:, lt - 1]
                        sub_cost = costs[:, ls - 1, vocab_inds]
                        transitions.append(fs[(ls - 1, lt - 1)] + sub_cost)
                    if transitions:
                        all_s = torch.stack(transitions, dim=-1)
                        new_s, _ = all_s.min(dim=-1)
                        fs[(ls, lt)] = new_s

        f_lst = list()
        for i in range(msl + 1):
            for j in range(mtl + 1):
                if (i, j) not in fs:
                    fs[(i, j)] = get_zeros(ns, nt).fill_(9999.9)
                f_lst.append(fs[(i, j)])
        f = torch.stack(f_lst, dim=0).view(msl + 1, mtl + 1, -1,
                                           len(self.vocab))
        f.rename_('len_w_src', 'len_w_tgt', 'viable', 'vocab')

        # Get the values wanted.
        with NoName(f, viable_lens, self.vocab_length):
            idx_src = viable_lens.unsqueeze(dim=-1)
            idx_tgt = self.vocab_length
            viable_i = get_range(ns, 2, 0)
            vocab_i = get_range(len(self.vocab_length), 2, 1)
            nll = f[idx_src, idx_tgt, viable_i, vocab_i]
            nll.rename_('viable', 'vocab')

        # Get the best spans.
        matches = Matches(-nll, f)
        return matches
Example #29
0
 def forward(self, input_: LT) -> FT:
     with NoName(self.char_embedding, input_):
         return self.char_embedding[input_]
Example #30
0
 def _get_Wh_s(self, h_s: FT) -> FT:
     sl, bs, ds = h_s.size()
     with NoName(h_s):
         Wh_s = h_s.reshape(sl * bs, -1).mm(self.Wa).view(sl, bs, -1)
     return Wh_s