示例#1
0
def id_from_inputs(inputs):
    if isinstance(inputs, (dict, OrderedDict)):
        inputs = inputs.items()
    if not inputs:
        return '()'
    return ','.join(k + ''.join(map(str, d.shape)) for k, d in inputs)


@dispatch(object, object, Variadic[float])
def allclose(a, b, rtol=1e-05, atol=1e-08):
    if type(a) != type(b):
        return False
    return ops.abs(a - b) < rtol + atol * ops.abs(b)


dispatch(np.ndarray, np.ndarray, Variadic[float])(np.allclose)


@dispatch(Tensor, Tensor, Variadic[float])
def allclose(a, b, rtol=1e-05, atol=1e-08):
    if a.inputs != b.inputs or a.output != b.output:
        return False
    return allclose(a.data, b.data, rtol=rtol, atol=atol)


def is_array(x):
    # XXX: in some JAX version, name of device array is DeviceArray,
    # while in some version, it is _DeviceArray
    return isinstance(
        x, (np.ndarray, np.generic)) or "DeviceArray" in type(x).__name__
示例#2
0
        elif isinstance(self.exo_grid, EmptyGrid):
            out = self.eval_s(s)
        elif isinstance(self.exo_grid, CartesianGrid):
            m = self.dprocess.inode(i, j)[None, :].repeat(s.shape[0], axis=0)
            out = self.eval_ms(m, s)
        else:
            raise Exception("Not Implemented.")

        return out


# this is *not* meant to be used by users

from multipledispatch import dispatch
namespace = dict()
multimethod = dispatch(namespace=namespace)

# Cartesian x Cartesian x Linear


@multimethod
def get_coefficients(itp: object, exo_grid: CartesianGrid,
                     endo_grid: CartesianGrid, interp_type: Linear, x: object):
    grid = exo_grid + endo_grid
    xx = x.reshape(tuple(grid.n) + (-1, ))
    return xx.copy()


@multimethod
def eval_ms(itp: object, exo_grid: CartesianGrid, endo_grid: CartesianGrid,
            interp_type: Linear, m: object, s: object):