def test_dispatcher():
    f = Dispatcher('f')
    inc = lambda x: x + 1
    dec = lambda x: x - 1
    f.add((int,), inc)
    f.add((float,), dec)

    assert f.resolve((int,)) == inc

    assert f(1) == 2
    assert f(1.0) == 0.0
def test_dispatcher():
    f = Dispatcher('f')
    f.add((int,), inc)
    f.add((float,), dec)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", DeprecationWarning)
        assert f.resolve((int,)) == inc
    assert f.dispatch(int) is inc

    assert f(1) == 2
    assert f(1.0) == 0.0
def test_vararg_dispatch_ambiguity_in_variadic():
    f = Dispatcher('f')

    @f.register(float, [object])
    def _1(a, b, *args):
        return 1

    @f.register(object, [float])
    def _2(a, b, *args):
        return 2

    assert ambiguities(f.funcs)
def test_dispatcher_as_decorator():
    f = Dispatcher('f')

    @f.register(int)
    def inc(x):
        return x + 1

    @f.register(float)
    def inc(x):
        return x - 1

    assert f(1) == 2
    assert f(1.0) == 0.0
Beispiel #5
0
def test_function_annotation_register():
    f = Dispatcher('f')

    @f.register()
    def inc(x: int):
        return x + 1

    @f.register()
    def inc(x: float):
        return x - 1

    assert f(1) == 2
    assert f(1.0) == 0.0
def test_register_stacking():
    f = Dispatcher('f')

    @f.register(list)
    @f.register(tuple)
    def rev(x):
        return x[::-1]

    assert f((1, 2, 3)) == (3, 2, 1)
    assert f([1, 2, 3]) == [3, 2, 1]

    assert raises(NotImplementedError, lambda: f('hello'))
    assert rev('hello') == 'olleh'
def test_serializable():
    f = Dispatcher('f')
    f.add((int,), inc)
    f.add((float,), dec)
    f.add((object,), identity)

    import pickle
    assert isinstance(pickle.dumps(f), (str, bytes))

    g = pickle.loads(pickle.dumps(f))

    assert g(1) == 2
    assert g(1.0) == 0.0
    assert g('hello') == 'hello'
def test_on_ambiguity():
    f = Dispatcher('f')

    identity = lambda x: x

    ambiguities = [False]
    def on_ambiguity(dispatcher, amb):
        ambiguities[0] = True


    f.add((object, object), identity, on_ambiguity=on_ambiguity)
    assert not ambiguities[0]
    f.add((object, float), identity, on_ambiguity=on_ambiguity)
    assert not ambiguities[0]
    f.add((float, object), identity, on_ambiguity=on_ambiguity)
    assert ambiguities[0]
def test_source():
    def one(x, y):
        """ Docstring number one """
        return x + y

    def two(x, y):
        """ Docstring number two """
        return x - y

    master_doc = 'Doc of the multimethod itself'

    f = Dispatcher('f', doc=master_doc)
    f.add((int, int), one)
    f.add((float, float), two)

    assert 'x + y' in f._source(1, 1)
    assert 'x - y' in f._source(1.0, 1.0)
def test_dispatch_method():
    f = Dispatcher('f')

    @f.register(list)
    def rev(x):
        return x[::-1]

    @f.register(int, int)
    def add(x, y):
        return x + y

    class MyList(list):
        pass

    assert f.dispatch(list) is rev
    assert f.dispatch(MyList) is rev
    assert f.dispatch(int, int) is add
def test_not_implemented():
    f = Dispatcher('f')

    @f.register(object)
    def _1(x):
        return 'default'

    @f.register(int)
    def _2(x):
        if x % 2 == 0:
            return 'even'
        else:
            raise MDNotImplementedError()

    assert f('hello') == 'default'  # default behavior
    assert f(2) == 'even'           # specialized behavior
    assert f(3) == 'default'        # fall back to default behavior
    assert raises(NotImplementedError, lambda: f(1, 2))
def test_on_ambiguity():
    f = Dispatcher('f')

    def identity(x):
        return x

    ambiguities = [False]

    def on_ambiguity(dispatcher, amb):
        ambiguities[0] = True

    f.add((object, object), identity)
    f.add((object, float), identity)
    f.add((float, object), identity)

    assert not ambiguities[0]
    f.reorder(on_ambiguity=on_ambiguity)
    assert ambiguities[0]
def test_vararg_dispatch_multiple_implementations():
    f = Dispatcher('f')

    @f.register(str, [float])
    def _1(a, *b):
        return 'mixed_string_floats'

    @f.register([float])
    def _2(*b):
        return 'floats'

    @f.register([str])
    def _3(*strings):
        return 'strings'

    assert f('a', 1.0, 2.0) == 'mixed_string_floats'
    assert f(1.0, 2.0, 3.14) == 'floats'
    assert f('a', 'b', 'c') == 'strings'
def test_help():
    def one(x, y):
        """ Docstring number one """
        return x + y

    def two(x, y):
        """ Docstring number two """
        return x + y

    def three(x, y):
        """ Docstring number three """
        return x + y

    master_doc = 'Doc of the multimethod itself'

    f = Dispatcher('f', doc=master_doc)
    f.add((object, object), one)
    f.add((int, int), two)
    f.add((float, float), three)

    assert f._help(1, 1) == two.__doc__
    assert f._help(1.0, 2.0) == three.__doc__
Beispiel #15
0
def test_docstring():
    def one(x, y):
        """ Docstring number one """
        return x + y

    def two(x, y):
        """ Docstring number two """
        return x + y

    def three(x, y):
        return x + y

    f = Dispatcher('f')
    f.add((object, object), one)
    f.add((int, int), two)
    f.add((float, float), three)

    assert one.__doc__.strip() in f.__doc__
    assert two.__doc__.strip() in f.__doc__
    assert f.__doc__.find(one.__doc__.strip()) < \
            f.__doc__.find(two.__doc__.strip())
    assert 'object, object' in f.__doc__
def test_vararg_dispatch_unions():
    f = Dispatcher('f')

    @f.register(str, [(int, float)])
    def _1(a, *b):
        return 'mixed_string_ints_floats'

    @f.register([str])
    def _2(*strings):
        return 'strings'

    @f.register([(str, int)])
    def _3(*strings_ints):
        return 'mixed_strings_ints'

    @f.register([object])
    def _4(*objects):
        return 'objects'

    assert f('a', 1.0, 7, 2.0, 11) == 'mixed_string_ints_floats'
    assert f('a', 'b', 'c') == 'strings'
    assert f('a', 1, 'b', 2) == 'mixed_strings_ints'
    assert f([], (), {}) == 'objects'
def test_halt_method_resolution():
    g = [0]
    def on_ambiguity(a, b):
        g[0] += 1

    f = Dispatcher('f')

    halt_ordering()

    def func(*args):
        pass

    f.add((int, object), func)
    f.add((object, int), func)

    assert g == [0]

    restart_ordering(on_ambiguity=on_ambiguity)

    assert g == [1]

    print(list(f.ordering))
    assert set(f.ordering) == set([(int, object), (object, int)])
def test_vararg_ordering():
    f = Dispatcher('f')

    @f.register(str, int, [object])
    def _1(string, integer, *objects):
        return 1

    @f.register(str, [object])
    def _2(string, *objects):
        return 2

    @f.register([object])
    def _3(*objects):
        return 3

    assert f('a', 1) == 1
    assert f('a', 1, ['a']) == 1
    assert f('a', 1, 'a') == 1
    assert f('a', 'a') == 2
    assert f('a') == 2
    assert f('a', ['a']) == 2
    assert f(1) == 3
    assert f() == 3
Beispiel #19
0
    def __add__(self, other):
        if isinstance(other, type(self)):
            if self.input_dim != other.input_dim or self.output_dim != other.output_dim:
                raise ValueError('Incompatible input or output dimensions.')
            return self.__class__(deepcopy(tuple(self.parts)) +
                                  deepcopy(tuple(other.parts)),
                                  input_dim=self.input_dim,
                                  output_dim=self.output_dim)
        else:
            return NotImplemented

    def __radd__(self, other):
        return other.__add__(self)


torchify = Dispatcher('torchify')


@torchify.register(Tensor)
def torchify_tensor(tens):
    return tens


@torchify.register(ndarray)
def torchify_numpy(arr):
    return torch.from_numpy(arr)


torchify32 = Dispatcher('torchify32')

def test_union_types():
    f = Dispatcher('f')
    f.register((int, float))(inc)

    assert f(1) == 2
    assert f(1.0) == 2.0
def test_raise_error_on_non_class():
    f = Dispatcher('f')
    assert raises(TypeError, lambda: f.add((1,), inc))
Beispiel #22
0
        return '(%s ** %s)' % (self.lhs, self.rhs)


class IntPowerInt(IntegerExpression, PowerBase, FunctionOfInts):
    pass


class RealPowerInt(RealNumberExpression, PowerBase, FunctionOfInts):
    pass


class RealPowerReal(RealNumberExpression, PowerBase, FunctionOfReals):
    pass


Power = Dispatcher('Power')
Power.register(IntegerExpression, IntegerExpression)(IntPowerInt)
Power.register(RealNumberExpression, IntegerExpression)(RealPowerInt)
Power.register(RealNumberExpression, RealNumberExpression)(RealPowerReal)


class QuotientBase(NumberExpression, BinaryFunction):
    def __str__(self):
        return '(%s / %s)' % (self.lhs, self.rhs)


class QuotientReal(RealNumberExpression, QuotientBase, FunctionOfReals):
    pass


class QuotientInt(RealNumberExpression, QuotientBase, FunctionOfInts):
Beispiel #23
0
from multipledispatch.dispatcher import Dispatcher
from torch.tensor import Tensor
from numpy import ndarray
import torch
import numpy as np

torchify = Dispatcher('torchify')


@torchify.register(Tensor)
def torchify_tensor(tens):
    return tens


@torchify.register(ndarray)
def torchify_numpy(arr):
    return torch.from_numpy(arr)


def torchify32(x):
    # TODO: This could obviously be more efficient.  It
    # also might be good for it to handle integers differently.
    return torchify(numpify(x).astype(np.float32))


numpify = Dispatcher('numpify')


@numpify.register(Tensor)
def numpify_tensor(tens):
    return tens.detach().numpy()
def test_source_raises_on_missing_function():
    f = Dispatcher('f')

    assert raises(TypeError, lambda: f.source(1))
Beispiel #25
0
    def __init__(self, network):
        '''
        network (nn.Module): Must end with a MuSigmaLayer or similar.
        '''
        nn.Module.__init__(self)
        self.network = network

    def rv(self, state):
        mu_sigma = self.network(state)
        slicer = (slice(None, None, None), ) * (len(mu_sigma.shape) - 1)
        mu = mu_sigma[slicer + (0, )]
        sigma = mu_sigma[slicer + (1, )]
        return Normal(mu, sigma)


tupify = Dispatcher('tupify')


@tupify.register((np.ndarray, object))
def tupify_ndarray(arr):
    return (arr, )


@tupify.register(Iterable)
def tupify_iterable(itr):
    return tuple(itr)


@curry
def select_first_dims(selection, arr):
    return arr[tupify(selection) + (slice(None, None, None), ) *
Beispiel #26
0
from multipledispatch.dispatcher import Dispatcher
from xgboost.sklearn import XGBRegressor

predict = Dispatcher(name='predict')

@predict.register(object)
def predict_default(estimator, *args, **kwargs):
    return estimator.predict(*args, **kwargs)

@predict.register(XGBRegressor)
def predict_xgb_regressor(estimator, X, **kwargs):
    return estimator.predict(X, **kwargs)
def test_vararg_has_multiple_elements():
    f = Dispatcher('f')
    assert raises(TypeError, lambda: f.register([float, str])(lambda: None))
def test_vararg_not_last_element_of_signature():
    f = Dispatcher('f')
    assert raises(TypeError, lambda: f.register([float], str)(lambda: None))
Beispiel #29
0
 def __init__(self, name):
     self.dispatcher = Dispatcher(name)
def test_no_implementations():
    f = Dispatcher('f')
    assert raises(NotImplementedError, lambda: f('hello'))