Exemplo n.º 1
0
 def __call__(self, shape, dtype=None):
     shape = [size2len(d) for d in shape]
     fan_in, fan_out = _compute_fans(shape,
                                     in_axis=self.in_axis,
                                     out_axis=self.out_axis)
     if self.mode == "fan_in":
         denominator = fan_in
     elif self.mode == "fan_out":
         denominator = fan_out
     elif self.mode == "fan_avg":
         denominator = (fan_in + fan_out) / 2
     else:
         raise ValueError(
             "invalid mode for variance scaling initializer: {}".format(
                 self.mode))
     variance = math.array(self.scale / denominator, dtype=dtype)
     if self.distribution == "truncated_normal":
         # constant is stddev of standard normal truncated to (-2, 2)
         stddev = math.sqrt(variance) / math.array(.87962566103423978,
                                                   dtype)
         res = self.rng.truncated_normal(-2, 2, shape) * stddev
         return math.asarray(res, dtype=dtype)
     elif self.distribution == "normal":
         res = self.rng.normal(size=shape) * math.sqrt(variance)
         return math.asarray(res, dtype=dtype)
     elif self.distribution == "uniform":
         res = self.rng.uniform(low=-1, high=1, size=shape) * math.sqrt(
             3 * variance)
         return math.asarray(res, dtype=dtype)
     else:
         raise ValueError(
             "invalid distribution for variance scaling initializer")
Exemplo n.º 2
0
  def test_jacfwd_aux1(self):
    def f1(x, y):
      r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r

    _x = bm.array([1., 2., 3.])
    _y = bm.array([10., 5.])

    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.x = bm.array([1., 2., 3.])

      def __call__(self, y):
        a = self.x[0] * y[0]
        b = 5 * self.x[2] * y[1]
        c = 4 * self.x[1] ** 2 - 2 * self.x[2]
        d = self.x[2] * jnp.sin(self.x[0])
        r = jnp.asarray([a, b, c, d])
        return r, (c, d)

    _jr = jax.jacfwd(f1)(_x, _y)
    t = Test()
    br = bm.jacfwd(t, grad_vars=t.x)(_y)
    self.assertTrue((br == _jr).all())

    t = Test()
    _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
    _aux = t(_y)[1]
    (var_grads, arg_grads), aux = bm.jacfwd(t, grad_vars=t.x, argnums=0, has_aux=True)(_y)
    print(var_grads, )
    print(arg_grads, )
    self.assertTrue((var_grads == _jr[0]).all())
    self.assertTrue((arg_grads == _jr[1]).all())
    self.assertTrue(bm.array_equal(aux, _aux))
Exemplo n.º 3
0
  def test_jacrev_return_aux1(self):
    def f1(x, y):
      a = 4 * x[1] ** 2 - 2 * x[2]
      r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
      return r, a

    _x = bm.array([1., 2., 3.])
    _y = bm.array([10., 5.])
    _r, _a = f1(_x, _y)
    f2 = lambda *args: f1(*args)[0]
    _g1 = jax.jacrev(f2)(_x, _y)  # jax jacobian
    pprint(_g1)
    _g2 = jax.jacrev(f2, argnums=(0, 1))(_x, _y)  # jax jacobian
    pprint(_g2)

    grads, vec, aux = bm.jacrev(f1, return_value=True, has_aux=True)(_x, _y)
    assert (grads == _g1).all()
    assert aux == _a
    assert (vec == _r).all()

    grads, vec, aux = bm.jacrev(f1, return_value=True, argnums=(0, 1), has_aux=True)(_x, _y)
    assert (grads[0] == _g2[0]).all()
    assert (grads[1] == _g2[1]).all()
    assert aux == _a
    assert (vec == _r).all()
Exemplo n.º 4
0
  def test_jacfwd1(self):
    def f1(x, y):
      r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r

    _x = bm.array([1., 2., 3.])
    _y = bm.array([10., 5.])

    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.x = bm.array([1., 2., 3.])
        self.y = bm.array([10., 5.])

      def __call__(self, ):
        a = self.x[0] * self.y[0]
        b = 5 * self.x[2] * self.y[1]
        c = 4 * self.x[1] ** 2 - 2 * self.x[2]
        d = self.x[2] * jnp.sin(self.x[0])
        r = jnp.asarray([a, b, c, d])
        return r

    _jr = jax.jacfwd(f1)(_x, _y)
    t = Test()
    br = bm.jacfwd(t, grad_vars=t.x)()
    self.assertTrue((br == _jr).all())

    _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
    t = Test()
    br = bm.jacfwd(t, grad_vars=[t.x, t.y])()
    self.assertTrue((br[0] == _jr[0]).all())
    self.assertTrue((br[1] == _jr[1]).all())
Exemplo n.º 5
0
  def test_jacrev2(self):
    print()

    def f2(x, y):
      r1 = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1]])
      r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r1, r2

    jr = jax.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
    pprint(jr)

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]).value, bm.array([10., 5.]).value)
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])

    def f2(x, y):
      r1 = bm.asarray([x[0] * y[0], 5 * x[2] * y[1]])
      r2 = bm.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r1, r2

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]).value, bm.array([10., 5.]).value)
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])
Exemplo n.º 6
0
    def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
        # dt
        self.dt = bm.get_dt() if dt is None else dt

        # data size
        if isinstance(size, int): size = (size, )
        if not isinstance(size, (tuple, list)):
            raise ModelBuildError(
                f'"size" must a tuple/list of int, but we got {type(size)}: {size}'
            )
        self.size = tuple(size)

        # delay time length
        self.delay = delay

        # data and operations
        if isinstance(delay, (int, float)):  # uniform delay
            self.uniform_delay = True
            self.num_step = int(pm.ceil(delay / self.dt)) + 1
            self.out_idx = bm.Variable(bm.array([0], dtype=bm.uint32))
            self.in_idx = bm.Variable(
                bm.array([self.num_step - 1], dtype=bm.uint32))
            self.data = bm.Variable(
                bm.zeros((self.num_step, ) + self.size, dtype=dtype))

        else:  # non-uniform delay
            self.uniform_delay = False
            if not len(self.size) == 1:
                raise NotImplementedError(
                    f'Currently, BrainPy only supports 1D heterogeneous '
                    f'delays, while we got the heterogeneous delay with '
                    f'{len(self.size)}-dimensions.')
            self.num = size2len(size)
            if bm.ndim(delay) != 1:
                raise ModelBuildError(f'Only support a 1D non-uniform delay. '
                                      f'But we got {delay.ndim}D: {delay}')
            if delay.shape[0] != self.size[0]:
                raise ModelBuildError(
                    f"The first shape of the delay time size must "
                    f"be the same with the delay data size. But "
                    f"we got {delay.shape[0]} != {self.size[0]}")
            delay = bm.around(delay / self.dt)
            self.diag = bm.array(bm.arange(self.num), dtype=bm.int_)
            self.num_step = bm.array(delay, dtype=bm.uint32) + 1
            self.in_idx = bm.Variable(self.num_step - 1)
            self.out_idx = bm.Variable(bm.zeros(self.num, dtype=bm.uint32))
            self.data = bm.Variable(
                bm.zeros((self.num_step.max(), ) + size, dtype=dtype))

        super(ConstantDelay, self).__init__(**kwargs)
Exemplo n.º 7
0
  def __init__(self, size, freq, **kwargs):
    super(PoissonNoise, self).__init__(size=size, **kwargs)

    self.freq = bm.Variable(bm.array([freq]))
    self.dt = bm.get_dt() / 1000.
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    self.rng = bm.random.RandomState()
Exemplo n.º 8
0
def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
    """Get the gradually changed input current.

  Parameters
  ----------
  c_start : float
      The minimum (or maximum) current size.
  c_end : float
      The maximum (or minimum) current size.
  duration : int, float
      The total duration.
  t_start : float
      The ramped current start time-point.
  t_end : float
      The ramped current end time-point. Default is the None.
  dt : float, int, optional
      The numerical precision.

  Returns
  -------
  current_and_duration : tuple
      (The formatted current, total duration)
  """
    dt = math.get_dt() if dt is None else dt
    t_end = duration if t_end is None else t_end

    current = math.zeros(int(np.ceil(duration / dt)), dtype=math.float_)
    p1 = int(np.ceil(t_start / dt))
    p2 = int(np.ceil(t_end / dt))
    current[p1:p2] = math.array(math.linspace(c_start, c_end, p2 - p1),
                                dtype=math.float_)
    return current
Exemplo n.º 9
0
  def test_jacrev1(self):
    def f1(x, y):
      r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r

    br = bm.jacrev(f1)(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    jr = jax.jacrev(f1)(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    assert (br == jr).all()

    br = bm.jacrev(f1, argnums=(0, 1))(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    jr = jax.jacrev(f1, argnums=(0, 1))(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    assert (br[0] == jr[0]).all()
    assert (br[1] == jr[1]).all()
Exemplo n.º 10
0
    def test_syn2post_softmax(self):
        data = bm.arange(5)
        segment_ids = bm.array([0, 0, 1, 1, 2])
        f_ans = bm.syn2post_softmax(data, segment_ids, 3)
        true_ans = bm.asarray([
            jnp.exp(data[0]) / (jnp.exp(data[0]) + jnp.exp(data[1])),
            jnp.exp(data[1]) / (jnp.exp(data[0]) + jnp.exp(data[1])),
            jnp.exp(data[2]) / (jnp.exp(data[2]) + jnp.exp(data[3])),
            jnp.exp(data[3]) / (jnp.exp(data[2]) + jnp.exp(data[3])),
            jnp.exp(data[4]) / jnp.exp(data[4])
        ])
        print()
        print(bm.asarray(f_ans))
        print(true_ans)
        print(f_ans == true_ans)
        # self.assertTrue(bm.array_equal(bm.syn2post_softmax(data, segment_ids, 3),
        #                                true_ans))

        data = bm.arange(5)
        segment_ids = bm.array([0, 0, 1, 1, 2])
        print(bm.syn2post_softmax(data, segment_ids, 4))
Exemplo n.º 11
0
  def test_jacrev_aux1(self):
    x = bm.array([1., 2., 3.])
    y = bm.array([10., 5.])

    def f1(x, y):
      a = 4 * x[1] ** 2 - 2 * x[2]
      r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
      return r, a

    f2 = lambda *args: f1(*args)[0]
    jr = jax.jacrev(f2)(x, y)  # jax jacobian
    pprint(jr)
    grads, aux = bm.jacrev(f1, has_aux=True)(x, y)
    assert (grads == jr).all()
    assert aux == (4 * x[1] ** 2 - 2 * x[2])

    jr = jax.jacrev(f2, argnums=(0, 1))(x, y)  # jax jacobian
    pprint(jr)
    grads, aux = bm.jacrev(f1, argnums=(0, 1), has_aux=True)(x, y)
    assert (grads[0] == jr[0]).all()
    assert (grads[1] == jr[1]).all()
    assert aux == (4 * x[1] ** 2 - 2 * x[2])
Exemplo n.º 12
0
def simulation(duration=5.):
    dt = 0.1 / 1e3
    # random input uniformly distributed between 120 and 320 pulses per second
    all_ps = bm.random.uniform(120, 320, size=(int(duration / dt), 1))
    jrm = JansenRitModel(num=6,
                         C=bm.array([68., 128., 135., 270., 675., 1350.]))
    runner = bp.StructRunner(jrm,
                             monitors=['y0', 'y1', 'y2', 'y3', 'y4', 'y5'],
                             inputs=['p', all_ps, 'iter', '='],
                             dt=dt)
    runner.run(duration)

    start, end = int(2 / dt), int(duration / dt)
    fig, gs = bp.visualize.get_figure(6, 3, 2, 3)
    for i in range(6):
        fig.add_subplot(gs[i, 0])
        title = 'E' if i == 0 else None
        xlabel = 'time [s]' if i == 5 else None
        bp.visualize.line_plot(runner.mon.ts[start:end],
                               runner.mon.y1[start:end, i],
                               title=title,
                               xlabel=xlabel,
                               ylabel='Hz')
        fig.add_subplot(gs[i, 1])
        title = 'P' if i == 0 else None
        bp.visualize.line_plot(runner.mon.ts[start:end],
                               runner.mon.y0[start:end, i],
                               title=title,
                               xlabel=xlabel)
        fig.add_subplot(gs[i, 2])
        title = 'I' if i == 0 else None
        bp.visualize.line_plot(runner.mon.ts[start:end],
                               runner.mon.y2[start:end, i],
                               title=title,
                               show=i == 5,
                               xlabel=xlabel)
Exemplo n.º 13
0
def check_and_format_monitors(host, mon):
    """Return a formatted monitor items:

  >>> [(node, key, target, variable, idx, interval),
  >>>  ...... ]

  """
    assert isinstance(host, DynamicalSystem)
    assert isinstance(mon, Monitor)

    formatted_mon_items = []

    # master node:
    #    Check whether the input target node is accessible,
    #    and check whether the target node has the attribute
    name2node = {
        node.name: node
        for node in list(host.nodes().unique().values())
    }
    for key, idx, interval in zip(mon.item_names, mon.item_indices,
                                  mon.item_intervals):
        # target and variable
        splits = key.split('.')
        if len(splits) == 1:
            if not hasattr(host, splits[0]):
                raise RunningError(f'{host} does not has variable {key}.')
            target = host
            variable = splits[-1]
        else:
            if not hasattr(host, splits[0]):
                if splits[0] not in name2node:
                    raise RunningError(
                        f'Cannot find target {key} in monitor of {host}, please check.'
                    )
                else:
                    target = name2node[splits[0]]
                    assert len(splits) == 2
                    variable = splits[-1]
            else:
                target = host
                for s in splits[:-1]:
                    try:
                        target = getattr(target, s)
                    except KeyError:
                        raise RunningError(
                            f'Cannot find {key} in {host}, please check.')
                variable = splits[-1]

        # idx
        if isinstance(idx, int): idx = math.array([idx])

        # interval
        if interval is not None:
            if not isinstance(interval, float):
                raise RunningError(
                    f'"interval" must be a float (denotes time), but we got {interval}'
                )

        # append
        formatted_mon_items.append((
            key,
            target,
            variable,
            idx,
            interval,
        ))

    return formatted_mon_items
Exemplo n.º 14
0
 def __init__(self):
   super(Test, self).__init__()
   self.x = bm.array([1., 2., 3.])
Exemplo n.º 15
0
 def test_syn2post_mean(self):
     data = bm.arange(5)
     segment_ids = bm.array([0, 0, 1, 1, 2])
     self.assertTrue(
         bm.array_equal(bm.syn2post_mean(data, segment_ids, 3),
                        bm.asarray([0.5, 2.5, 4.])))
Exemplo n.º 16
0
 def test_syn2post_prod(self):
     data = bm.arange(5)
     segment_ids = bm.array([0, 0, 1, 1, 2])
     self.assertTrue(
         bm.array_equal(bm.syn2post_prod(data, segment_ids, 3),
                        bm.asarray([0, 6, 4])))