def test_convert_recursive(): is_match = lambda obj: obj == "foo" convert_item = lambda obj: obj.upper() obj = { "a": { ("b", "foo"): { "c": "foo", "d": ["foo", { "e": "foo", "f": (1, "foo") }] } } } result = convert_recursive(is_match, convert_item, obj) assert result["a"][("b", "FOO")]["c"] == "FOO" assert result["a"][("b", "FOO")]["d"] == ["FOO", { "e": "FOO", "f": (1, "FOO") }] obj = {"a": ArgsKwargs(("foo", [{"b": "foo"}]), {"a": ["x", "foo"]})} result = convert_recursive(is_match, convert_item, obj) assert result["a"].args == ("FOO", [{"b": "FOO"}]) assert result["a"].kwargs == {"a": ["x", "FOO"]}
def convert_transformer_inputs(model, tokens: TokensPlus, is_train): kwargs = { "input_ids": tokens.input_ids, "attention_mask": tokens.attention_mask, "token_type_ids": tokens.token_type_ids, } return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs: # Restore entries for bos and eos markers. row = model.ops.alloc2f(1, d_tokvecs[0].shape[1]) d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs] return ArgsKwargs( args=(torch_tokvecs,), kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))}, )
def _convert_transformer_inputs(model, tokens: BatchEncoding, is_train): # Adapter for the PyTorchWrapper. See https://thinc.ai/docs/usage-frameworks kwargs = { "input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"], } if "token_type_ids" in tokens: kwargs["token_type_ids"] = tokens["token_type_ids"] return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
def _convert_transformer_inputs(model, wps: WordpieceBatch, is_train): # Adapter for the PyTorchWrapper. See https://thinc.ai/docs/usage-frameworks kwargs = { "input_ids": xp2torch(wps.input_ids), "attention_mask": xp2torch(wps.attention_mask), } if wps.token_type_ids is not None: kwargs["token_type_ids"] = xp2torch(wps.token_type_ids) return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
def _convert_transformer_inputs(model, wps: WordpieceBatch, is_train): # Adapter for the HFWrapper. See https://thinc.ai/docs/usage-frameworks hf_device = model.shims[0]._hfmodel.transformer.device kwargs = { # Note: remove conversion to long when PyTorch >= 1.8.0. "input_ids": xp2torch(wps.input_ids).long().to(device=hf_device), "attention_mask": xp2torch(wps.attention_mask).to(device=hf_device), } if wps.token_type_ids is not None: kwargs["token_type_ids"] = (xp2torch( wps.token_type_ids).long().to(device=hf_device)) return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []
def backprop(d_tensors: List[torch.Tensor]) -> ArgsKwargs: return ArgsKwargs(args=(tensors, ), kwargs={"grad_tensors": d_tensors})
def backprop(d_model_output: ModelOutput) -> ArgsKwargs: return ArgsKwargs( args=(model_output.last_hidden_state, ), kwargs={"grad_tensors": d_model_output.values()}, )