def make_orimap(hyper_col, X, Y, nn=30, prngKey=4):
    '''
    Makes the orientation map for the grid
    hyper_col = hyper column length for the network in retinotopic degrees
    X = distances between neurons in retinotopic degrees
    Y = distances between neurons in retinotopic degrees
    
    Outputs
    OMap = orientation preference for each cell in the network
    Nthetas = 1/2 the number of cells in the network (or equivalent to number of E or I cells)
    '''
    kc = 2 * np.pi / (hyper_col)

    z = np.zeros_like(X)
    key = random.PRNGKey(prngKey)
    subkey = key

    for j in range(nn):
        kj = kc * np.array([np.cos(j * np.pi / nn), np.sin(j * np.pi / nn)])

        sj = 2 * random.randint(subkey, shape=(
        ), minval=1, maxval=3) - 3  #random number that's either + or -1.
        #randint inputs: PRNGkey, size tuple, minval (incl), maxval(excl)

        phij = random.uniform(subkey) * 2 * np.pi
        tmp = (X * kj[0] + Y * kj[1]) * sj + phij
        z = z + np.exp(1j * tmp)

        key, subkey = random.split(key)

    OMap = np.angle(z)
    OMap = (OMap - np.min(OMap)) * 180 / (2 * np.pi)
    Nthetas = len(OMap.ravel())

    return OMap, Nthetas
Exemple #2
0
def foe_daxcorr(y, x, L=100):
    N = y.shape[0]
    if N < L:
        raise TypeError('signal length %d is less then xcorr length %d' % (N, L))
    s = y * x.conj() # remove modulated data phase
    sf = xop.frame(s, L, L)
    sf2 = sf[:-1]
    sf1 = sf[1:]
    return jnp.mean(jnp.angle(sf1 * sf2.conj())) / L
Exemple #3
0
def angle(z, deg=False):
  if isinstance(z, JaxArray): z = z.value
  a = jnp.angle(z)
  if deg:
    a *= 180 / pi
  return JaxArray(a)
Exemple #4
0
class JaxBox(qml.math.TensorBox):
    """Implements the :class:`~.TensorBox` API for ``numpy.ndarray``.

    For more details, please refer to the :class:`~.TensorBox` documentation.
    """

    abs = wrap_output(lambda self: jnp.abs(self.data))
    angle = wrap_output(lambda self: jnp.angle(self.data))
    arcsin = wrap_output(lambda self: jnp.arcsin(self.data))
    cast = wrap_output(lambda self, dtype: jnp.array(self.data, dtype=dtype))
    expand_dims = wrap_output(
        lambda self, axis: jnp.expand_dims(self.data, axis=axis))
    ones_like = wrap_output(lambda self: jnp.ones_like(self.data))
    sqrt = wrap_output(lambda self: jnp.sqrt(self.data))
    sum = wrap_output(lambda self, axis=None, keepdims=False: jnp.sum(
        self.data, axis=axis, keepdims=keepdims))
    T = wrap_output(lambda self: self.data.T)
    take = wrap_output(lambda self, indices, axis=None: jnp.take(
        self.data, indices, axis=axis, mode="wrap"))

    def __init__(self, tensor):
        tensor = jnp.asarray(tensor)

        super().__init__(tensor)

    @staticmethod
    def astensor(tensor):
        return jnp.asarray(tensor)

    @staticmethod
    @wrap_output
    def concatenate(values, axis=0):
        return jnp.concatenate(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def dot(x, y):
        x, y = JaxBox.unbox_list([x, y])
        x = jnp.asarray(x)
        y = jnp.asarray(y)

        if x.ndim == 0 and y.ndim == 0:
            return x * y

        if x.ndim == 2 and y.ndim == 2:
            return x @ y

        return jnp.dot(x, y)

    @property
    def interface(self):
        return "jax"

    def numpy(self):
        return self.data

    @property
    def requires_grad(self):
        return True

    @property
    def shape(self):
        return self.data.shape

    @staticmethod
    @wrap_output
    def stack(values, axis=0):
        return jnp.stack(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def where(condition, x, y):
        return jnp.where(condition, *JaxBox.unbox_list([x, y]))
Exemple #5
0
accumulate_n = utils.copy_docstring(
    tf.math.accumulate_n,
    lambda inputs, shape=None, tensor_dtype=None, name=None: (  # pylint: disable=g-long-lambda
        sum(map(np.array, inputs)).astype(utils.numpy_dtype(tensor_dtype))))

acos = utils.copy_docstring(tf.math.acos, lambda x, name=None: np.arccos(x))

acosh = utils.copy_docstring(tf.math.acosh, lambda x, name=None: np.arccosh(x))

add = utils.copy_docstring(tf.math.add, lambda x, y, name=None: np.add(x, y))

add_n = utils.copy_docstring(
    tf.math.add_n, lambda inputs, name=None: sum(map(np.array, inputs)))

angle = utils.copy_docstring(tf.math.angle,
                             lambda input, name=None: np.angle(input))

argmax = utils.copy_docstring(
    tf.math.argmax,
    lambda input, axis=None, output_type=tf.int64, name=None: (  # pylint: disable=g-long-lambda
        np.argmax(input, axis=0 if axis is None else _astuple(axis)).astype(
            utils.numpy_dtype(output_type))))

argmin = utils.copy_docstring(
    tf.math.argmin,
    lambda input, axis=None, output_type=tf.int64, name=None: (  # pylint: disable=g-long-lambda
        np.argmin(input, axis=0 if axis is None else _astuple(axis)).astype(
            utils.numpy_dtype(output_type))))

asin = utils.copy_docstring(tf.math.asin, lambda x, name=None: np.arcsin(x))