Esempio n. 1
0
def get_matching_stats(predictions: List[Segmentation],
                       ground_truths: List[Segmentation],
                       match_words: bool = False) -> Metrics:
    exact_span_matches = 0
    prefix_span_matches = 0
    if match_words:
        exact_word_matches = 0
        prefix_word_matches = 0
    for pred, gt in zip(predictions, ground_truths):
        for p in pred:
            for g in gt:
                if p.is_same_span(g):
                    exact_span_matches += 1
                    prefix_span_matches += 1
                    if match_words and p.is_same_word(g):
                        exact_word_matches += 1
                        prefix_word_matches += 1
                elif p.is_prefix_span_of(g) or g.is_prefix_span_of(p):
                    prefix_span_matches += 1
                    if match_words and (p.is_prefix_word_of(g)
                                        or g.is_prefix_word_of(p)):
                        prefix_word_matches += 1

    total_correct = sum(map(len, ground_truths))
    total_pred = sum(map(len, predictions))
    exact_span_matches = Metric(f'prf_exact_span_matches',
                                exact_span_matches,
                                1.0,
                                report_mean=False)
    prefix_span_matches = Metric(f'prf_prefix_span_matches',
                                 prefix_span_matches,
                                 1.0,
                                 report_mean=False)
    total_correct = Metric(f'prf_total_correct',
                           total_correct,
                           1.0,
                           report_mean=False)
    total_pred = Metric(f'prf_total_pred', total_pred, 1.0, report_mean=False)
    metrics = Metrics(exact_span_matches, prefix_span_matches, total_correct,
                      total_pred)

    if match_words:
        exact_word_matches = Metric(f'prf_exact_word_matches',
                                    exact_word_matches,
                                    1.0,
                                    report_mean=False)
        prefix_word_matches = Metric(f'prf_prefix_word_matches',
                                     prefix_word_matches,
                                     1.0,
                                     report_mean=False)
        metrics += Metrics(exact_word_matches, prefix_word_matches)
    return metrics
Esempio n. 2
0
 def train_one_step(self, dl: IpaDataLoader) -> Metrics:
     self.model.train()
     self.optimizer.zero_grad()
     batch = dl.get_next_batch()
     ret = self.model.score(batch)
     # for idx, segment in enumerate(batch.segments):
     #     if str(segment).startswith('e-s-t-a-n'):
     #         break
     # from xib.ipa import Name
     # name = Name('Ptype', 'camel')
     # print(torch.stack([ret.distr[name][0], ret.distr_noise[name][0]], new_name='tmp')[idx])
     # import time; time.sleep(1)
     metrics = self.analyzer.analyze(ret)
     metrics.loss.mean.backward()
     grad_norm = get_grad_norm(self.model)
     grad_norm = Metric('grad_norm', grad_norm * len(batch), len(batch))
     metrics += grad_norm
     self.optimizer.step()
     return metrics
Esempio n. 3
0
def get_prf_scores(metrics: Metrics) -> Metrics:
    prf_scores = Metrics()

    def _get_f1(p, r):
        return 2 * p * r / (p + r + 1e-8)

    exact_span_matches = getattr(metrics, f'prf_exact_span_matches').total
    prefix_span_matches = getattr(metrics, f'prf_prefix_span_matches').total
    total_pred = getattr(metrics, f'prf_total_pred').total
    total_correct = getattr(metrics, f'prf_total_correct').total
    exact_span_precision = exact_span_matches / (total_pred + 1e-8)
    exact_span_recall = exact_span_matches / (total_correct + 1e-8)
    exact_span_f1 = _get_f1(exact_span_precision, exact_span_recall)
    prefix_span_precision = prefix_span_matches / (total_pred + 1e-8)
    prefix_span_recall = prefix_span_matches / (total_correct + 1e-8)
    prefix_span_f1 = _get_f1(prefix_span_precision, prefix_span_recall)
    prf_scores += Metric(f'prf_exact_span_precision',
                         exact_span_precision,
                         report_mean=False)
    prf_scores += Metric(f'prf_exact_span_recall',
                         exact_span_recall,
                         1.0,
                         report_mean=False)
    prf_scores += Metric(f'prf_exact_span_f1',
                         exact_span_f1,
                         1.0,
                         report_mean=False)
    prf_scores += Metric(f'prf_prefix_span_precision',
                         prefix_span_precision,
                         report_mean=False)
    prf_scores += Metric(f'prf_prefix_span_recall',
                         prefix_span_recall,
                         1.0,
                         report_mean=False)
    prf_scores += Metric(f'prf_prefix_span_f1',
                         prefix_span_f1,
                         1.0,
                         report_mean=False)
    return prf_scores
Esempio n. 4
0
    def analyze(self, ret: AdaptLMReturn) -> Metrics:
        metrics, scores = super().analyze(ret.distr, return_scores=True)
        # if g.use_moe:
        #     # prior = get_tensor([g.prior_value, 1.0 - g.prior_value]).squeeze(dim=0)
        #     # lp = ret.gate_log_probs
        #     # kld = lp.exp() * (lp - prior.log())

        #     # kld.
        #     lp = ret.gate_log_probs
        #     _p = lp.exp().sum(0) / lp.exp().sum()
        #     prior = get_tensor([g.prior_value, 1.0 - g.prior_value]).squeeze(dim=0)
        #     kld = _p * (_p.log() - prior.log())

        #     bs = lp.size('batch')
        #     kld = Metric('kld', kld.sum() * bs, bs)
        #     metrics += kld

        #     # sparsity.
        #     _p = lp.exp()
        #     with NoName(_p):
        #         # sparsity = torch.nn.functional.softmin(torch.stack([_p, 1.0 - _p], dim=-1), dim=-1)
        #         sparsity = torch.min(_p, 1.0 - _p)
        #     sparsity = Metric('sparsity', sparsity.sum(), bs)
        #     metrics += sparsity

        #     metrics.rename('loss', 'ce_loss')
        #     metrics += Metric('loss', metrics.ce_loss.total, bs)
        #     # metrics += Metric('loss', metrics.ce_loss.total + kld.total, bs)
        #     # metrics += Metric('loss', metrics.ce_loss.total + kld.total + sparsity.total, bs)

        if g.use_moe:
            metrics = Metrics()
            metrics_noise, scores_noise = super().analyze(ret.distr_noise,
                                                          return_scores=True)
            total_loss = 0.0
            total_weight = 0.0
            cnt = 0
            prob_cnt = 0

            # gate_log_probs = ret.gate_logits.log_softmax(dim=-1)

            all_scores = [s for _, (s, _) in scores.items()]
            all_weights = [w for _, (_, w) in scores.items()]
            weight = all_weights[0]

            sum_scores = torch.stack(all_scores,
                                     new_name='stacked').sum(dim='stacked')
            batch_probs = ret.gate_logits.log_softmax(
                dim=-1).exp()[:, 0] * weight  # + (-999.9) * (1.0 - weight))
            # batch_probs = (ret.gate_logits[:, 0] + (-999.9) * (1.0 - weight)).log_softmax(dim='batch').exp()
            bs = batch_probs.size('batch')
            total = int(g.prior_value * weight.sum())
            diff_loss = ((batch_probs.sum() - total)**2).sum()
            diff_loss = Metric('diff_loss', diff_loss, bs)
            loss = (sum_scores * batch_probs).sum()
            loss = Metric('loss', loss + diff_loss.total, bs)

            metrics += diff_loss
            metrics += loss

            # for name in scores:
            #     s, w = scores[name]
            #     sn, _ = scores_noise[name]
            #     all_score = torch.stack([s, sn], new_name='expert')
            #     probs = gate_log_probs.exp()
            #     loss = ((all_score * probs) * w.align_as(all_score)).sum()
            #     cnt += ((all_score[:, 0] < all_score[:, 1]) * w).sum()
            #     prob_cnt += ((probs[:, 0] > probs[:, 1]) * w).sum()
            #     weight = w.sum()
            #     total_loss += loss
            #     total_weight += weight
            #     loss = Metric(f'loss_{name.snake}', loss, weight)
            #     metrics += loss

            # # kld.
            # lp = gate_log_probs
            # _p = lp.exp().sum(0) / lp.exp().sum()
            # prior = get_tensor([g.prior_value, 1.0 - g.prior_value]).squeeze(dim=0)
            # kld = _p * (_p.log() - prior.log())

            # bs = lp.size('batch')
            # kld = Metric('kld', kld.sum() * bs, bs)
            # metrics += kld

            # metrics += Metric('loss', total_loss, total_weight)
            # metrics += Metric('loss', total_loss + kld.total, total_weight)

            # print('cnt', cnt / total_weight)
            # print('prob', prob_cnt / total_weight)
            return metrics
        else:
            return metrics
Esempio n. 5
0
    def _evaluate_one_dl(self, stage: str, dl: OnePairDataLoader) -> Metrics:
        records = list()
        K = 5
        for batch in pbar(dl, desc='eval: batch'):
            if g.eval_mode == 'edit_dist':
                batch_records = self._get_batch_records(dl, batch, K)
                records.extend(batch_records)
            else:
                scores = self.model.get_scores(batch, dl.tgt_seqs)
                top_scores, top_preds = torch.topk(scores, 5, dim='tgt_vocab')
                for pss, pis, gi in zip(top_scores, top_preds, batch.indices):
                    gold = dl.get_token_from_index(gi, 'tgt')
                    src = dl.get_token_from_index(gi, 'src')
                    record = {'source': src, 'gold_target': gold}
                    for i, (ps, pi) in enumerate(zip(pss, pis), 1):
                        pred = dl.get_token_from_index(pi, 'tgt')
                        record[f'pred_target@{i}'] = pred
                        record[f'pred_target@{i}_score'] = f'{ps:.3f}'
                    records.append(record)
        out_df = pd.DataFrame.from_records(records)
        values = ['gold_target']
        values.extend([f'pred_target@{i}' for i in range(1, K + 1)])
        values.extend([f'pred_target@{i}_score' for i in range(1, K + 1)])
        aggfunc = {'gold_target': '|'.join}
        aggfunc.update({f'pred_target@{i}': 'last' for i in range(1, K + 1)})
        aggfunc.update(
            {f'pred_target@{i}_score': 'last'
             for i in range(1, K + 1)})
        if g.eval_mode == 'edit_dist':
            values.extend(
                [f'pred_target_beam@{i}' for i in range(1, g.beam_size + 1)])
            values.extend([
                f'pred_target_beam@{i}_score'
                for i in range(1, g.beam_size + 1)
            ])
            values.extend(['edit_dist', 'normalized_edit_dist', 'ppx'])
            aggfunc.update({
                f'pred_target_beam@{i}': 'last'
                for i in range(1, g.beam_size + 1)
            })
            aggfunc.update({
                f'pred_target_beam@{i}_score': 'last'
                for i in range(1, g.beam_size + 1)
            })
            aggfunc.update({
                'edit_dist': min,
                'normalized_edit_dist': min,
                'ppx': min
            })
        out_df = out_df.pivot_table(index='source',
                                    values=values,
                                    aggfunc=aggfunc)

        def is_correct(item):
            pred, gold = item
            golds = gold.split('|')
            preds = pred.split('|')
            return bool(set(golds) & set(preds))

        for i in range(1, K + 1):
            correct = out_df[[f'pred_target@{i}',
                              'gold_target']].apply(is_correct, axis=1)
            if i > 1:
                correct = correct | out_df[f'correct@{i - 1}']
            out_df[f'correct@{i}'] = correct
        out_folder = g.log_dir / 'predictions'
        out_folder.mkdir(exist_ok=True)
        setting = dl.setting
        out_path = str(out_folder / f'{setting.name}.{stage}.tsv')
        out_df.to_csv(out_path, sep='\t')
        logging.info(f'Predictions saved to {out_path}.')

        num_pred = len(out_df)
        metrics = Metrics()
        for i in [1, K]:
            num_correct = out_df[f'correct@{i}'].sum()
            correct = Metric(f'precision@{i}', num_correct, weight=num_pred)
            metrics += correct
        metrics += Metric('edit_dist',
                          out_df['edit_dist'].sum(),
                          weight=num_pred)
        metrics += Metric('normalized_edit_dist',
                          out_df['normalized_edit_dist'].sum(),
                          weight=num_pred)
        metrics += Metric('ppx', out_df['ppx'].sum(), weight=num_pred)
        return metrics
Esempio n. 6
0
    def train_one_step(self, dl: OnePairDataLoader):
        if g.improved_player_only:
            self._old_state = self.agent.state_dict()
        # Collect episodes with the latest agent first.
        new_tr = self.mcts.collect_episodes(self.mcts.env.start, self.tracker)
        # new_tr = self.mcts.collect_episodes(dl.init_state, dl.end_state, self.tracker)
        tr_rew = Metric('reward', sum(tr.rewards.sum() for tr in new_tr),
                        g.num_episodes)
        tr_len = Metric('trajectory_length', sum(map(len, new_tr)),
                        g.num_episodes)
        success = Metric('success', sum(tr.done for tr in new_tr),
                         g.num_episodes)
        metrics = Metrics(tr_rew, tr_len, success)

        # Add these new episodes to the replay buffer.
        for i, tr in enumerate(new_tr, 1):
            global_step = i + self.tracker['step'].value * g.num_episodes
            self.metric_writer.add_scalar('episode_reward',
                                          tr.rewards.sum(),
                                          global_step=global_step)
            self.metric_writer.add_text('trajectory',
                                        str(tr),
                                        global_step=global_step)
            # NOTE(j_luo) Use temperature if it's positive.
            if g.tau > 0.0:
                weight = math.exp(tr.total_reward * 10.0)
            else:
                weight = 1.0

            for tr_edge in tr:
                self.replay_buffer.append(tr_edge, weight)

        # Main loop.
        from torch.optim import SGD, Adam
        optim_cls = Adam if g.optim_cls == 'adam' else SGD
        optim_kwargs = dict()
        if optim_cls == SGD:
            optim_kwargs['momentum'] = 0.9
        self.set_optimizer(optim_cls,
                           lr=g.learning_rate,
                           weight_decay=g.weight_decay,
                           **optim_kwargs)
        with self.agent.policy_grad(True), self.agent.value_grad(True):
            for _ in range(g.num_inner_steps):
                # Get a batch of training trajectories from the replay buffer.
                edge_batch = self.replay_buffer.sample(g.mcts_batch_size)
                # edge_batch = np.random.choice(self.replay_buffer, size=g.mcts_batch_size)
                agent_inputs = AgentInputs.from_edges(
                    edge_batch)  # , self.mcts.env)#, sparse=True)

                self.agent.train()
                self.optimizer.zero_grad()

                policies = self.agent.get_policy(agent_inputs.id_seqs,
                                                 almts=(agent_inputs.almts1,
                                                        agent_inputs.almts2))
                # print(policies[:, 2, self.mcts.env.abc['ẽ']].exp().mean())
                # values = self.agent.get_values(agent_inputs.id_seqs, steps=agent_inputs.steps)
                # breakpoint()  # BREAKPOINT(j_luo)
                with NoName(policies, agent_inputs.permissible_actions):
                    mask = agent_inputs.permissible_actions == SENTINEL_ID
                    pa = agent_inputs.permissible_actions
                    pa = torch.where(mask, torch.zeros_like(pa), pa)
                    logits = policies.gather(2, pa)
                    logits = torch.where(mask,
                                         torch.full_like(logits,
                                                         -9999.9), logits)
                    logits = logits.log_softmax(dim=-1)
                # r_max = agent_inputs.rewards.max()
                # r_min = agent_inputs.rewards.min()
                # weights = (agent_inputs.rewards - r_min) / (r_max - r_min + 1e-8)

                # weights = weights.align_as(pi_ce_losses)
                entropies = (-agent_inputs.mcts_pis *
                             (1e-8 + agent_inputs.mcts_pis).log()).sum(dim=-1)
                pi_ce_losses = (-agent_inputs.mcts_pis *
                                logits).sum(dim=-1) - entropies
                for i in range(7):
                    metrics += Metric(f'entropy_{i}', entropies[:, i].sum(),
                                      g.mcts_batch_size)
                    metrics += Metric(f'pi_ce_los_{i}', pi_ce_losses[:,
                                                                     i].sum(),
                                      g.mcts_batch_size)

                # v_regress_losses = 0.5 * (values - agent_inputs.qs) ** 2

                # pi_ce_loss = Metric('pi_ce_loss', (weights * pi_ce_losses).sum(), g.mcts_batch_size * 7)
                # mini_weights = get_tensor([1.0, 0.1, 1.0, 0.1, 0.1, 0.1, 0.1]).rename('mini').align_as(pi_ce_losses)
                # pi_ce_loss = Metric('pi_ce_loss', (mini_weights * pi_ce_losses).sum(), g.mcts_batch_size * 7)
                pi_ce_loss = Metric('pi_ce_loss', pi_ce_losses.sum(),
                                    g.mcts_batch_size * 7)
                # pi_ce_loss = Metric('pi_ce_loss', pi_ce_losses[:, 0].sum(), g.mcts_batch_size)
                # v_regress_loss = Metric('v_regress_loss', v_regress_losses.sum(), g.mcts_batch_size)
                total_loss = pi_ce_loss.total  # + g.regress_lambda * v_regress_loss.total
                total_loss = Metric('total_loss', total_loss,
                                    g.mcts_batch_size)

                total_loss.mean.backward()

                # Clip gradient norm.
                grad_norm = clip_grad(self.agent.parameters(),
                                      g.mcts_batch_size)
                # metrics += Metrics(total_loss, pi_ce_loss, v_regress_loss, grad_norm)
                metrics += Metrics(total_loss, pi_ce_loss, grad_norm)
                self.optimizer.step()
                self.tracker.update('inner_step')

        return metrics