예제 #1
0
    def __call__(self, engine: Engine) -> None:
        output = self._output_transform(engine.state.output)

        def raise_error(x: Union[float, torch.Tensor]) -> None:

            if isinstance(x, numbers.Number):
                x = torch.tensor(x)

            if isinstance(x, torch.Tensor) and not bool(torch.isfinite(x).all()):
                raise RuntimeError("Infinite or NaN tensor found.")

        try:
            apply_to_type(output, (numbers.Number, torch.Tensor), raise_error)
        except RuntimeError:
            self.logger.warning(f"{self.__class__.__name__}: Output '{output}' contains NaN or Inf. Stop training")
            engine.terminate()
예제 #2
0
 def _update(engine, batch):
     model.train()
     optimizer.zero_grad()
     batch = apply_to_type(batch, torch.Tensor, lambda x: x.to(device))
     scores = model(batch)
     loss = -scores.mean()
     if max_grad_norm is not None:
         torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
     loss.backward()
     optimizer.step()
     return loss.item()
예제 #3
0
 def _inference(engine, batch):
     model.eval()
     with torch.no_grad():
         batch = apply_to_type(batch, torch.Tensor, lambda x: x.to(device))
         scores = model(batch, reduce=False)
         return scores
예제 #4
0
model.load_state_dict(torch.load("my_model/nmt_mymodel_28.pth"))
model.to(model.device)
test_iterator = BasicIterator(batch_size=1)
test_iterator.index_with(vocab)

all_hyps = []
all_refs = []

for idx, test_data_sample in tqdm(
        enumerate(test_iterator(instances, shuffle=False, num_epochs=1)),
        total=len(instances),
):
    # noinspection PyTypeChecker
    with torch.no_grad():
        hyps = model.beam_search(
            apply_to_type(test_data_sample, torch.Tensor,
                          lambda t: t.to(model.device)),
            TokenCharactersIndexer(
                "char_trg",
                character_tokenizer=MyCharacterTokenizer(max_length=21),
            ))
    best_hyp = hyps[0].value
    # try:
    ref = [str(t) for t in instances[idx]["target_sentence"].tokens[1:-1]]
    all_hyps.append(best_hyp)
    all_refs.append([ref])
    if idx % 100 == 0:
        print(" ".join(best_hyp), "->", " ".join(ref))
    # except:
    #     pass
    # if idx == 100:
    #     break