Exemplo n.º 1
def test_dispatcher():
    f = Dispatcher('f')
    inc = lambda x: x + 1
    dec = lambda x: x - 1
    f.add((int,), inc)
    f.add((float,), dec)

    assert f.resolve((int,)) == inc

    assert f(1) == 2
    assert f(1.0) == 0.0
def test_dispatcher():
    f = Dispatcher('f')
    f.add((int,), inc)
    f.add((float,), dec)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", DeprecationWarning)
        assert f.resolve((int,)) == inc
    assert f.dispatch(int) is inc

    assert f(1) == 2
    assert f(1.0) == 0.0
def test_vararg_dispatch_ambiguity_in_variadic():
    f = Dispatcher('f')

    @f.register(float, [object])
    def _1(a, b, *args):
        return 1

    @f.register(object, [float])
    def _2(a, b, *args):
        return 2

    assert ambiguities(f.funcs)
Exemplo n.º 4
def test_dispatcher_as_decorator():
    f = Dispatcher('f')

    def inc(x):
        return x + 1

    def inc(x):
        return x - 1

    assert f(1) == 2
    assert f(1.0) == 0.0
Exemplo n.º 5
def test_function_annotation_register():
    f = Dispatcher('f')

    def inc(x: int):
        return x + 1

    def inc(x: float):
        return x - 1

    assert f(1) == 2
    assert f(1.0) == 0.0
def test_register_stacking():
    f = Dispatcher('f')

    def rev(x):
        return x[::-1]

    assert f((1, 2, 3)) == (3, 2, 1)
    assert f([1, 2, 3]) == [3, 2, 1]

    assert raises(NotImplementedError, lambda: f('hello'))
    assert rev('hello') == 'olleh'
Exemplo n.º 7
def test_serializable():
    f = Dispatcher('f')
    f.add((int,), inc)
    f.add((float,), dec)
    f.add((object,), identity)

    import pickle
    assert isinstance(pickle.dumps(f), (str, bytes))

    g = pickle.loads(pickle.dumps(f))

    assert g(1) == 2
    assert g(1.0) == 0.0
    assert g('hello') == 'hello'
Exemplo n.º 8
def test_on_ambiguity():
    f = Dispatcher('f')

    identity = lambda x: x

    ambiguities = [False]
    def on_ambiguity(dispatcher, amb):
        ambiguities[0] = True

    f.add((object, object), identity, on_ambiguity=on_ambiguity)
    assert not ambiguities[0]
    f.add((object, float), identity, on_ambiguity=on_ambiguity)
    assert not ambiguities[0]
    f.add((float, object), identity, on_ambiguity=on_ambiguity)
    assert ambiguities[0]
def test_source():
    def one(x, y):
        """ Docstring number one """
        return x + y

    def two(x, y):
        """ Docstring number two """
        return x - y

    master_doc = 'Doc of the multimethod itself'

    f = Dispatcher('f', doc=master_doc)
    f.add((int, int), one)
    f.add((float, float), two)

    assert 'x + y' in f._source(1, 1)
    assert 'x - y' in f._source(1.0, 1.0)
Exemplo n.º 10
def test_dispatch_method():
    f = Dispatcher('f')

    def rev(x):
        return x[::-1]

    @f.register(int, int)
    def add(x, y):
        return x + y

    class MyList(list):

    assert f.dispatch(list) is rev
    assert f.dispatch(MyList) is rev
    assert f.dispatch(int, int) is add
Exemplo n.º 11
def test_not_implemented():
    f = Dispatcher('f')

    def _1(x):
        return 'default'

    def _2(x):
        if x % 2 == 0:
            return 'even'
            raise MDNotImplementedError()

    assert f('hello') == 'default'  # default behavior
    assert f(2) == 'even'           # specialized behavior
    assert f(3) == 'default'        # fall back to default behavior
    assert raises(NotImplementedError, lambda: f(1, 2))
Exemplo n.º 12
def test_on_ambiguity():
    f = Dispatcher('f')

    def identity(x):
        return x

    ambiguities = [False]

    def on_ambiguity(dispatcher, amb):
        ambiguities[0] = True

    f.add((object, object), identity)
    f.add((object, float), identity)
    f.add((float, object), identity)

    assert not ambiguities[0]
    assert ambiguities[0]
Exemplo n.º 13
def test_vararg_dispatch_multiple_implementations():
    f = Dispatcher('f')

    @f.register(str, [float])
    def _1(a, *b):
        return 'mixed_string_floats'

    def _2(*b):
        return 'floats'

    def _3(*strings):
        return 'strings'

    assert f('a', 1.0, 2.0) == 'mixed_string_floats'
    assert f(1.0, 2.0, 3.14) == 'floats'
    assert f('a', 'b', 'c') == 'strings'
Exemplo n.º 14
def test_help():
    def one(x, y):
        """ Docstring number one """
        return x + y

    def two(x, y):
        """ Docstring number two """
        return x + y

    def three(x, y):
        """ Docstring number three """
        return x + y

    master_doc = 'Doc of the multimethod itself'

    f = Dispatcher('f', doc=master_doc)
    f.add((object, object), one)
    f.add((int, int), two)
    f.add((float, float), three)

    assert f._help(1, 1) == two.__doc__
    assert f._help(1.0, 2.0) == three.__doc__
Exemplo n.º 15
def test_docstring():
    def one(x, y):
        """ Docstring number one """
        return x + y

    def two(x, y):
        """ Docstring number two """
        return x + y

    def three(x, y):
        return x + y

    f = Dispatcher('f')
    f.add((object, object), one)
    f.add((int, int), two)
    f.add((float, float), three)

    assert one.__doc__.strip() in f.__doc__
    assert two.__doc__.strip() in f.__doc__
    assert f.__doc__.find(one.__doc__.strip()) < \
    assert 'object, object' in f.__doc__
Exemplo n.º 16
def test_vararg_dispatch_unions():
    f = Dispatcher('f')

    @f.register(str, [(int, float)])
    def _1(a, *b):
        return 'mixed_string_ints_floats'

    def _2(*strings):
        return 'strings'

    @f.register([(str, int)])
    def _3(*strings_ints):
        return 'mixed_strings_ints'

    def _4(*objects):
        return 'objects'

    assert f('a', 1.0, 7, 2.0, 11) == 'mixed_string_ints_floats'
    assert f('a', 'b', 'c') == 'strings'
    assert f('a', 1, 'b', 2) == 'mixed_strings_ints'
    assert f([], (), {}) == 'objects'
Exemplo n.º 17
def test_halt_method_resolution():
    g = [0]
    def on_ambiguity(a, b):
        g[0] += 1

    f = Dispatcher('f')


    def func(*args):

    f.add((int, object), func)
    f.add((object, int), func)

    assert g == [0]


    assert g == [1]

    assert set(f.ordering) == set([(int, object), (object, int)])
Exemplo n.º 18
def test_vararg_ordering():
    f = Dispatcher('f')

    @f.register(str, int, [object])
    def _1(string, integer, *objects):
        return 1

    @f.register(str, [object])
    def _2(string, *objects):
        return 2

    def _3(*objects):
        return 3

    assert f('a', 1) == 1
    assert f('a', 1, ['a']) == 1
    assert f('a', 1, 'a') == 1
    assert f('a', 'a') == 2
    assert f('a') == 2
    assert f('a', ['a']) == 2
    assert f(1) == 3
    assert f() == 3
Exemplo n.º 19
    def __add__(self, other):
        if isinstance(other, type(self)):
            if self.input_dim != other.input_dim or self.output_dim != other.output_dim:
                raise ValueError('Incompatible input or output dimensions.')
            return self.__class__(deepcopy(tuple(self.parts)) +
            return NotImplemented

    def __radd__(self, other):
        return other.__add__(self)

torchify = Dispatcher('torchify')

def torchify_tensor(tens):
    return tens

def torchify_numpy(arr):
    return torch.from_numpy(arr)

torchify32 = Dispatcher('torchify32')

Exemplo n.º 20
def test_union_types():
    f = Dispatcher('f')
    f.register((int, float))(inc)

    assert f(1) == 2
    assert f(1.0) == 2.0
Exemplo n.º 21
def test_raise_error_on_non_class():
    f = Dispatcher('f')
    assert raises(TypeError, lambda: f.add((1,), inc))
Exemplo n.º 22
        return '(%s ** %s)' % (self.lhs, self.rhs)

class IntPowerInt(IntegerExpression, PowerBase, FunctionOfInts):

class RealPowerInt(RealNumberExpression, PowerBase, FunctionOfInts):

class RealPowerReal(RealNumberExpression, PowerBase, FunctionOfReals):

Power = Dispatcher('Power')
Power.register(IntegerExpression, IntegerExpression)(IntPowerInt)
Power.register(RealNumberExpression, IntegerExpression)(RealPowerInt)
Power.register(RealNumberExpression, RealNumberExpression)(RealPowerReal)

class QuotientBase(NumberExpression, BinaryFunction):
    def __str__(self):
        return '(%s / %s)' % (self.lhs, self.rhs)

class QuotientReal(RealNumberExpression, QuotientBase, FunctionOfReals):

class QuotientInt(RealNumberExpression, QuotientBase, FunctionOfInts):
Exemplo n.º 23
from multipledispatch.dispatcher import Dispatcher
from torch.tensor import Tensor
from numpy import ndarray
import torch
import numpy as np

torchify = Dispatcher('torchify')

def torchify_tensor(tens):
    return tens

def torchify_numpy(arr):
    return torch.from_numpy(arr)

def torchify32(x):
    # TODO: This could obviously be more efficient.  It
    # also might be good for it to handle integers differently.
    return torchify(numpify(x).astype(np.float32))

numpify = Dispatcher('numpify')

def numpify_tensor(tens):
    return tens.detach().numpy()
Exemplo n.º 24
def test_source_raises_on_missing_function():
    f = Dispatcher('f')

    assert raises(TypeError, lambda: f.source(1))
Exemplo n.º 25
    def __init__(self, network):
        network (nn.Module): Must end with a MuSigmaLayer or similar.
        self.network = network

    def rv(self, state):
        mu_sigma = self.network(state)
        slicer = (slice(None, None, None), ) * (len(mu_sigma.shape) - 1)
        mu = mu_sigma[slicer + (0, )]
        sigma = mu_sigma[slicer + (1, )]
        return Normal(mu, sigma)

tupify = Dispatcher('tupify')

@tupify.register((np.ndarray, object))
def tupify_ndarray(arr):
    return (arr, )

def tupify_iterable(itr):
    return tuple(itr)

def select_first_dims(selection, arr):
    return arr[tupify(selection) + (slice(None, None, None), ) *
Exemplo n.º 26
from multipledispatch.dispatcher import Dispatcher
from xgboost.sklearn import XGBRegressor

predict = Dispatcher(name='predict')

def predict_default(estimator, *args, **kwargs):
    return estimator.predict(*args, **kwargs)

def predict_xgb_regressor(estimator, X, **kwargs):
    return estimator.predict(X, **kwargs)
Exemplo n.º 27
def test_vararg_has_multiple_elements():
    f = Dispatcher('f')
    assert raises(TypeError, lambda: f.register([float, str])(lambda: None))
Exemplo n.º 28
def test_vararg_not_last_element_of_signature():
    f = Dispatcher('f')
    assert raises(TypeError, lambda: f.register([float], str)(lambda: None))
Exemplo n.º 29
 def __init__(self, name):
     self.dispatcher = Dispatcher(name)
Exemplo n.º 30
def test_no_implementations():
    f = Dispatcher('f')
    assert raises(NotImplementedError, lambda: f('hello'))