def train_one_step(self, dl: ContinuousTextDataLoader) -> Metrics: self.model.train() self.optimizer.zero_grad() accum_metrics = Metrics() for _ in pbar(range(g.accum_gradients), desc='accum_gradients'): batch = dl.get_next_batch() ret = self.model(batch) metrics = self.analyzer.analyze(ret, batch) loss = -metrics.ll.mean try: loss = loss + metrics.reg.mean * g.reg_hyper except AttributeError: pass loss_per_split = loss / g.accum_gradients loss_per_split.backward() accum_metrics += metrics grad_norm = clip_grad_norm_(self.model.parameters(), 5.0) self.optimizer.step() accum_metrics += Metric('grad_norm', grad_norm * batch.batch_size, batch.batch_size) return accum_metrics
def get_scores(self, batch: OnePairBatch, tgt_vocab_seqs: PaddedUnitSeqs, chunk_size: int = 100) -> FT: """Given a batch and a list of target tokens (provided as id sequences), return scores produced by the model.""" src_emb, (output, state) = self.encoder(batch.src_seqs.ids, batch.src_seqs.lengths) src_emb = src_emb.refine_names('pos', 'batch', 'src_emb') output = output.refine_names('pos', 'batch', 'output') batch_size = src_emb.size('batch') lang_emb = self._prepare_lang_emb(batch) def create_chunk(size, base, old_chunk, interleave: bool = True): if not interleave: return base.repeat(1, batch_size) if old_chunk is not None and old_chunk.size( 'batch') == batch_size * size: return old_chunk new_chunk = torch.repeat_interleave(base, size, dim='batch') return new_chunk chunk_src_emb = None chunk_output = None chunk_src_paddings = None scores = list() for split in pbar(tgt_vocab_seqs.split(chunk_size), desc='Get scores: chunk'): split: PaddedUnitSeqs bs_split = len(split) chunk_src_emb = create_chunk(bs_split, src_emb, chunk_src_emb) chunk_output = create_chunk(bs_split, output, chunk_output) chunk_src_paddings = create_chunk(bs_split, batch.src_seqs.paddings, chunk_src_paddings) chunk_target = create_chunk(None, split.ids, None, interleave=False) chunk_tgt_paddings = create_chunk(None, split.paddings, None, interleave=False) chunk_log_probs, _ = self.decoder(SOT_ID, chunk_src_emb, chunk_output, chunk_src_paddings, target=chunk_target, lang_emb=lang_emb) chunk_scores = chunk_log_probs.gather('unit', chunk_target) chunk_scores = (chunk_scores * chunk_tgt_paddings).sum('pos') with NoName(chunk_scores): scores.append( chunk_scores.view(batch_size, bs_split).refine_names( 'batch', 'tgt_vocab')) scores = torch.cat(scores, dim='tgt_vocab') return scores
def evaluate(self, dl: ContinuousTextDataLoader) -> Metrics: segments = list() ground_truths = list() predictions = list() for batch in pbar(dl, desc='eval_batch'): for segment in batch.segments: segments.append(segment) ground_truth = segment.to_segmentation() ground_truths.append(ground_truth) best_value, best_state = self.solver.find_best(segment) prediction = Segmentation(best_state.spans) predictions.append(prediction) df = _get_df(segments, ground_truths, predictions) out_path = g.log_dir / 'predictions' / 'search_solver.tsv' out_path.parent.mkdir(exist_ok=True, parents=True) df.to_csv(out_path, index=None, sep='\t') matching_stats = get_matching_stats(predictions, ground_truths) prf_scores = get_prf_scores(matching_stats) return matching_stats + prf_scores
def evaluate(self, stage: str) -> Metrics: segments = list() predictions = list() ground_truths = list() matched_segments = list() total_num_samples = 0 analyzed_metrics = Metrics() for batch in pbar(self.dl, desc='eval_batch'): if g.eval_max_num_samples and total_num_samples + batch.batch_size > g.eval_max_num_samples: logging.imp( f'Stopping at {total_num_samples} < {g.eval_max_num_samples} evaluated examples.' ) break ret = self.model(batch) analyzed_metrics += self.analyzer.analyze(ret, batch) segments.extend(list(batch.segments)) segmentations, _matched_segments = self._get_segmentations( ret, batch) predictions.extend(segmentations) matched_segments.extend(_matched_segments) ground_truths.extend( [segment.to_segmentation() for segment in batch.segments]) total_num_samples += batch.batch_size df = _get_df(segments, ground_truths, predictions, matched_segments, columns=('segment', 'ground_truth', 'prediction', 'matched_segment')) out_path = g.log_dir / 'predictions' / f'extract.{stage}.tsv' out_path.parent.mkdir(exist_ok=True, parents=True) df.to_csv(out_path, index=None, sep='\t') matching_stats = get_matching_stats(predictions, ground_truths) prf_scores = get_prf_scores(matching_stats) return analyzed_metrics + matching_stats + prf_scores