def test_convert_recursive():
    is_match = lambda obj: obj == "foo"
    convert_item = lambda obj: obj.upper()
    obj = {
        "a": {
            ("b", "foo"): {
                "c": "foo",
                "d": ["foo", {
                    "e": "foo",
                    "f": (1, "foo")
                }]
            }
        }
    }
    result = convert_recursive(is_match, convert_item, obj)
    assert result["a"][("b", "FOO")]["c"] == "FOO"
    assert result["a"][("b",
                        "FOO")]["d"] == ["FOO", {
                            "e": "FOO",
                            "f": (1, "FOO")
                        }]
    obj = {"a": ArgsKwargs(("foo", [{"b": "foo"}]), {"a": ["x", "foo"]})}
    result = convert_recursive(is_match, convert_item, obj)
    assert result["a"].args == ("FOO", [{"b": "FOO"}])
    assert result["a"].kwargs == {"a": ["x", "FOO"]}
示例#2
0
def convert_transformer_inputs(model, tokens: TokensPlus, is_train):
    kwargs = {
        "input_ids": tokens.input_ids,
        "attention_mask": tokens.attention_mask,
        "token_type_ids": tokens.token_type_ids,
    }
    return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
示例#3
0
 def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
     # Restore entries for bos and eos markers.
     row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
     d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
     return ArgsKwargs(
         args=(torch_tokvecs,),
         kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
     )
示例#4
0
def _convert_transformer_inputs(model, tokens: BatchEncoding, is_train):
    # Adapter for the PyTorchWrapper. See https://thinc.ai/docs/usage-frameworks
    kwargs = {
        "input_ids": tokens["input_ids"],
        "attention_mask": tokens["attention_mask"],
    }
    if "token_type_ids" in tokens:
        kwargs["token_type_ids"] = tokens["token_type_ids"]
    return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
def _convert_transformer_inputs(model, wps: WordpieceBatch, is_train):
    # Adapter for the PyTorchWrapper. See https://thinc.ai/docs/usage-frameworks
    kwargs = {
        "input_ids": xp2torch(wps.input_ids),
        "attention_mask": xp2torch(wps.attention_mask),
    }
    if wps.token_type_ids is not None:
        kwargs["token_type_ids"] = xp2torch(wps.token_type_ids)
    return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
def _convert_transformer_inputs(model, wps: WordpieceBatch, is_train):
    # Adapter for the HFWrapper. See https://thinc.ai/docs/usage-frameworks

    hf_device = model.shims[0]._hfmodel.transformer.device
    kwargs = {
        # Note: remove conversion to long when PyTorch >= 1.8.0.
        "input_ids": xp2torch(wps.input_ids).long().to(device=hf_device),
        "attention_mask": xp2torch(wps.attention_mask).to(device=hf_device),
    }
    if wps.token_type_ids is not None:
        kwargs["token_type_ids"] = (xp2torch(
            wps.token_type_ids).long().to(device=hf_device))
    return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
示例#7
0
 def backprop(d_tensors: List[torch.Tensor]) -> ArgsKwargs:
     return ArgsKwargs(args=(tensors, ), kwargs={"grad_tensors": d_tensors})
示例#8
0
 def backprop(d_model_output: ModelOutput) -> ArgsKwargs:
     return ArgsKwargs(
         args=(model_output.last_hidden_state, ),
         kwargs={"grad_tensors": d_model_output.values()},
     )