예제 #1
0
파일: test_ops.py 프로젝트: admariner/thinc
def test_use_ops():
    class_ops = get_current_ops()
    with use_ops("numpy"):
        new_ops = get_current_ops()
        assert new_ops.name == "numpy"
    with use_ops("cupy"):
        new_ops = get_current_ops()
        assert new_ops.name == "cupy"
    new_ops = get_current_ops()
    assert class_ops.name == new_ops.name
예제 #2
0
def test_prefer_gpu():
    current_ops = get_current_ops()
    try:
        import cupy  # noqa: F401

        prefer_gpu()
        assert isinstance(get_current_ops(), CupyOps)
    except ImportError:
        assert not prefer_gpu()
    set_current_ops(current_ops)
예제 #3
0
def test_require_gpu():
    current_ops = get_current_ops()
    try:
        import cupy  # noqa: F401

        require_gpu()
        assert isinstance(get_current_ops(), CupyOps)
    except ImportError:
        with pytest.raises(ValueError):
            require_gpu()
    set_current_ops(current_ops)
예제 #4
0
def test_require_cpu():
    require_cpu()
    assert isinstance(get_current_ops(), NumpyOps)
    try:
        import cupy  # noqa: F401

        require_gpu()
        assert isinstance(get_current_ops(), CupyOps)
    except ImportError:
        pass
    require_cpu()
    assert isinstance(get_current_ops(), NumpyOps)
예제 #5
0
def test_tok2vec_listener(with_vectors):
    orig_config = Config().from_str(cfg_string)
    orig_config["components"]["tok2vec"]["model"]["embed"][
        "include_static_vectors"] = with_vectors
    nlp = util.load_model_from_config(orig_config,
                                      auto_fill=True,
                                      validate=True)

    if with_vectors:
        ops = get_current_ops()
        vectors = [
            ("apple", ops.asarray([1, 2, 3])),
            ("orange", ops.asarray([-1, -2, -3])),
            ("and", ops.asarray([-1, -1, -1])),
            ("juice", ops.asarray([5, 5, 10])),
            ("pie", ops.asarray([7, 6.3, 8.9])),
        ]
        add_vecs_to_vocab(nlp.vocab, vectors)

    assert nlp.pipe_names == ["tok2vec", "tagger"]
    tagger = nlp.get_pipe("tagger")
    tok2vec = nlp.get_pipe("tok2vec")
    tagger_tok2vec = tagger.model.get_ref("tok2vec")
    assert isinstance(tok2vec, Tok2Vec)
    assert isinstance(tagger_tok2vec, Tok2VecListener)
    train_examples = []
    for t in TRAIN_DATA:
        train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
        for tag in t[1]["tags"]:
            tagger.add_label(tag)

    # Check that the Tok2Vec component finds it listeners
    assert tok2vec.listeners == []
    optimizer = nlp.initialize(lambda: train_examples)
    assert tok2vec.listeners == [tagger_tok2vec]

    for i in range(5):
        losses = {}
        nlp.update(train_examples, sgd=optimizer, losses=losses)

    doc = nlp("Running the pipeline as a whole.")
    doc_tensor = tagger_tok2vec.predict([doc])[0]
    ops = get_current_ops()
    assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))

    # test with empty doc
    doc = nlp("")

    # TODO: should this warn or error?
    nlp.select_pipes(disable="tok2vec")
    assert nlp.pipe_names == ["tagger"]
    nlp("Running the pipeline with the Tok2Vec component disabled.")
예제 #6
0
def test_issue5082():
    # Ensure the 'merge_entities' pipeline does something sensible for the vectors of the merged tokens
    nlp = English()
    vocab = nlp.vocab
    array1 = numpy.asarray([0.1, 0.5, 0.8], dtype=numpy.float32)
    array2 = numpy.asarray([-0.2, -0.6, -0.9], dtype=numpy.float32)
    array3 = numpy.asarray([0.3, -0.1, 0.7], dtype=numpy.float32)
    array4 = numpy.asarray([0.5, 0, 0.3], dtype=numpy.float32)
    array34 = numpy.asarray([0.4, -0.05, 0.5], dtype=numpy.float32)
    vocab.set_vector("I", array1)
    vocab.set_vector("like", array2)
    vocab.set_vector("David", array3)
    vocab.set_vector("Bowie", array4)
    text = "I like David Bowie"
    patterns = [
        {"label": "PERSON", "pattern": [{"LOWER": "david"}, {"LOWER": "bowie"}]}
    ]
    ruler = nlp.add_pipe("entity_ruler")
    ruler.add_patterns(patterns)
    parsed_vectors_1 = [t.vector for t in nlp(text)]
    assert len(parsed_vectors_1) == 4
    ops = get_current_ops()
    numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[0]), array1)
    numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[1]), array2)
    numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[2]), array3)
    numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_1[3]), array4)
    nlp.add_pipe("merge_entities")
    parsed_vectors_2 = [t.vector for t in nlp(text)]
    assert len(parsed_vectors_2) == 3
    numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_2[0]), array1)
    numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_2[1]), array2)
    numpy.testing.assert_array_equal(ops.to_numpy(parsed_vectors_2[2]), array34)
예제 #7
0
def test_serialize_transformer_data():
    data = {"x": TransformerData.empty()}
    bytes_data = srsly.msgpack_dumps(data)
    new_data = srsly.msgpack_loads(bytes_data)
    assert isinstance(new_data["x"], TransformerData)

    nlp = Language()
    nlp.add_pipe(
        "transformer",
        config={
            "model": {
                "name": "distilbert-base-uncased",
                "transformer_config": {
                    "output_attentions": True
                },
            }
        },
    )
    nlp.initialize()
    doc = nlp("This is a test.")
    b = doc.to_bytes()
    reloaded_doc = Doc(nlp.vocab)
    reloaded_doc.from_bytes(b)
    assert_docs_equal(doc, reloaded_doc)
    ops = get_current_ops()
    for key in doc._.trf_data.model_output:
        assert_array_equal(
            ops.to_numpy(ops.asarray(doc._.trf_data.model_output[key])),
            ops.to_numpy(ops.asarray(
                reloaded_doc._.trf_data.model_output[key])),
        )
예제 #8
0
파일: test_ops.py 프로젝트: admariner/thinc
def test_ngrams():
    ops = get_current_ops()
    arr1 = numpy.asarray([1, 2, 3, 4, 5], dtype=numpy.uint64)
    for n in range(1, 10):
        assert len(ops.ngrams(n, arr1)) == max(0, arr1.shape[0] - (n - 1))
    assert len(ops.ngrams(-1, arr1)) == 0
    assert len(ops.ngrams(arr1.shape[0] + 1, arr1)) == 0
예제 #9
0
def test_tok2vec_listener():
    orig_config = Config().from_str(cfg_string)
    nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
    assert nlp.pipe_names == ["tok2vec", "tagger"]
    tagger = nlp.get_pipe("tagger")
    tok2vec = nlp.get_pipe("tok2vec")
    tagger_tok2vec = tagger.model.get_ref("tok2vec")
    assert isinstance(tok2vec, Tok2Vec)
    assert isinstance(tagger_tok2vec, Tok2VecListener)
    train_examples = []
    for t in TRAIN_DATA:
        train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
        for tag in t[1]["tags"]:
            tagger.add_label(tag)

    # Check that the Tok2Vec component finds it listeners
    assert tok2vec.listeners == []
    optimizer = nlp.initialize(lambda: train_examples)
    assert tok2vec.listeners == [tagger_tok2vec]

    for i in range(5):
        losses = {}
        nlp.update(train_examples, sgd=optimizer, losses=losses)

    doc = nlp("Running the pipeline as a whole.")
    doc_tensor = tagger_tok2vec.predict([doc])[0]
    ops = get_current_ops()
    assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))

    # TODO: should this warn or error?
    nlp.select_pipes(disable="tok2vec")
    assert nlp.pipe_names == ["tagger"]
    nlp("Running the pipeline with the Tok2Vec component disabled.")
def huggingface_from_pretrained(source: Union[Path, str], tok_config: Dict,
                                trf_config: Dict) -> HFObjects:
    """Create a Huggingface transformer model from pretrained weights. Will
    download the model if it is not already downloaded.

    source (Union[str, Path]): The name of the model or a path to it, such as
        'bert-base-cased'.
    tok_config (dict): Settings to pass to the tokenizer.
    trf_config (dict): Settings to pass to the transformer.
    """
    if hasattr(source, "absolute"):
        str_path = str(source.absolute())
    else:
        str_path = source
    tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config)
    vocab_file_contents = None
    if hasattr(tokenizer, "vocab_file"):
        with open(tokenizer.vocab_file, "rb") as fileh:
            vocab_file_contents = fileh.read()
    trf_config["return_dict"] = True
    config = AutoConfig.from_pretrained(str_path, **trf_config)
    transformer = AutoModel.from_pretrained(str_path, config=config)
    ops = get_current_ops()
    if isinstance(ops, CupyOps):
        transformer.cuda()
    return HFObjects(tokenizer, transformer, vocab_file_contents)
예제 #11
0
def huggingface_from_pretrained(source, config):
    tokenizer = AutoTokenizer.from_pretrained(source, **config)
    transformer = AutoModel.from_pretrained(source)
    ops = get_current_ops()
    if isinstance(ops, CupyOps):
        transformer.cuda()
    return tokenizer, transformer
예제 #12
0
def test_issue5551(textcat_config):
    """Test that after fixing the random seed, the results of the pipeline are truly identical"""
    component = "textcat"

    pipe_cfg = Config().from_str(textcat_config)
    results = []
    for i in range(3):
        fix_random_seed(0)
        nlp = English()
        text = "Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g."
        annots = {"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}}
        pipe = nlp.add_pipe(component, config=pipe_cfg, last=True)
        for label in set(annots["cats"]):
            pipe.add_label(label)
        # Train
        nlp.initialize()
        doc = nlp.make_doc(text)
        nlp.update([Example.from_dict(doc, annots)])
        # Store the result of each iteration
        result = pipe.model.predict([doc])
        results.append(result[0])
    # All results should be the same because of the fixed seed
    assert len(results) == 3
    ops = get_current_ops()
    assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[1]), decimal=5)
    assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[2]), decimal=5)
예제 #13
0
def huggingface_from_pretrained(source: Union[Path, str], config: Dict):
    """Create a Huggingface transformer model from pretrained weights. Will
    download the model if it is not already downloaded.

    source (Union[str, Path]): The name of the model or a path to it, such as
        'bert-base-cased'.
    config (dict): Settings to pass to the tokenizer.
    """
    warnings.warn(
        "spacy_transformers.util.huggingface_from_pretrained has been moved to "
        "spacy_transformers.layers.transformer_model.huggingface_from_pretrained "
        "with an updated API:\n"
        "huggingface_from_pretrained(source, tok_config, trf_config) -> HFObjects",
        DeprecationWarning,
    )
    if hasattr(source, "absolute"):
        str_path = str(source.absolute())
    else:
        str_path = source
    tokenizer = AutoTokenizer.from_pretrained(str_path, **config)
    transformer = AutoModel.from_pretrained(str_path)
    ops = get_current_ops()
    if isinstance(ops, CupyOps):
        transformer.cuda()
    return tokenizer, transformer
예제 #14
0
def Y(answer: int, n_classes: int) -> Array2d:
    ops: Ops = get_current_ops()
    return cast(
        Array2d,
        to_categorical(cast(IntsXd, ops.asarray([answer])),
                       n_classes=n_classes),
    )
예제 #15
0
파일: spancat.py 프로젝트: bpben/spaCy
    def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged:
        if ops is None:
            ops = get_current_ops()
        spans = []
        lengths = []
        for doc in docs:
            starts = ops.xp.arange(len(doc), dtype="i")
            starts = starts.reshape((-1, 1))
            length = 0
            for size in sizes:
                if size <= len(doc):
                    starts_size = starts[: len(doc) - (size - 1)]
                    spans.append(ops.xp.hstack((starts_size, starts_size + size)))
                    length += spans[-1].shape[0]
                if spans:
                    assert spans[-1].ndim == 2, spans[-1].shape
            lengths.append(length)
        lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
        if len(spans) > 0:
            output = Ragged(ops.xp.vstack(spans), lengths_array)
        else:
            output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)

        assert output.dataXd.ndim == 2
        return output
예제 #16
0
def test_language_pipe_error_handler_custom(en_vocab, n_process):
    """Test the error handling of a custom component that has no pipe method"""
    Language.component("my_evil_component", func=evil_component)
    ops = get_current_ops()
    if isinstance(ops, NumpyOps) or n_process < 2:
        nlp = English()
        nlp.add_pipe("my_evil_component")
        texts = ["TEXT 111", "TEXT 222", "TEXT 333", "TEXT 342", "TEXT 666"]
        with pytest.raises(ValueError):
            # the evil custom component throws an error
            list(nlp.pipe(texts))

        nlp.set_error_handler(warn_error)
        logger = logging.getLogger("spacy")
        with mock.patch.object(logger, "warning") as mock_warning:
            # the errors by the evil custom component raise a warning for each
            # bad doc
            docs = list(nlp.pipe(texts, n_process=n_process))
            # HACK/TODO? the warnings in child processes don't seem to be
            # detected by the mock logger
            if n_process == 1:
                mock_warning.assert_called()
                assert mock_warning.call_count == 2
                assert len(docs) + mock_warning.call_count == len(texts)
            assert [doc.text for doc in docs] == ["TEXT 111", "TEXT 333", "TEXT 666"]
예제 #17
0
def test_language_pipe_error_handler_input_as_tuples(en_vocab, n_process):
    """Test the error handling of nlp.pipe with input as tuples"""
    Language.component("my_evil_component", func=evil_component)
    ops = get_current_ops()
    if isinstance(ops, NumpyOps) or n_process < 2:
        nlp = English()
        nlp.add_pipe("my_evil_component")
        texts = [
            ("TEXT 111", 111),
            ("TEXT 222", 222),
            ("TEXT 333", 333),
            ("TEXT 342", 342),
            ("TEXT 666", 666),
        ]
        with pytest.raises(ValueError):
            list(nlp.pipe(texts, as_tuples=True))
        nlp.set_error_handler(warn_error)
        logger = logging.getLogger("spacy")
        with mock.patch.object(logger, "warning") as mock_warning:
            tuples = list(nlp.pipe(texts, as_tuples=True, n_process=n_process))
            # HACK/TODO? the warnings in child processes don't seem to be
            # detected by the mock logger
            if n_process == 1:
                mock_warning.assert_called()
                assert mock_warning.call_count == 2
                assert len(tuples) + mock_warning.call_count == len(texts)
            assert (tuples[0][0].text, tuples[0][1]) == ("TEXT 111", 111)
            assert (tuples[1][0].text, tuples[1][1]) == ("TEXT 333", 333)
            assert (tuples[2][0].text, tuples[2][1]) == ("TEXT 666", 666)
예제 #18
0
def test_language_pipe(nlp2, n_process, texts):
    ops = get_current_ops()
    if isinstance(ops, NumpyOps) or n_process < 2:
        texts = texts * 10
        expecteds = [nlp2(text) for text in texts]
        docs = nlp2.pipe(texts, n_process=n_process, batch_size=2)

        for doc, expected_doc in zip(docs, expecteds):
            assert_docs_equal(doc, expected_doc)
def test_multiprocessing(simple_nlp, texts):
    ops = get_current_ops()
    if isinstance(ops, NumpyOps):
        texts = texts * 3
        expecteds = [simple_nlp(text) for text in texts]
        docs = simple_nlp.pipe(texts, n_process=2, batch_size=2)

        for doc, expected_doc in zip(docs, expecteds):
            assert_docs_equal(doc, expected_doc)
예제 #20
0
def test_pickle_vocab(strings, lex_attr):
    vocab = Vocab(strings=strings)
    ops = get_current_ops()
    vectors = Vectors(data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1)
    vocab.vectors = vectors
    vocab[strings[0]].norm_ = lex_attr
    vocab_pickled = pickle.dumps(vocab)
    vocab_unpickled = pickle.loads(vocab_pickled)
    assert vocab.to_bytes() == vocab_unpickled.to_bytes()
    assert vocab_unpickled.vectors.mode == "floret"
예제 #21
0
def test_pass_doc_to_pipeline(nlp, n_process):
    texts = ["cats", "dogs", "guinea pigs"]
    docs = [nlp.make_doc(text) for text in texts]
    assert not any(len(doc.cats) for doc in docs)
    doc = nlp(docs[0])
    assert doc.text == texts[0]
    assert len(doc.cats) > 0
    if isinstance(get_current_ops(), NumpyOps) or n_process < 2:
        docs = nlp.pipe(docs, n_process=n_process)
        assert [doc.text for doc in docs] == texts
        assert all(len(doc.cats) for doc in docs)
예제 #22
0
def test_issue4903():
    """Ensure that this runs correctly and doesn't hang or crash on Windows /
    macOS."""
    nlp = English()
    nlp.add_pipe("sentencizer")
    nlp.add_pipe("my_pipe", after="sentencizer")
    text = ["I like bananas.", "Do you like them?", "No, I prefer wasabi."]
    if isinstance(get_current_ops(), NumpyOps):
        docs = list(nlp.pipe(text, n_process=2))
        assert docs[0].text == "I like bananas."
        assert docs[1].text == "Do you like them?"
        assert docs[2].text == "No, I prefer wasabi."
예제 #23
0
def test_language_pipe_stream(nlp2, n_process, texts):
    ops = get_current_ops()
    if isinstance(ops, NumpyOps) or n_process < 2:
        # check if nlp.pipe can handle infinite length iterator properly.
        stream_texts = itertools.cycle(texts)
        texts0, texts1 = itertools.tee(stream_texts)
        expecteds = (nlp2(text) for text in texts0)
        docs = nlp2.pipe(texts1, n_process=n_process, batch_size=2)

        n_fetch = 20
        for doc, expected_doc in itertools.islice(zip(docs, expecteds), n_fetch):
            assert_docs_equal(doc, expected_doc)
예제 #24
0
def test_language_pipe_error_handler_make_doc_preferred(n_process):
    """Test the error handling for make_doc"""

    ops = get_current_ops()
    if isinstance(ops, NumpyOps) or n_process < 2:
        nlp = English()
        nlp.max_length = 10
        texts = ["12345678901234567890", "12345"] * 10
        with pytest.raises(ValueError):
            list(nlp.pipe(texts, n_process=n_process))
        nlp.default_error_handler = ignore_error
        docs = list(nlp.pipe(texts, n_process=n_process))
        assert len(docs) == 0
예제 #25
0
def test_multibatch():
    fix_random_seed(0)
    ops = get_current_ops()
    arr1 = numpy.asarray([1, 2, 3, 4])
    arr2 = numpy.asarray([5, 6, 7, 8])
    batches = list(ops.multibatch(2, arr1, arr2))
    assert numpy.concatenate(batches).tolist() == [[1, 2], [5, 6], [3, 4], [7, 8]]
    batches = list(ops.multibatch(2, arr1, arr2, shuffle=True))
    assert len(batches) == 2
    assert len(batches[0]) == 2
    assert len(batches[1]) == 2
    batches = list(ops.multibatch(2, [1, 2, 3, 4], [5, 6, 7, 8]))
    assert batches == [[[1, 2], [5, 6]], [[3, 4], [7, 8]]]
    with pytest.raises(ValueError):
        ops.multibatch(10, (i for i in range(100)), (i for i in range(100)))
    with pytest.raises(ValueError):
        ops.multibatch(10, arr1, (i for i in range(100)), arr2)
예제 #26
0
파일: subclasses.py 프로젝트: MalteHB/DaCy
def huggingface_classification_from_pretrained(source: Union[Path, str], config: Dict):
    """Create a Huggingface transformer model from pretrained weights. Will
    download the model if it is not already downloaded.
    source (Union[str, Path]): The name of the model or a path to it, such as
        'bert-base-cased'.
    config (dict): Settings to pass to the tokenizer.
    """
    if hasattr(source, "absolute"):
        str_path = str(source.absolute())
    else:
        str_path = source
    tokenizer = AutoTokenizer.from_pretrained(str_path, **config)
    transformer = AutoModelForSequenceClassification.from_pretrained(str_path)
    ops = get_current_ops()
    if isinstance(ops, CupyOps):
        transformer.cuda()
    return tokenizer, transformer
예제 #27
0
def test_language_pipe_error_handler_make_doc_actual(n_process):
    """Test the error handling for make_doc"""
    # TODO: fix so that the following test is the actual behavior

    ops = get_current_ops()
    if isinstance(ops, NumpyOps) or n_process < 2:
        nlp = English()
        nlp.max_length = 10
        texts = ["12345678901234567890", "12345"] * 10
        with pytest.raises(ValueError):
            list(nlp.pipe(texts, n_process=n_process))
        nlp.default_error_handler = ignore_error
        if n_process == 1:
            with pytest.raises(ValueError):
                list(nlp.pipe(texts, n_process=n_process))
        else:
            docs = list(nlp.pipe(texts, n_process=n_process))
            assert len(docs) == 0
예제 #28
0
def test_span_with_vectors(doc):
    ops = get_current_ops()
    prev_vectors = doc.vocab.vectors
    vectors = [
        ("apple", ops.asarray([1, 2, 3])),
        ("orange", ops.asarray([-1, -2, -3])),
        ("And", ops.asarray([-1, -1, -1])),
        ("juice", ops.asarray([5, 5, 10])),
        ("pie", ops.asarray([7, 6.3, 8.9])),
    ]
    add_vecs_to_vocab(doc.vocab, vectors)
    # 0-length span
    assert_array_equal(ops.to_numpy(doc[0:0].vector), numpy.zeros((3, )))
    # longer span with no vector
    assert_array_equal(ops.to_numpy(doc[0:4].vector), numpy.zeros((3, )))
    # single-token span with vector
    assert_array_equal(ops.to_numpy(doc[10:11].vector), [-1, -1, -1])
    doc.vocab.vectors = prev_vectors
예제 #29
0
def test_language_pipe_error_handler_pipe(en_vocab, n_process):
    """Test the error handling of a component's pipe method"""
    Language.component("my_perhaps_sentences", func=perhaps_set_sentences)
    Language.component("assert_sents_error", func=assert_sents_error)
    ops = get_current_ops()
    if isinstance(ops, NumpyOps) or n_process < 2:
        texts = [f"{str(i)} is enough. Done" for i in range(100)]
        nlp = English()
        nlp.add_pipe("my_perhaps_sentences")
        nlp.add_pipe("assert_sents_error")
        nlp.initialize()
        with pytest.raises(ValueError):
            # assert_sents_error requires sentence boundaries, will throw an error otherwise
            docs = list(nlp.pipe(texts, n_process=n_process, batch_size=10))
        nlp.set_error_handler(ignore_error)
        docs = list(nlp.pipe(texts, n_process=n_process, batch_size=10))
        # we lose/ignore the failing 4,40-49 docs
        assert len(docs) == 89
def test_tensorflow_wrapper_serialize_model_subclass(
    X, Y, input_size, n_classes, answer
):
    import tensorflow as tf

    input_shape = (1, input_size)
    ops = get_current_ops()

    @keras_subclass(
        "foo.v1",
        X=ops.alloc2f(*input_shape),
        Y=to_categorical(ops.asarray1i([1]), n_classes=n_classes),
        input_shape=input_shape,
    )
    class CustomKerasModel(tf.keras.Model):
        def __init__(self, **kwargs):
            super(CustomKerasModel, self).__init__(**kwargs)
            self.in_dense = tf.keras.layers.Dense(
                12, name="in_dense", input_shape=input_shape
            )
            self.out_dense = tf.keras.layers.Dense(
                n_classes, name="out_dense", activation="softmax"
            )

        def call(self, inputs) -> tf.Tensor:
            x = self.in_dense(inputs)
            return self.out_dense(x)

    model = TensorFlowWrapper(CustomKerasModel())
    # Train the model to predict the right single answer
    optimizer = Adam()
    for i in range(50):
        guesses, backprop = model(X, is_train=True)
        d_guesses = (guesses - Y) / guesses.shape[0]
        backprop(d_guesses)
        model.finish_update(optimizer)
    predicted = model.predict(X).argmax()
    assert predicted == answer

    # Save then Load the model from bytes
    model.from_bytes(model.to_bytes())

    # The from_bytes model gets the same answer
    assert model.predict(X).argmax() == answer