Esempio n. 1
0
    def _get_model(self, dl=None):
        phono_feat_mat = special_ids = None
        if g.use_phono_features:
            phono_feat_mat = get_tensor(self.src_abc.pfm)
            special_ids = get_tensor(self.src_abc.special_ids)

        phono_kwargs = {
            'phono_feat_mat': phono_feat_mat,
            'special_ids': special_ids
        }
        if g.use_rl:
            end_state = self.env.end
            agent_cls = VanillaPolicyGradient if g.agent == 'vpg' else A2C
            model = agent_cls(len(self.tgt_abc), self.env, end_state,
                              **phono_kwargs)
        else:
            model = OnePairModel(len(self.src_abc), len(self.tgt_abc),
                                 **phono_kwargs)
        if g.saved_model_path is not None:
            model.load_state_dict(
                torch.load(g.saved_model_path,
                           map_location=torch.device('cpu')))
            logging.imp(f'Loaded from {g.saved_model_path}.')
        if has_gpus():
            model.cuda()
        logging.info(model)
        return model
Esempio n. 2
0
    def evaluate(self,
                 states,
                 steps: Optional[Union[int, LT]] = None) -> List[float]:
        """Expand and evaluate the leaf node."""
        values = [None] * len(states)
        outstanding_idx = list()
        outstanding_states = list()
        # Deal with end states first.
        for i, state in enumerate(states):
            if state.stopped or state.done:
                # NOTE(j_luo) This value is used for backup. If already reaching the end state, the final reward is either accounted for by the step reward, or by the value network. Therefore, we need to set it to 0.0 here.
                values[i] = 0.0
            else:
                outstanding_idx.append(i)
                outstanding_states.append(state)

        # Collect states that need evaluation.
        if outstanding_states:
            almts1 = almts2 = None
            if g.use_alignment:
                id_seqs, almts1, almts2 = parallel_stack_ids(
                    outstanding_states, g.num_workers, True,
                    self.env.max_end_length)
                almts1 = get_tensor(almts1).rename('batch', 'word', 'pos')
                almts2 = get_tensor(almts2).rename('batch', 'word', 'pos')
            else:
                id_seqs = parallel_stack_ids(outstanding_states, g.num_workers,
                                             False, self.env.max_end_length)
            id_seqs = get_tensor(id_seqs).rename('batch', 'word', 'pos')
            if steps is not None and not isinstance(steps, int):
                steps = steps[outstanding_idx]

            # TODO(j_luo) Scoped might be wrong here.
            # with ScopedCache('state_repr'):
            # NOTE(j_luo) Don't forget to call exp().
            priors = self.agent.get_policy(id_seqs,
                                           almts=(almts1, almts2)).exp()
            with NoName(priors):
                meta_priors = priors[:, [0, 2, 3, 4, 5, 6]].cpu().numpy()
                special_priors = priors[:, 1].cpu().numpy()
            if g.use_value_guidance:
                agent_values = self.agent.get_values(
                    id_seqs, steps=steps).cpu().numpy()
            else:
                agent_values = np.zeros([len(id_seqs)], dtype='float32')

            for i, state, mp, sp, v in zip(outstanding_idx, outstanding_states,
                                           meta_priors, special_priors,
                                           agent_values):
                # NOTE(j_luo) Values should be returned even if states are duplicates or have been visited.
                values[i] = v
                # NOTE(j_luo) Skip duplicate states (due to exploration collapse) or visited states (due to rollout truncation).
                if not state.is_leaf():
                    continue

                # print(mp[1, 111])
                self.env.evaluate(state, mp, sp)
        return values
Esempio n. 3
0
def collate_fn(batch):
    words = _get_item('word', batch)
    forms = _get_item('form', batch)
    char_seqs = _get_item('char_seq', batch)
    id_seqs = _get_item('id_seq', batch)
    lengths, words, forms, char_seqs, id_seqs = sort_all(words, forms, char_seqs, id_seqs)
    lengths = get_tensor(lengths, dtype='l')
    # Trim the id_seqs.
    max_len = max(lengths).item()
    id_seqs = pad_to_dense(id_seqs, dtype='l')
    id_seqs = get_tensor(id_seqs[:, :max_len])

    lang = batch[0].lang
    return Map(
        words=words, forms=forms, char_seqs=char_seqs, id_seqs=id_seqs, lengths=lengths, lang=lang)
Esempio n. 4
0
    def analyze(self, log_probs, almt_distr, words, lost_lengths):
        self.clear_cache()
        self._sample(words)

        assert self._eff_max_length == len(log_probs)

        tl, nc, bs = log_probs.shape
        charset = get_charset(self.lang)
        assert nc == len(charset)

        # V x bs, or c_s x c_t -> bs x V
        valid_log_probs = self._eff_weight.matmul(log_probs.view(-1, bs)).t()

        sl = almt_distr.shape[-1]
        pos = get_tensor(torch.arange(sl).float(), requires_grad=False)
        mean_pos = (pos * almt_distr).sum(dim=-1)  # bs x tl
        mean_pos = torch.cat([get_zeros(bs, 1, requires_grad=False).fill_(-1.0), mean_pos],
                             dim=-1)
        reg_weight = lost_lengths.float().view(-1, 1) - 1.0 - mean_pos[:, :-1]
        reg_weight.clamp_(0.0, 1.0)
        rel_pos = mean_pos[:, 1:] - mean_pos[:, :-1]  # bs x tl
        rel_pos_diff = rel_pos - 1
        margin = rel_pos_diff != 0
        reg_loss = margin.float() * (rel_pos_diff ** 2)  # bs x tl
        reg_loss = (reg_loss * reg_weight).sum()

        return Map(reg_loss=reg_loss, valid_log_probs=valid_log_probs)
Esempio n. 5
0
 def __init__(self, lost_lang, known_lang, momentum, num_cognates):
     super().__init__()
     lost_words = get_words(lost_lang)
     known_words = get_words(known_lang)
     flow = get_tensor(np.zeros([len(lost_words), len(known_words)]))
     self.flow = MagicTensor(flow, lost_words, known_words)
     self._warmed_up = False
Esempio n. 6
0
    def _get_batch_records(self, dl: OnePairDataLoader, batch: OnePairBatch,
                           K: int) -> List[Dict[str, Any]]:
        hyps = self.model.predict(batch)
        # NOTE(j_luo) EOT's have been removed from translations since they don't matter in edit distance computation.
        preds, pred_lengths, _ = hyps.translate(self.tgt_abc)
        # HACK(j_luo) Pretty ugly here.
        dists = compute_edit_dist(g.comp_mode,
                                  pred_ids=hyps.tokens,
                                  lengths=dl.tgt_seqs.lengths,
                                  gold_ids=dl.tgt_seqs.ids.t(),
                                  forms=dl.tgt_vocabulary.forms,
                                  units=dl.tgt_seqs.units,
                                  predictions=preds,
                                  pred_lengths=pred_lengths,
                                  pfm=self.tgt_abc.pfm)

        weights = get_beam_probs(hyps.scores)
        w_dists = weights.align_to(..., 'tgt_vocab') * dists
        expected_dists = w_dists.sum(dim='beam')
        top_s, top_i = torch.topk(-expected_dists, K, dim='tgt_vocab')
        top_s = -top_s

        records = list()
        tgt_vocab = dl.tgt_vocabulary
        # In order to record the edit distance between the top prediction and the ground truth,
        # we need to find the index of the ground truth in the vocabulary, not in the dataset.
        tgt_ids = get_tensor(
            [tgt_vocab.get_id_by_form(form) for form in batch.tgt_seqs.forms])
        dists_tx = Tx(dists, ['batch', 'beam', 'tgt_vocab'])
        tgt_ids_tx = Tx(tgt_ids, ['batch'])
        top_dists_tx = dists_tx.select('beam', 0)
        top_dists = top_dists_tx.each_select({'tgt_vocab': tgt_ids_tx}).data
        normalized_top_dists = top_dists.float() / (batch.tgt_seqs.lengths - 1)
        top_dists = top_dists.cpu().numpy()
        normalized_top_dists = normalized_top_dists.cpu().numpy()
        # We also report the perplexity scores.
        log_probs, _ = self.model(batch)
        ppxs = get_ce_loss(log_probs, batch, agg='batch_mean')
        ppxs = ppxs.cpu().numpy()
        for pss, pis, src, gold, pbis, pbss, top_dist, n_top_dist, ppx in zip(
                top_s, top_i, batch.src_seqs.forms, batch.tgt_seqs.forms,
                preds, weights, top_dists, normalized_top_dists, ppxs):
            record = {
                'source': src,
                'gold_target': gold,
                'edit_dist': top_dist,
                'normalized_edit_dist': n_top_dist,
                'ppx': ppx
            }
            for i, (pbs, pbi) in enumerate(zip(pbss, pbis), 1):
                record[f'pred_target_beam@{i}'] = pbi
                record[f'pred_target_beam@{i}_score'] = f'{pbs.item():.3f}'
            for i, (ps, pi) in enumerate(zip(pss, pis), 1):
                pred_closest = tgt_vocab[pi]['form']
                record[f'pred_target@{i}'] = pred_closest
                record[f'pred_target@{i}_score'] = f'{ps.item():.3f}'

            records.append(record)

        return records
Esempio n. 7
0
    def _prepare_weight(self):
        rows = list()
        cols = list()

        words = get_words(self.lang)
        charset = get_charset(self.lang)

        self._word2rows = defaultdict(list)
        for row, word in enumerate(words):
            for i, c in enumerate(word.char_seq):
                cid = charset.char2id(c)
                self._word2rows[word].append(len(rows))
                rows.append(row)
                cols.append(len(charset) * i + cid)
        data = np.ones(len(rows))
        # NOTE This is ugly, but it avoids this issue in 0.4.1: https://github.com/pytorch/pytorch/issues/8856.
        weight = torch.sparse.FloatTensor(
            get_tensor([rows, cols], dtype='l', use_cuda=False),
            get_tensor(data, dtype='f', use_cuda=False),
            (len(words), self._max_length * len(charset)))
        self._weight = get_tensor(weight)
Esempio n. 8
0
 def _sample(self, words):
     all_words = get_words(self.lang)
     word_indices = list()
     old_to_new = np.zeros([len(all_words)], dtype='int64')
     self._eff_id2word = list()
     self._eff_word2id = dict()
     self._eff_max_length = max(map(len, words))
     for w in words:
         word_indices.extend(self._word2rows[w])
         old_to_new[w.idx] = len(self._eff_id2word)
         self._eff_id2word.append(w)
         self._eff_word2id[w] = len(self._eff_word2id)
     old_to_new = get_tensor(old_to_new)
     indices = get_tensor(word_indices, dtype='l')
     old_rows, cols = self._weight._indices()[:, word_indices].unbind(dim=0)
     rows = old_to_new[old_rows]
     data = self._weight._values()[word_indices]
     charset = get_charset(self.lang)
     weight = torch.sparse.FloatTensor(
         torch.stack([rows, cols], dim=0),
         data,
         (len(words), self._eff_max_length * len(charset)))
     self._eff_weight = get_tensor(weight)
Esempio n. 9
0
    def _train_one_step_mrt(self, batch: OnePairBatch,
                            abc: Alphabet) -> Metrics:
        """Train for one step using minimum risk training (MRT)."""
        # Get scores for top predictions and the target sequence.
        assert g.comp_mode == 'str'
        assert g.dropout == 0.0, 'Might have some issue due to the fact that ground truth is fed through the normal forward call whereas hyps are not, resulting in discrepant dropout applications.'
        hyps = self.model.predict(batch)
        log_probs, _ = self.model(batch)
        tgt_scores = -get_ce_loss(log_probs, batch, agg='batch')

        # Mark which ones are duplicates.
        preds, _, _ = hyps.translate(abc)
        duplicates = list()
        for beam_preds, tgt_form in zip(preds, batch.tgt_seqs.forms):
            duplicates.append([False] + [p == tgt_form for p in beam_preds])
            # If no duplicates are found, then we discard the last prediction.
            if not any(duplicates[-1]):
                duplicates[-1][-1] = True
            if sum(duplicates[-1]) != 1:
                raise RuntimeError(f'Should have exactly one duplicate.')
        duplicates = get_tensor(duplicates)

        # Assemble all scores together.
        scores = torch.cat([tgt_scores.align_as(hyps.scores), hyps.scores],
                           dim='beam')
        probs = get_beam_probs(scores, duplicates=duplicates)
        target = np.tile(batch.tgt_seqs.forms.reshape(-1, 1),
                         [1, g.beam_size + 1])
        preds = np.concatenate([target[:, 0:1], preds], axis=-1)
        dists = edit_dist_batch(preds.reshape(-1), target.reshape(-1), 'ed')
        lengths = batch.tgt_seqs.lengths.align_to('batch', 'beam')
        dists = get_tensor(dists.reshape(-1,
                                         g.beam_size + 1)).float()  # / lengths
        # risk = (probs * (dists ** 2)).sum(dim='beam')
        risk = (probs * dists).sum(dim='beam')
        risk = Metric('risk', risk.sum(), len(batch))
        return Metrics(risk)
Esempio n. 10
0
 def search_by_probs(self, lengths: LT,
                     label_log_probs: FT) -> Tuple[LT, FT]:
     max_length = lengths.max().item()
     samples = get_tensor(
         torch.LongTensor(list(product([B, I, O], repeat=max_length))))
     samples.rename_('sample', 'length')
     bs = label_log_probs.size('batch')
     samples = samples.align_to('batch', 'sample',
                                'length').expand(bs, -1, -1)
     sample_log_probs = label_log_probs.gather('label', samples)
     with NoName(lengths):
         length_mask = get_length_mask(lengths, max_length).rename(
             'batch', 'length')
     length_mask = length_mask.align_to(sample_log_probs)
     sample_log_probs = (sample_log_probs *
                         length_mask.float()).sum(dim='length')
     return samples, sample_log_probs
        type=str,
        help='Path to save the output. No suffix should be included.')
    args = parser.parse_args()

    if '.' in args.out_name:
        raise ValueError(f'No suffix should be included.')

    initiator = setup()
    initiator.run(saved_g_path=args.saved_g_path)
    _, _, abc, _ = OneToManyManager.prepare_raw_data()
    assert g.share_src_tgt_abc

    sd = torch.load(args.saved_model_path)

    emb_params = get_emb_params(len(abc),
                                phono_feat_mat=get_tensor(abc.pfm),
                                special_ids=get_tensor(abc.special_ids))
    emb = PhonoEmbedding.from_params(emb_params)
    prefix = 'encoder.embedding'
    emb.load_state_dict({
        'weight': sd[f'{prefix}.weight'],
        'special_weight': sd[f'{prefix}.special_weight'],
        'special_mask': sd[f'{prefix}.special_mask'],
        'pfm': sd[f'{prefix}.pfm']
    })
    emb.cuda()

    char_emb = emb.char_embedding.detach().cpu().numpy()
    size = char_emb.shape[-1]
    cols = [f'vec_{i}' for i in range(size)]
    df = pd.DataFrame(char_emb, columns=cols)
Esempio n. 12
0
    def collect_episodes(self,
                         init_state: VocabState,
                         tracker: Optional[Tracker] = None,
                         num_episodes: int = 0,
                         is_eval: bool = False,
                         no_simulation: bool = False) -> List[Trajectory]:
        trajectories = list()
        self.agent.eval()
        if is_eval:
            self.eval()
        else:
            self.train()
        num_episodes = num_episodes or g.num_episodes
        # if no_simulation:
        #     breakpoint()  # BREAKPOINT(j_luo)
        with self.agent.policy_grad(False), self.agent.value_grad(False):
            for ei in range(num_episodes):
                root = init_state
                self.reset()
                steps = 0 if g.use_finite_horizon else None
                self.evaluate([root], steps=steps)

                # Episodes have max rollout length.
                played_path = None
                for ri in range(g.max_rollout_length):
                    if not is_eval:
                        self.add_noise(root)
                    if is_eval and no_simulation:
                        new_state = self.select_one_pi_step(root)
                        steps = steps + 1 if g.use_finite_horizon else None
                        values = self.evaluate([new_state], steps=steps)
                    else:
                        # Run many simulations before take one action. Simulations take place in batches. Each batch
                        # would be evaluated and expanded after batched selection.
                        num_batches = g.num_mcts_sims // g.expansion_batch_size
                        for _ in range(num_batches):
                            paths, steps = self.select(root,
                                                       g.expansion_batch_size,
                                                       ri,
                                                       g.max_rollout_length,
                                                       played_path)
                            steps = get_tensor(
                                steps) if g.use_finite_horizon else None
                            new_states = [
                                path.get_last_node() for path in paths
                            ]
                            values = self.evaluate(new_states, steps=steps)
                            self.backup(paths, values)
                            if tracker is not None:
                                tracker.update('mcts',
                                               incr=g.expansion_batch_size)
                        if ri == 0 and ei % g.episode_check_interval == 0:
                            k = min(20, root.num_actions)
                            logging.debug(
                                pad_for_log(
                                    str(
                                        get_tensor(
                                            root.action_counts).topk(k))))
                            logging.debug(
                                pad_for_log(str(get_tensor(root.q).topk(k))))
                            logging.debug(
                                pad_for_log(
                                    str(get_tensor(root.max_values).topk(k))))
                    ps = self.play_strategy
                    if is_eval:
                        if no_simulation:
                            ps = PyPS_SAMPLE_AC
                        else:
                            ps = PyPS_MAX
                    new_path = self.play(root, ri, ps, g.exponent)
                    if played_path is None:
                        played_path = new_path
                    else:
                        played_path.merge(new_path)
                    root = played_path.get_last_node()

                    # print('3')
                    if tracker is not None:
                        tracker.update('rollout')
                    if root.stopped or root.done:
                        break
                    # self.show_stats()
                trajectory = Trajectory(played_path, self.env.max_end_length)
                if ei % g.episode_check_interval == 0:
                    logging.debug(pad_for_log(str(trajectory)))

                trajectories.append(trajectory)
                if tracker is not None:
                    tracker.update('episode')
        # if no_simulation:
        #     breakpoint()  # BREAKPOINT(j_luo)

        return trajectories
Esempio n. 13
0
    def __init__(self):
        all_tgt, self.cog_reg, self.src_abc, self.tgt_abc = self.prepare_raw_data(
        )

        # Get stats for unseen units.
        stats = self.tgt_abc.stats
        _, test_tgt_path = get_paths(g.data_path, g.src_lang, g.tgt_lang)
        mask = (stats.sum() == stats.loc[test_tgt_path])
        unseen = mask[mask].index.tolist()
        total = len(stats.loc[test_tgt_path].dropna())
        logging.info(
            f'Unseen units ({len(unseen)}/{total}) for {g.tgt_lang} are: {unseen}.'
        )

        # Get language-to-id mappings. Used only for the targets (i.e., decoder side).
        self.lang2id = lang2id = {tgt: i for i, tgt in enumerate(all_tgt)}

        # Get all data loaders.
        self.dl_reg = DataLoaderRegistry()

        def create_setting(name: str,
                           tgt_lang: str,
                           split: Split,
                           for_training: bool,
                           keep_ratio: Optional[float] = None,
                           tgt_sot: bool = False) -> Setting:
            return Setting(name,
                           'one_pair',
                           split,
                           g.src_lang,
                           tgt_lang,
                           for_training,
                           keep_ratio=keep_ratio,
                           tgt_sot=tgt_sot)

        test_setting = create_setting(f'test@{g.tgt_lang}',
                                      g.tgt_lang,
                                      Split('all'),
                                      False,
                                      keep_ratio=g.test_keep_ratio)
        settings: List[Setting] = [test_setting]

        # Get the training languages.
        for train_tgt_lang in g.train_tgt_langs:
            if g.input_format == 'ielex':
                train_split = Split(
                    'train',
                    [1, 2, 3, 4])  # Use the first four folds for training.
                dev_split = Split('dev', [5])  # Use the last fold for dev.
            else:
                train_split = Split('train')
                dev_split = Split('dev')
            train_setting = create_setting(f'train@{train_tgt_lang}',
                                           train_tgt_lang,
                                           train_split,
                                           True,
                                           keep_ratio=g.keep_ratio)
            train_e_setting = create_setting(f'train@{train_tgt_lang}_e',
                                             train_tgt_lang,
                                             train_split,
                                             False,
                                             keep_ratio=g.keep_ratio)
            dev_setting = create_setting(f'dev@{train_tgt_lang}',
                                         train_tgt_lang, dev_split, False)
            test_setting = create_setting(f'test@{train_tgt_lang}',
                                          train_tgt_lang, Split('test'), False)

            settings.extend(
                [train_setting, train_e_setting, dev_setting, test_setting])
        for setting in settings:
            self.dl_reg.register_data_loader(setting,
                                             self.cog_reg,
                                             lang2id=lang2id)

        phono_feat_mat = special_ids = None
        if g.use_phono_features:
            phono_feat_mat = get_tensor(self.src_abc.pfm)
            special_ids = get_tensor(self.src_abc.special_ids)

        self.model = OneToManyModel(len(self.src_abc),
                                    len(self.tgt_abc),
                                    len(g.train_tgt_langs) + 1,
                                    lang2id[g.tgt_lang],
                                    lang2id=lang2id,
                                    phono_feat_mat=phono_feat_mat,
                                    special_ids=special_ids)

        if g.saved_model_path is not None:
            self.model.load_state_dict(
                torch.load(g.saved_model_path,
                           map_location=torch.device('cpu')))
            logging.imp(f'Loaded from {g.saved_model_path}.')
        if has_gpus():
            self.model.cuda()
        logging.info(self.model)

        metric_writer = MetricWriter(g.log_dir, flush_secs=5)

        # NOTE(j_luo) Evaluate on every loader that is not for training.
        eval_dls = self.dl_reg.get_loaders_by_name(
            lambda name: 'train' not in name or '_e' in name)
        self.evaluator = Evaluator(self.model,
                                   eval_dls,
                                   self.tgt_abc,
                                   metric_writer=metric_writer)

        if not g.evaluate_only:
            train_names = [
                f'train@{train_tgt_lang}'
                for train_tgt_lang in g.train_tgt_langs
            ]
            train_settings = [
                self.dl_reg.get_setting_by_name(name) for name in train_names
            ]
            self.trainer = Trainer(self.model,
                                   train_settings, [1.0] * len(train_settings),
                                   'step',
                                   stage_tnames=['step'],
                                   evaluator=self.evaluator,
                                   check_interval=g.check_interval,
                                   eval_interval=g.eval_interval,
                                   save_interval=g.save_interval,
                                   metric_writer=metric_writer)
            if g.saved_model_path is None:
                # self.trainer.init_params('uniform', -0.1, 0.1)
                self.trainer.init_params('xavier_uniform')
            optim_cls = Adam if g.optim_cls == 'adam' else SGD
            self.trainer.set_optimizer(optim_cls, lr=g.learning_rate)
Esempio n. 14
0
def compute_edit_dist(comp_mode: str,
                      pred_ids: Optional[LT] = None,
                      lengths: Optional[LT] = None,
                      gold_ids: Optional[LT] = None,
                      forms: Optional[np.ndarray] = None,
                      predictions=None,
                      pred_lengths: Optional[np.ndarray] = None,
                      units: Optional[np.ndarray] = None,
                      pfm=None) -> FT:
    if comp_mode == 'ids_gpu':
        # Prepare tensorx.
        pred_lengths = Tx(get_tensor(pred_lengths), ['pred_batch', 'beam'])
        pred_tokens = Tx(pred_ids, ['pred_batch', 'beam', 'l'])
        # NOTE(j_luo) -1 for removing EOT's.
        tgt_lengths = Tx(lengths, ['tgt_batch']) - 1
        tgt_tokens = Tx(gold_ids, ['tgt_batch', 'l'])

        # Align them to the same names.
        new_names = ['pred_batch', 'beam', 'tgt_batch']
        pred_lengths = pred_lengths.align_to(*new_names)
        pred_tokens = pred_tokens.align_to(*(new_names + ['l']))
        tgt_lengths = tgt_lengths.align_to(*new_names)
        tgt_tokens = tgt_tokens.align_to(*(new_names + ['l']))

        # Expand them to have the same size.
        pred_bs = pred_tokens.size('pred_batch')
        tgt_bs = tgt_tokens.size('tgt_batch')
        pred_lengths = pred_lengths.expand({'tgt_batch': tgt_bs})
        pred_tokens = pred_tokens.expand({'tgt_batch': tgt_bs})
        tgt_tokens = tgt_tokens.expand({
            'pred_batch': pred_bs,
            'beam': g.beam_size
        })
        tgt_lengths = tgt_lengths.expand({
            'pred_batch': pred_bs,
            'beam': g.beam_size
        })

        # Flatten names, preparing for DP.
        def flatten(tx):
            return tx.flatten(['pred_batch', 'beam', 'tgt_batch'], 'batch')

        pred_lengths = flatten(pred_lengths)
        pred_tokens = flatten(pred_tokens)
        tgt_lengths = flatten(tgt_lengths)
        tgt_tokens = flatten(tgt_tokens)

        penalty = None
        if g.use_phono_edit_dist:
            x = pfm.rename('src_unit', 'phono_feat')
            y = pfm.rename('tgt_unit', 'phono_feat')
            names = ('src_unit', 'tgt_unit', 'phono_feat')
            diff = x.align_to(*names) - y.align_to(*names)
            penalty = (diff != 0).sum(
                'phono_feat').cuda().float() / g.phono_edit_dist_scale
        dp = EditDist(pred_tokens,
                      tgt_tokens,
                      pred_lengths,
                      tgt_lengths,
                      penalty=penalty)
        dp.run()
        dists = dp.get_results().data
        dists = dists.view(pred_bs, g.beam_size, tgt_bs)
    else:
        eval_all = lambda seqs_0, seqs_1: edit_dist_all(
            seqs_0, seqs_1, mode='ed')
        flat_preds = predictions.reshape(-1)
        if comp_mode == 'units':
            # NOTE(j_luo) Remove EOT's.
            flat_golds = [u[:-1] for u in units]
        elif comp_mode == 'ids':
            tgt_ids = gold_ids.cpu().numpy()
            # NOTE(j_luo) -1 for EOT's.
            tgt_lengths = lengths - 1
            flat_golds = [ids[:l] for ids, l in zip(tgt_ids, tgt_lengths)]
        elif comp_mode == 'str':
            flat_golds = forms
        dists = get_tensor(eval_all(flat_preds,
                                    flat_golds)).view(-1, g.beam_size,
                                                      len(flat_golds))
    return dists
Esempio n. 15
0
def compute_expected_edits(known_charset,
                           log_probs,
                           wordlist,
                           valid_log_probs,
                           num_samples=10,
                           alpha=1e1,
                           edit=False):
    logging.debug('Computing expected edits')
    log_probs = log_probs.transpose(0, 2).transpose(1, 2)  # size: bs x tl x C
    log_probs = torch.log_softmax(log_probs * alpha, dim=-1)
    probs = log_probs.exp()
    bs, tl, nc = probs.shape
    # get samples
    if num_samples > 0:
        samples = torch.multinomial(probs.reshape(bs * tl, nc),
                                    num_samples,
                                    replacement=True)
        samples = samples.view(bs, tl, num_samples)
        # get tokens
        tokens = known_charset.get_tokens(samples.transpose(
            1, 2))  # size: bs x num_samples
        # get probs
        sample_log_probs = log_probs[torch.arange(bs).long().view(-1, 1, 1),
                                     torch.arange(tl).long().view(1, -1, 1),
                                     samples]  # bs x tl x ns
        lengths = get_tensor(np.vectorize(len)(tokens) + 1,
                             dtype='f')  # bs x num_samples
        mask = get_tensor(torch.arange(tl)).float().view(1, -1, 1).expand(
            bs, tl, num_samples) < lengths.unsqueeze(dim=1)
        sample_log_probs = (mask.float() * sample_log_probs).sum(
            dim=1)  # bs x num_samples
    else:  # This means we are taking the argmax according to token-level probs, not character-level probs.
        # Take argmax
        _, idx = valid_log_probs.max(dim=-1)
        tokens = wordlist[idx.cpu().numpy()].reshape(bs, 1)
        num_samples = 1
        sample_log_probs = get_tensor(np.ones([bs, 1]))
    # use chunks to get all edits
    chunk_size = 1000
    num_chunks = len(wordlist) // chunk_size + (len(wordlist) % chunk_size > 0)
    expected_edits = list()
    for i in range(num_chunks):
        logging.debug('Computing chunk %d/%d' % (i + 1, num_chunks))
        start = i * chunk_size
        end = min(start + chunk_size, len(wordlist))

        valid_log_prob_chunk = valid_log_probs[:, start:end]
        if edit:
            # get dists
            dists = compute_dists(tokens,
                                  wordlist[start:end])  # bs x c_s x (1 + ns)
            dists = get_tensor(dists, 'f')
            # remove accidental hits
            duplicates = compute_duplicates(
                tokens, wordlist[start:end])  # bs x c_s x (1 + ns)
            duplicates = get_tensor(duplicates, 'f')
            edit_chunk = dists * duplicates
            # compute expected edits
            ex_sample_log_probs = sample_log_probs.view(
                bs, 1, num_samples).expand(-1, valid_log_prob_chunk.shape[-1],
                                           -1)
            all_sample_log_probs = torch.cat(
                [valid_log_prob_chunk.unsqueeze(dim=-1), ex_sample_log_probs],
                dim=-1)
            # make it less sharp
            all_sample_log_probs = all_sample_log_probs + (
                1.0 - duplicates) * (-999.)
            logits = all_sample_log_probs  # * alpha
            sm_log_probs = torch.log_softmax(
                logits, dim=-1)  # NOTE sm stands for softmax
            sm_probs = sm_log_probs.exp()
            expected_edits.append((edit_chunk * sm_probs).sum(dim=-1))
            # expected_edits.append(dists[..., 1:].sum(dim=-1))
        else:
            expected_edits.append(-valid_log_prob_chunk.tensor)
    return torch.cat(expected_edits, dim=1)
Esempio n. 16
0
def sample_gumbel(shape: Sequence[int], eps: float = 1e-20) -> FT:
    """Sample from Gumbel(0, 1)"""
    U = get_tensor(torch.rand(shape))
    return -(-(U + eps).log() + eps).log()
Esempio n. 17
0
 def tensor(self) -> LT:
     """Convert the state into a long tensor."""
     return get_tensor(self.vocab_array).rename('word', 'pos')