Beispiel #1
0
def test_frontend_error():
    from myia.frontends import _frontends
    name = '__testing_name000_'

    def f():
        raise ValueError('test')

    assert name not in _frontends
    _frontends[name] = f

    with pytest.raises(FrontendLoadingError):
        activate_frontend(name)

    del _frontends[name]
Beispiel #2
0
import numpy as np
import pytest

from myia import grad, myia, value_and_grad
from myia.api import to_device
from myia.frontends import activate_frontend
from myia.frontends.pytorch import tensor_pytorch_aliasable
from myia.utils import MyiaInputTypeError, MyiaTypeError, MyiaValueError

from ..common import MA
from ..multitest import eqtest, run

torch = pytest.importorskip("torch")
nn = torch.nn

activate_frontend('pytorch')


@eqtest.register
def eqtest(t1: torch.Tensor, t2, rtol=1e-5, atol=1e-8, **kwargs):
    return torch.allclose(t1, t2, equal_nan=True, atol=atol, rtol=rtol)


def test_pytorch_dtype_to_type():
    from myia.frontends.pytorch import pytorch_dtype_to_type
    with pytest.raises(TypeError):
        pytorch_dtype_to_type("fake_pytorch_type")


# Uncomment this line to print values at specific precision
# torch.set_printoptions(precision=8)
Beispiel #3
0
def test_load_frontend_unknown():
    with pytest.raises(UnknownFrontend):
        activate_frontend('_made_up_frontend')
Beispiel #4
0
from myia.abstract import (
    SHAPE,
    TYPE,
    AbstractArray,
    AbstractFunction,
    TypedPrimitive,
)
from myia.frontends import activate_frontend
from myia.operations import partial, primitives as P
from myia.pipeline import scalar_parse, scalar_pipeline
from myia.validate import ValidationError, validate, validate_abstract

from .common import Point, f64, i64, to_abstract_test

activate_frontend("pytorch")
pytorch_abstract_types = pytest.importorskip(
    "myia_frontend_pytorch.pytorch_abstract_types")
PyTorchTensor = pytorch_abstract_types.PyTorchTensor

Point_a = Point(i64, i64)

pip = scalar_pipeline.select("resources", "parse", "infer", "specialize",
                             "validate")

pip_ec = scalar_pipeline.select("resources", "parse", "infer", "specialize",
                                "simplify_types", "validate")


def run(pip, fn, types):
    res = pip.run(input=fn, argspec=[to_abstract_test(t) for t in types])