Пример #1
0
    def train_one_step(self, dl: ContinuousTextDataLoader) -> Metrics:
        self.model.train()
        self.optimizer.zero_grad()
        accum_metrics = Metrics()

        for _ in pbar(range(g.accum_gradients), desc='accum_gradients'):
            batch = dl.get_next_batch()
            ret = self.model(batch)
            metrics = self.analyzer.analyze(ret, batch)

            loss = -metrics.ll.mean
            try:
                loss = loss + metrics.reg.mean * g.reg_hyper
            except AttributeError:
                pass
            loss_per_split = loss / g.accum_gradients
            loss_per_split.backward()

            accum_metrics += metrics

        grad_norm = clip_grad_norm_(self.model.parameters(), 5.0)
        self.optimizer.step()
        accum_metrics += Metric('grad_norm', grad_norm * batch.batch_size, batch.batch_size)

        return accum_metrics
Пример #2
0
    def get_scores(self,
                   batch: OnePairBatch,
                   tgt_vocab_seqs: PaddedUnitSeqs,
                   chunk_size: int = 100) -> FT:
        """Given a batch and a list of target tokens (provided as id sequences), return scores produced by the model."""
        src_emb, (output, state) = self.encoder(batch.src_seqs.ids,
                                                batch.src_seqs.lengths)
        src_emb = src_emb.refine_names('pos', 'batch', 'src_emb')
        output = output.refine_names('pos', 'batch', 'output')
        batch_size = src_emb.size('batch')
        lang_emb = self._prepare_lang_emb(batch)

        def create_chunk(size, base, old_chunk, interleave: bool = True):
            if not interleave:
                return base.repeat(1, batch_size)

            if old_chunk is not None and old_chunk.size(
                    'batch') == batch_size * size:
                return old_chunk

            new_chunk = torch.repeat_interleave(base, size, dim='batch')
            return new_chunk

        chunk_src_emb = None
        chunk_output = None
        chunk_src_paddings = None
        scores = list()
        for split in pbar(tgt_vocab_seqs.split(chunk_size),
                          desc='Get scores: chunk'):
            split: PaddedUnitSeqs
            bs_split = len(split)
            chunk_src_emb = create_chunk(bs_split, src_emb, chunk_src_emb)
            chunk_output = create_chunk(bs_split, output, chunk_output)
            chunk_src_paddings = create_chunk(bs_split,
                                              batch.src_seqs.paddings,
                                              chunk_src_paddings)
            chunk_target = create_chunk(None,
                                        split.ids,
                                        None,
                                        interleave=False)
            chunk_tgt_paddings = create_chunk(None,
                                              split.paddings,
                                              None,
                                              interleave=False)
            chunk_log_probs, _ = self.decoder(SOT_ID,
                                              chunk_src_emb,
                                              chunk_output,
                                              chunk_src_paddings,
                                              target=chunk_target,
                                              lang_emb=lang_emb)
            chunk_scores = chunk_log_probs.gather('unit', chunk_target)
            chunk_scores = (chunk_scores * chunk_tgt_paddings).sum('pos')
            with NoName(chunk_scores):
                scores.append(
                    chunk_scores.view(batch_size, bs_split).refine_names(
                        'batch', 'tgt_vocab'))
        scores = torch.cat(scores, dim='tgt_vocab')
        return scores
Пример #3
0
    def evaluate(self, dl: ContinuousTextDataLoader) -> Metrics:
        segments = list()
        ground_truths = list()
        predictions = list()
        for batch in pbar(dl, desc='eval_batch'):
            for segment in batch.segments:
                segments.append(segment)
                ground_truth = segment.to_segmentation()
                ground_truths.append(ground_truth)

                best_value, best_state = self.solver.find_best(segment)
                prediction = Segmentation(best_state.spans)
                predictions.append(prediction)

        df = _get_df(segments, ground_truths, predictions)
        out_path = g.log_dir / 'predictions' / 'search_solver.tsv'
        out_path.parent.mkdir(exist_ok=True, parents=True)
        df.to_csv(out_path, index=None, sep='\t')
        matching_stats = get_matching_stats(predictions, ground_truths)
        prf_scores = get_prf_scores(matching_stats)
        return matching_stats + prf_scores
Пример #4
0
    def evaluate(self, stage: str) -> Metrics:
        segments = list()
        predictions = list()
        ground_truths = list()
        matched_segments = list()
        total_num_samples = 0
        analyzed_metrics = Metrics()
        for batch in pbar(self.dl, desc='eval_batch'):

            if g.eval_max_num_samples and total_num_samples + batch.batch_size > g.eval_max_num_samples:
                logging.imp(
                    f'Stopping at {total_num_samples} < {g.eval_max_num_samples} evaluated examples.'
                )
                break

            ret = self.model(batch)
            analyzed_metrics += self.analyzer.analyze(ret, batch)

            segments.extend(list(batch.segments))
            segmentations, _matched_segments = self._get_segmentations(
                ret, batch)
            predictions.extend(segmentations)
            matched_segments.extend(_matched_segments)
            ground_truths.extend(
                [segment.to_segmentation() for segment in batch.segments])
            total_num_samples += batch.batch_size

        df = _get_df(segments,
                     ground_truths,
                     predictions,
                     matched_segments,
                     columns=('segment', 'ground_truth', 'prediction',
                              'matched_segment'))
        out_path = g.log_dir / 'predictions' / f'extract.{stage}.tsv'
        out_path.parent.mkdir(exist_ok=True, parents=True)
        df.to_csv(out_path, index=None, sep='\t')
        matching_stats = get_matching_stats(predictions, ground_truths)
        prf_scores = get_prf_scores(matching_stats)
        return analyzed_metrics + matching_stats + prf_scores