예제 #1
0
def test_update_convergence(nlp, labels, docs_golds):
    nlp.pipeline[-1][1]
    optim = nlp.resume_training()
    texts, labels = zip(*docs_golds)
    prev = 1e9
    for i in range(10):
        docs = [nlp.make_doc(text) for text in texts]
        nlp.update(docs, labels, optim)
        assert prev > get_loss_from_docs(docs).item()
예제 #2
0
 def _update_params(self,
                    docs: Sequence[Doc],
                    optimizer: Optimizer,
                    verbose: bool = False):
     loss = get_loss_from_docs(docs)
     optimizer.zero_grad()
     loss.backward()
     optimizer.step()
     if verbose:
         logger.info(f"Loss: {loss.detach().item()}")
예제 #3
0
 def evaluate(  # type: ignore
     self,
     docs_golds: Sequence[Tuple[Union[str, Doc], Union[Dict[str, Any],
                                                       GoldParse]]],
     batch_size: int = 16,
     scorer: Optional[Scorer] = None,
 ) -> Dict[str, Any]:
     """Overrided to compute loss"""
     if scorer is None:
         scorer = Scorer(pipeline=self.pipeline)
     loss = 0.0
     for batch in minibatch(docs_golds, size=batch_size):
         docs, golds = zip(*batch)
         docs, golds = self._format_docs_and_golds(docs,
                                                   golds)  # type: ignore
         for _, pipe in self.pipeline:
             self._eval_pipe(pipe, docs, golds, batch_size=batch_size)
         loss += cast(float, get_loss_from_docs(docs).cpu().float().item())
         for doc, gold in zip(docs, golds):
             scorer.score(doc, gold)
     scores = scorer.scores
     scores["loss"] = loss
     return scores
예제 #4
0
def test_update(nlp, labels, docs_golds):
    optim = nlp.resume_training()
    texts, labels = zip(*docs_golds)
    docs = [nlp.make_doc(text) for text in texts]
    nlp.update(docs, labels, optim)
    assert get_loss_from_docs(docs)
예제 #5
0
def test_update(nlp):
    docs = [nlp.make_doc(text) for text in TEXTS]
    optim = nlp.resume_training()
    nlp.update(docs, [{} for _ in range(len(docs))], optim)
    loss = get_loss_from_docs(docs)
    assert loss.item() > 0