def get_matching_stats(predictions: List[Segmentation], ground_truths: List[Segmentation], match_words: bool = False) -> Metrics: exact_span_matches = 0 prefix_span_matches = 0 if match_words: exact_word_matches = 0 prefix_word_matches = 0 for pred, gt in zip(predictions, ground_truths): for p in pred: for g in gt: if p.is_same_span(g): exact_span_matches += 1 prefix_span_matches += 1 if match_words and p.is_same_word(g): exact_word_matches += 1 prefix_word_matches += 1 elif p.is_prefix_span_of(g) or g.is_prefix_span_of(p): prefix_span_matches += 1 if match_words and (p.is_prefix_word_of(g) or g.is_prefix_word_of(p)): prefix_word_matches += 1 total_correct = sum(map(len, ground_truths)) total_pred = sum(map(len, predictions)) exact_span_matches = Metric(f'prf_exact_span_matches', exact_span_matches, 1.0, report_mean=False) prefix_span_matches = Metric(f'prf_prefix_span_matches', prefix_span_matches, 1.0, report_mean=False) total_correct = Metric(f'prf_total_correct', total_correct, 1.0, report_mean=False) total_pred = Metric(f'prf_total_pred', total_pred, 1.0, report_mean=False) metrics = Metrics(exact_span_matches, prefix_span_matches, total_correct, total_pred) if match_words: exact_word_matches = Metric(f'prf_exact_word_matches', exact_word_matches, 1.0, report_mean=False) prefix_word_matches = Metric(f'prf_prefix_word_matches', prefix_word_matches, 1.0, report_mean=False) metrics += Metrics(exact_word_matches, prefix_word_matches) return metrics
def train_one_step(self, dl: IpaDataLoader) -> Metrics: self.model.train() self.optimizer.zero_grad() batch = dl.get_next_batch() ret = self.model.score(batch) # for idx, segment in enumerate(batch.segments): # if str(segment).startswith('e-s-t-a-n'): # break # from xib.ipa import Name # name = Name('Ptype', 'camel') # print(torch.stack([ret.distr[name][0], ret.distr_noise[name][0]], new_name='tmp')[idx]) # import time; time.sleep(1) metrics = self.analyzer.analyze(ret) metrics.loss.mean.backward() grad_norm = get_grad_norm(self.model) grad_norm = Metric('grad_norm', grad_norm * len(batch), len(batch)) metrics += grad_norm self.optimizer.step() return metrics
def get_prf_scores(metrics: Metrics) -> Metrics: prf_scores = Metrics() def _get_f1(p, r): return 2 * p * r / (p + r + 1e-8) exact_span_matches = getattr(metrics, f'prf_exact_span_matches').total prefix_span_matches = getattr(metrics, f'prf_prefix_span_matches').total total_pred = getattr(metrics, f'prf_total_pred').total total_correct = getattr(metrics, f'prf_total_correct').total exact_span_precision = exact_span_matches / (total_pred + 1e-8) exact_span_recall = exact_span_matches / (total_correct + 1e-8) exact_span_f1 = _get_f1(exact_span_precision, exact_span_recall) prefix_span_precision = prefix_span_matches / (total_pred + 1e-8) prefix_span_recall = prefix_span_matches / (total_correct + 1e-8) prefix_span_f1 = _get_f1(prefix_span_precision, prefix_span_recall) prf_scores += Metric(f'prf_exact_span_precision', exact_span_precision, report_mean=False) prf_scores += Metric(f'prf_exact_span_recall', exact_span_recall, 1.0, report_mean=False) prf_scores += Metric(f'prf_exact_span_f1', exact_span_f1, 1.0, report_mean=False) prf_scores += Metric(f'prf_prefix_span_precision', prefix_span_precision, report_mean=False) prf_scores += Metric(f'prf_prefix_span_recall', prefix_span_recall, 1.0, report_mean=False) prf_scores += Metric(f'prf_prefix_span_f1', prefix_span_f1, 1.0, report_mean=False) return prf_scores
def analyze(self, ret: AdaptLMReturn) -> Metrics: metrics, scores = super().analyze(ret.distr, return_scores=True) # if g.use_moe: # # prior = get_tensor([g.prior_value, 1.0 - g.prior_value]).squeeze(dim=0) # # lp = ret.gate_log_probs # # kld = lp.exp() * (lp - prior.log()) # # kld. # lp = ret.gate_log_probs # _p = lp.exp().sum(0) / lp.exp().sum() # prior = get_tensor([g.prior_value, 1.0 - g.prior_value]).squeeze(dim=0) # kld = _p * (_p.log() - prior.log()) # bs = lp.size('batch') # kld = Metric('kld', kld.sum() * bs, bs) # metrics += kld # # sparsity. # _p = lp.exp() # with NoName(_p): # # sparsity = torch.nn.functional.softmin(torch.stack([_p, 1.0 - _p], dim=-1), dim=-1) # sparsity = torch.min(_p, 1.0 - _p) # sparsity = Metric('sparsity', sparsity.sum(), bs) # metrics += sparsity # metrics.rename('loss', 'ce_loss') # metrics += Metric('loss', metrics.ce_loss.total, bs) # # metrics += Metric('loss', metrics.ce_loss.total + kld.total, bs) # # metrics += Metric('loss', metrics.ce_loss.total + kld.total + sparsity.total, bs) if g.use_moe: metrics = Metrics() metrics_noise, scores_noise = super().analyze(ret.distr_noise, return_scores=True) total_loss = 0.0 total_weight = 0.0 cnt = 0 prob_cnt = 0 # gate_log_probs = ret.gate_logits.log_softmax(dim=-1) all_scores = [s for _, (s, _) in scores.items()] all_weights = [w for _, (_, w) in scores.items()] weight = all_weights[0] sum_scores = torch.stack(all_scores, new_name='stacked').sum(dim='stacked') batch_probs = ret.gate_logits.log_softmax( dim=-1).exp()[:, 0] * weight # + (-999.9) * (1.0 - weight)) # batch_probs = (ret.gate_logits[:, 0] + (-999.9) * (1.0 - weight)).log_softmax(dim='batch').exp() bs = batch_probs.size('batch') total = int(g.prior_value * weight.sum()) diff_loss = ((batch_probs.sum() - total)**2).sum() diff_loss = Metric('diff_loss', diff_loss, bs) loss = (sum_scores * batch_probs).sum() loss = Metric('loss', loss + diff_loss.total, bs) metrics += diff_loss metrics += loss # for name in scores: # s, w = scores[name] # sn, _ = scores_noise[name] # all_score = torch.stack([s, sn], new_name='expert') # probs = gate_log_probs.exp() # loss = ((all_score * probs) * w.align_as(all_score)).sum() # cnt += ((all_score[:, 0] < all_score[:, 1]) * w).sum() # prob_cnt += ((probs[:, 0] > probs[:, 1]) * w).sum() # weight = w.sum() # total_loss += loss # total_weight += weight # loss = Metric(f'loss_{name.snake}', loss, weight) # metrics += loss # # kld. # lp = gate_log_probs # _p = lp.exp().sum(0) / lp.exp().sum() # prior = get_tensor([g.prior_value, 1.0 - g.prior_value]).squeeze(dim=0) # kld = _p * (_p.log() - prior.log()) # bs = lp.size('batch') # kld = Metric('kld', kld.sum() * bs, bs) # metrics += kld # metrics += Metric('loss', total_loss, total_weight) # metrics += Metric('loss', total_loss + kld.total, total_weight) # print('cnt', cnt / total_weight) # print('prob', prob_cnt / total_weight) return metrics else: return metrics
def _evaluate_one_dl(self, stage: str, dl: OnePairDataLoader) -> Metrics: records = list() K = 5 for batch in pbar(dl, desc='eval: batch'): if g.eval_mode == 'edit_dist': batch_records = self._get_batch_records(dl, batch, K) records.extend(batch_records) else: scores = self.model.get_scores(batch, dl.tgt_seqs) top_scores, top_preds = torch.topk(scores, 5, dim='tgt_vocab') for pss, pis, gi in zip(top_scores, top_preds, batch.indices): gold = dl.get_token_from_index(gi, 'tgt') src = dl.get_token_from_index(gi, 'src') record = {'source': src, 'gold_target': gold} for i, (ps, pi) in enumerate(zip(pss, pis), 1): pred = dl.get_token_from_index(pi, 'tgt') record[f'pred_target@{i}'] = pred record[f'pred_target@{i}_score'] = f'{ps:.3f}' records.append(record) out_df = pd.DataFrame.from_records(records) values = ['gold_target'] values.extend([f'pred_target@{i}' for i in range(1, K + 1)]) values.extend([f'pred_target@{i}_score' for i in range(1, K + 1)]) aggfunc = {'gold_target': '|'.join} aggfunc.update({f'pred_target@{i}': 'last' for i in range(1, K + 1)}) aggfunc.update( {f'pred_target@{i}_score': 'last' for i in range(1, K + 1)}) if g.eval_mode == 'edit_dist': values.extend( [f'pred_target_beam@{i}' for i in range(1, g.beam_size + 1)]) values.extend([ f'pred_target_beam@{i}_score' for i in range(1, g.beam_size + 1) ]) values.extend(['edit_dist', 'normalized_edit_dist', 'ppx']) aggfunc.update({ f'pred_target_beam@{i}': 'last' for i in range(1, g.beam_size + 1) }) aggfunc.update({ f'pred_target_beam@{i}_score': 'last' for i in range(1, g.beam_size + 1) }) aggfunc.update({ 'edit_dist': min, 'normalized_edit_dist': min, 'ppx': min }) out_df = out_df.pivot_table(index='source', values=values, aggfunc=aggfunc) def is_correct(item): pred, gold = item golds = gold.split('|') preds = pred.split('|') return bool(set(golds) & set(preds)) for i in range(1, K + 1): correct = out_df[[f'pred_target@{i}', 'gold_target']].apply(is_correct, axis=1) if i > 1: correct = correct | out_df[f'correct@{i - 1}'] out_df[f'correct@{i}'] = correct out_folder = g.log_dir / 'predictions' out_folder.mkdir(exist_ok=True) setting = dl.setting out_path = str(out_folder / f'{setting.name}.{stage}.tsv') out_df.to_csv(out_path, sep='\t') logging.info(f'Predictions saved to {out_path}.') num_pred = len(out_df) metrics = Metrics() for i in [1, K]: num_correct = out_df[f'correct@{i}'].sum() correct = Metric(f'precision@{i}', num_correct, weight=num_pred) metrics += correct metrics += Metric('edit_dist', out_df['edit_dist'].sum(), weight=num_pred) metrics += Metric('normalized_edit_dist', out_df['normalized_edit_dist'].sum(), weight=num_pred) metrics += Metric('ppx', out_df['ppx'].sum(), weight=num_pred) return metrics
def train_one_step(self, dl: OnePairDataLoader): if g.improved_player_only: self._old_state = self.agent.state_dict() # Collect episodes with the latest agent first. new_tr = self.mcts.collect_episodes(self.mcts.env.start, self.tracker) # new_tr = self.mcts.collect_episodes(dl.init_state, dl.end_state, self.tracker) tr_rew = Metric('reward', sum(tr.rewards.sum() for tr in new_tr), g.num_episodes) tr_len = Metric('trajectory_length', sum(map(len, new_tr)), g.num_episodes) success = Metric('success', sum(tr.done for tr in new_tr), g.num_episodes) metrics = Metrics(tr_rew, tr_len, success) # Add these new episodes to the replay buffer. for i, tr in enumerate(new_tr, 1): global_step = i + self.tracker['step'].value * g.num_episodes self.metric_writer.add_scalar('episode_reward', tr.rewards.sum(), global_step=global_step) self.metric_writer.add_text('trajectory', str(tr), global_step=global_step) # NOTE(j_luo) Use temperature if it's positive. if g.tau > 0.0: weight = math.exp(tr.total_reward * 10.0) else: weight = 1.0 for tr_edge in tr: self.replay_buffer.append(tr_edge, weight) # Main loop. from torch.optim import SGD, Adam optim_cls = Adam if g.optim_cls == 'adam' else SGD optim_kwargs = dict() if optim_cls == SGD: optim_kwargs['momentum'] = 0.9 self.set_optimizer(optim_cls, lr=g.learning_rate, weight_decay=g.weight_decay, **optim_kwargs) with self.agent.policy_grad(True), self.agent.value_grad(True): for _ in range(g.num_inner_steps): # Get a batch of training trajectories from the replay buffer. edge_batch = self.replay_buffer.sample(g.mcts_batch_size) # edge_batch = np.random.choice(self.replay_buffer, size=g.mcts_batch_size) agent_inputs = AgentInputs.from_edges( edge_batch) # , self.mcts.env)#, sparse=True) self.agent.train() self.optimizer.zero_grad() policies = self.agent.get_policy(agent_inputs.id_seqs, almts=(agent_inputs.almts1, agent_inputs.almts2)) # print(policies[:, 2, self.mcts.env.abc['ẽ']].exp().mean()) # values = self.agent.get_values(agent_inputs.id_seqs, steps=agent_inputs.steps) # breakpoint() # BREAKPOINT(j_luo) with NoName(policies, agent_inputs.permissible_actions): mask = agent_inputs.permissible_actions == SENTINEL_ID pa = agent_inputs.permissible_actions pa = torch.where(mask, torch.zeros_like(pa), pa) logits = policies.gather(2, pa) logits = torch.where(mask, torch.full_like(logits, -9999.9), logits) logits = logits.log_softmax(dim=-1) # r_max = agent_inputs.rewards.max() # r_min = agent_inputs.rewards.min() # weights = (agent_inputs.rewards - r_min) / (r_max - r_min + 1e-8) # weights = weights.align_as(pi_ce_losses) entropies = (-agent_inputs.mcts_pis * (1e-8 + agent_inputs.mcts_pis).log()).sum(dim=-1) pi_ce_losses = (-agent_inputs.mcts_pis * logits).sum(dim=-1) - entropies for i in range(7): metrics += Metric(f'entropy_{i}', entropies[:, i].sum(), g.mcts_batch_size) metrics += Metric(f'pi_ce_los_{i}', pi_ce_losses[:, i].sum(), g.mcts_batch_size) # v_regress_losses = 0.5 * (values - agent_inputs.qs) ** 2 # pi_ce_loss = Metric('pi_ce_loss', (weights * pi_ce_losses).sum(), g.mcts_batch_size * 7) # mini_weights = get_tensor([1.0, 0.1, 1.0, 0.1, 0.1, 0.1, 0.1]).rename('mini').align_as(pi_ce_losses) # pi_ce_loss = Metric('pi_ce_loss', (mini_weights * pi_ce_losses).sum(), g.mcts_batch_size * 7) pi_ce_loss = Metric('pi_ce_loss', pi_ce_losses.sum(), g.mcts_batch_size * 7) # pi_ce_loss = Metric('pi_ce_loss', pi_ce_losses[:, 0].sum(), g.mcts_batch_size) # v_regress_loss = Metric('v_regress_loss', v_regress_losses.sum(), g.mcts_batch_size) total_loss = pi_ce_loss.total # + g.regress_lambda * v_regress_loss.total total_loss = Metric('total_loss', total_loss, g.mcts_batch_size) total_loss.mean.backward() # Clip gradient norm. grad_norm = clip_grad(self.agent.parameters(), g.mcts_batch_size) # metrics += Metrics(total_loss, pi_ce_loss, v_regress_loss, grad_norm) metrics += Metrics(total_loss, pi_ce_loss, grad_norm) self.optimizer.step() self.tracker.update('inner_step') return metrics