def _convert_transformer_inputs(model, wps: WordpieceBatch, is_train):
    # Adapter for the PyTorchWrapper. See https://thinc.ai/docs/usage-frameworks
    kwargs = {
        "input_ids": xp2torch(wps.input_ids),
        "attention_mask": xp2torch(wps.attention_mask),
    }
    if wps.token_type_ids is not None:
        kwargs["token_type_ids"] = xp2torch(wps.token_type_ids)
    return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
 def to_hf_dict(self) -> Dict:
     """Return a dict similar to the format produced by the Huggingface
     tokenizer, converting arrays to pytorch tensors as well.
     """
     output = {
         "input_ids": xp2torch(self.input_ids),
         "attention_mask": xp2torch(self.attention_mask),
         "input_texts": self.strings,
     }
     if self.token_type_ids is not None:
         output["token_type_ids"] = xp2torch(self.token_type_ids)
     return output
def _convert_transformer_inputs(model, wps: WordpieceBatch, is_train):
    # Adapter for the HFWrapper. See https://thinc.ai/docs/usage-frameworks

    hf_device = model.shims[0]._hfmodel.transformer.device
    kwargs = {
        # Note: remove conversion to long when PyTorch >= 1.8.0.
        "input_ids": xp2torch(wps.input_ids).long().to(device=hf_device),
        "attention_mask": xp2torch(wps.attention_mask).to(device=hf_device),
    }
    if wps.token_type_ids is not None:
        kwargs["token_type_ids"] = (xp2torch(
            wps.token_type_ids).long().to(device=hf_device))
    return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
예제 #4
0
 def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
     # Restore entries for bos and eos markers.
     row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
     d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
     return ArgsKwargs(
         args=(torch_tokvecs,),
         kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
     )
def test_pytorch_roundtrip_conversion():
    import torch

    xp_tensor = numpy.zeros((2, 3), dtype="f")
    torch_tensor = xp2torch(xp_tensor)
    assert isinstance(torch_tensor, torch.Tensor)
    new_xp_tensor = torch2xp(torch_tensor)
    assert numpy.array_equal(xp_tensor, new_xp_tensor)
예제 #6
0
 def unsplit_by_doc(self,
                    arrays: List[List[Floats3d]]) -> "FullTransformerBatch":
     xp = get_array_module(arrays[0][0])
     return FullTransformerBatch(
         spans=self.spans,
         tokens=self.tokens,
         tensors=[xp2torch(xp.vstack(x)) for x in transpose_list(arrays)],
         align=self.align,
     )
예제 #7
0
    def unsplit_by_doc(self,
                       arrays: List[List[Floats3d]]) -> "FullTransformerBatch":
        """Return a new FullTransformerBatch from a split batch of activations,
        using the current object's spans, wordpieces and alignment.

        This is used during the backward pass, in order to construct the gradients
        to pass back into the transformer model.
        """
        xp = get_array_module(arrays[0][0])
        return FullTransformerBatch(
            spans=self.spans,
            wordpieces=self.wordpieces,
            tensors=[xp2torch(xp.vstack(x)) for x in transpose_list(arrays)],
            align=self.align,
        )
    def unsplit_by_doc(self,
                       arrays: List[List[Floats3d]]) -> "FullTransformerBatch":
        """Return a new FullTransformerBatch from a split batch of activations,
        using the current object's spans, wordpieces and alignment.

        This is used during the backward pass, in order to construct the gradients
        to pass back into the transformer model.
        """
        xp = get_array_module(arrays[0][0])
        # construct a dummy ModelOutput with the tensor values
        model_output = ModelOutput()
        for i, x in enumerate(transpose_list(arrays)):
            model_output[f"output_{i}"] = xp2torch(xp.vstack(x))
        return FullTransformerBatch(
            spans=self.spans,
            wordpieces=self.wordpieces,
            model_output=model_output,
            align=self.align,
        )