Ejemplo n.º 1
0
  def __init__(self, size, times, indices, need_sort=True, name=None):
    super(SpikeTimeInput, self).__init__(size=size, name=name)

    # parameters
    if len(indices) != len(times):
      raise ModelBuildError(f'The length of "indices" and "times" must be the same. '
                            f'However, we got {len(indices)} != {len(times)}.')
    self.num_times = len(times)

    # data about times and indices
    self.i = bm.Variable(bm.zeros(1, dtype=bm.int_))
    self.times = bm.Variable(bm.asarray(times, dtype=bm.float_))
    self.indices = bm.Variable(bm.asarray(indices, dtype=bm.int_))
    self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
    if need_sort:
      sort_idx = bm.argsort(times)
      self.indices.value = self.indices[sort_idx]
      self.times.value = self.times[sort_idx]

    # functions
    def cond_fun(t):
      return bm.logical_and(self.i[0] < self.num_times, t >= self.times[self.i[0]])

    def body_fun(t):
      self.spike[self.indices[self.i[0]]] = True
      self.i[0] += 1

    self._run = bm.make_while(cond_fun, body_fun, dyn_vars=self.vars())
Ejemplo n.º 2
0
 def __call__(self, shape, dtype=None):
     shape = [size2len(d) for d in shape]
     fan_in, fan_out = _compute_fans(shape,
                                     in_axis=self.in_axis,
                                     out_axis=self.out_axis)
     if self.mode == "fan_in":
         denominator = fan_in
     elif self.mode == "fan_out":
         denominator = fan_out
     elif self.mode == "fan_avg":
         denominator = (fan_in + fan_out) / 2
     else:
         raise ValueError(
             "invalid mode for variance scaling initializer: {}".format(
                 self.mode))
     variance = math.array(self.scale / denominator, dtype=dtype)
     if self.distribution == "truncated_normal":
         # constant is stddev of standard normal truncated to (-2, 2)
         stddev = math.sqrt(variance) / math.array(.87962566103423978,
                                                   dtype)
         res = self.rng.truncated_normal(-2, 2, shape) * stddev
         return math.asarray(res, dtype=dtype)
     elif self.distribution == "normal":
         res = self.rng.normal(size=shape) * math.sqrt(variance)
         return math.asarray(res, dtype=dtype)
     elif self.distribution == "uniform":
         res = self.rng.uniform(low=-1, high=1, size=shape) * math.sqrt(
             3 * variance)
         return math.asarray(res, dtype=dtype)
     else:
         raise ValueError(
             "invalid distribution for variance scaling initializer")
Ejemplo n.º 3
0
    def build_monitors(self, show_code=False):
        monitors = utils.check_and_format_monitors(host=self.target,
                                                   mon=self.mon)

        returns = []
        code_lines = []
        host = self.target
        code_scope = dict(sys=sys)
        for key, target, variable, idx, interval in monitors:
            code_scope[host.name] = host
            code_scope[target.name] = target

            # get data
            data = target
            for k in variable.split('.'):
                data = getattr(data, k)

            # get the data key in the host
            if not isinstance(data, math.Variable):
                raise RunningError(
                    f'"{key}" in {target} is not a dynamically changed Variable, '
                    f'its value will not change, we think there is no need to '
                    f'monitor its trajectory.')
            if math.ndim(data) == 1:
                key_in_host = f'{target.name}.{variable}.value'
            else:
                key_in_host = f'{target.name}.{variable}.value.flatten()'

            # format the monitor index
            if idx is None:
                right = key_in_host
            else:
                idx = math.asarray(idx)
                right = f'{key_in_host}[_{key.replace(".", "_")}_idx]'
                code_scope[f'_{key.replace(".", "_")}_idx'] = idx.value

            # format the monitor lines according to the time interval
            returns.append(right)
            if interval is not None:
                raise ValueError(
                    f'Running with "{self.__class__.__name__}" does '
                    f'not support "interval" in the monitor.')

        if len(code_lines) or len(returns):
            code_lines.append(f'return {", ".join(returns) + ", "}')
            # function
            code_scope_old = {k: v for k, v in code_scope.items()}
            code, func = tools.code_lines_to_func(lines=code_lines,
                                                  func_name=_mon_func_name,
                                                  func_args=['_t', '_dt'],
                                                  scope=code_scope)
            if show_code:
                print(code)
                print()
                pprint(code_scope_old)
                print()
        else:
            func = lambda _t, _dt: None
        return func
Ejemplo n.º 4
0
 def get_stimulus_by_pos(self, pos):
     assert bm.size(pos) == 2
     x1, x2 = bm.meshgrid(self.x, self.x)
     value = bm.stack([x1.flatten(), x2.flatten()]).T
     d = self.dist(bm.abs(bm.asarray(pos) - value))
     d = bm.linalg.norm(d, axis=1)
     d = d.reshape((self.length, self.length))
     return self.A * bm.exp(-0.25 * bm.square(d / self.a))
Ejemplo n.º 5
0
    def __init__(self,
                 target,
                 monitors=None,
                 inputs=(),
                 dyn_vars=None,
                 jit=False,
                 dt=None,
                 numpy_mon_after_run=True,
                 progress_bar=True):
        self._has_iter_array = False  # default do not have iterable input array

        super(StructRunner,
              self).__init__(target=target,
                             inputs=inputs,
                             monitors=monitors,
                             jit=jit,
                             dt=dt,
                             dyn_vars=dyn_vars,
                             numpy_mon_after_run=numpy_mon_after_run)

        # intrinsic parameters
        self._i = math.Variable(math.asarray([0]))
        self._pbar = None  # progress bar
        self.progress_bar = progress_bar

        # JAX does not support iterator in fori_loop, scan, etc.
        #   https://github.com/google/jax/issues/3567
        # We use Variable i to index the current input data.
        if self._has_iter_array:
            self.dyn_vars.update({'_i': self._i})
        else:
            self._i = None

        # setup step function
        if progress_bar:

            def _step(t_and_dt):
                _t, _dt = t_and_dt[0], t_and_dt[1]
                self._input_step(_t=_t, _dt=_dt)
                for step in self.target.steps.values():
                    step(_t=_t, _dt=_dt)
                # id_tap(lambda *args: self._pbar.update(round(self.dt, 4)), ())
                id_tap(lambda *args: self._pbar.update(), ())
                return self._monitor_step(_t=_t, _dt=_dt)
        else:

            def _step(t_and_dt):
                _t, _dt = t_and_dt[0], t_and_dt[1]
                self._input_step(_t=_t, _dt=_dt)
                for step in self.target.steps.values():
                    step(_t=_t, _dt=_dt)
                return self._monitor_step(_t=_t, _dt=_dt)

        # build the update step
        self._step = math.make_loop(_step,
                                    dyn_vars=self.dyn_vars,
                                    has_return=True)
        if jit: self._step = math.jit(self._step, dyn_vars=dyn_vars)
Ejemplo n.º 6
0
  def _return_by_ij(self, structures, ij: tuple, all_data: dict):
    pre_ids, post_ids = ij
    assert isinstance(pre_ids, np.ndarray)
    assert isinstance(post_ids, np.ndarray)

    if (CONN_MAT in structures) and (CONN_MAT not in all_data):
      all_data[CONN_MAT] = math.asarray(ij2mat(ij, self.pre_num, self.post_num), dtype=MAT_DTYPE)

    if (PRE_IDS in structures) and (PRE_IDS not in all_data):
      all_data[PRE_IDS] = math.asarray(pre_ids, dtype=IDX_DTYPE)

    if (POST_IDS in structures) and (POST_IDS not in all_data):
      all_data[POST_IDS] = math.asarray(post_ids, dtype=IDX_DTYPE)

    require_other_structs = len([s for s in structures
                                 if s not in [CONN_MAT, PRE_IDS, POST_IDS]]) > 0
    if require_other_structs:
      csr = ij2csr(pre_ids, post_ids, self.pre_num)
      self._return_by_csr(structures, csr=csr, all_data=all_data)
Ejemplo n.º 7
0
  def _return_by_mat(self, structures, mat, all_data: dict):
    assert isinstance(mat, np.ndarray) and np.ndim(mat) == 2
    if (CONN_MAT in structures) and (CONN_MAT not in all_data):
      all_data[CONN_MAT] = math.asarray(mat, dtype=MAT_DTYPE)

    require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0
    if require_other_structs:
      pre_ids, post_ids = np.where(mat > 0)
      pre_ids = np.ascontiguousarray(pre_ids, dtype=IDX_DTYPE)
      post_ids = np.ascontiguousarray(post_ids, dtype=IDX_DTYPE)
      self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data)
Ejemplo n.º 8
0
 def get_param(param, size):
     if param is None:
         return None
     if callable(param):
         return bm.TrainVar(param(size))
     if isinstance(param, onp.ndarray):
         assert param.shape == size
         return bm.TrainVar(bm.asarray(param))
     if isinstance(param, (bm.JaxArray, jnp.ndarray)):
         return bm.TrainVar(param)
     raise ValueError
Ejemplo n.º 9
0
    def compute_jacobians(self, points):
        """Compute the jacobian matrices at the points.

    Parameters
    ----------
    points: np.ndarray, bm.JaxArray, jax.ndarray
      The fixed points with the shape of (num_point, num_dim).

    Returns
    -------
    jacobians : bm.JaxArray
      npoints number of jacobians, np array with shape npoints x dim x dim
    """
        # if len(self.fixed_points) == 0: return
        if bm.ndim(points) == 1:
            points = bm.asarray([
                points,
            ])
        assert bm.ndim(points) == 2
        return self.f_jacob_batch(bm.asarray(points))
Ejemplo n.º 10
0
    def test_syn2post_softmax(self):
        data = bm.arange(5)
        segment_ids = bm.array([0, 0, 1, 1, 2])
        f_ans = bm.syn2post_softmax(data, segment_ids, 3)
        true_ans = bm.asarray([
            jnp.exp(data[0]) / (jnp.exp(data[0]) + jnp.exp(data[1])),
            jnp.exp(data[1]) / (jnp.exp(data[0]) + jnp.exp(data[1])),
            jnp.exp(data[2]) / (jnp.exp(data[2]) + jnp.exp(data[3])),
            jnp.exp(data[3]) / (jnp.exp(data[2]) + jnp.exp(data[3])),
            jnp.exp(data[4]) / jnp.exp(data[4])
        ])
        print()
        print(bm.asarray(f_ans))
        print(true_ans)
        print(f_ans == true_ans)
        # self.assertTrue(bm.array_equal(bm.syn2post_softmax(data, segment_ids, 3),
        #                                true_ans))

        data = bm.arange(5)
        segment_ids = bm.array([0, 0, 1, 1, 2])
        print(bm.syn2post_softmax(data, segment_ids, 4))
Ejemplo n.º 11
0
def check_initials(initials, target_var_names):
  # check the initial values
  assert isinstance(initials, dict)
  for p in target_var_names:
    assert p in initials
  initials = {p: bm.asarray(initials[p], dtype=bm.float_) for p in target_var_names}
  len_of_init = []
  for v in initials.values():
    assert isinstance(v, (tuple, list, np.ndarray, jnp.ndarray, bm.ndarray))
    len_of_init.append(len(v))
  len_of_init = np.unique(len_of_init)
  assert len(len_of_init) == 1
  return initials
Ejemplo n.º 12
0
  def make_returns(self, structures, csr=None, mat=None, ij=None):
    """Make the desired synaptic structures and return them.
    """
    # checking
    all_data = dict()
    if (csr is None) and (mat is None) and (ij is None):
      raise ConnectorError('Must provide one of "csr", "mat" or "ij".')
    structures = (structures,) if isinstance(structures, str) else structures
    assert isinstance(structures, (tuple, list))

    # "csr" structure
    if csr is not None:
      assert isinstance(csr[0], np.ndarray)
      assert isinstance(csr[1], np.ndarray)
      if (PRE2POST in structures) and (PRE2POST not in all_data):
        all_data[PRE2POST] = (math.asarray(csr[0], dtype=IDX_DTYPE),
                              math.asarray(csr[1], dtype=IDX_DTYPE))
      self._return_by_csr(structures, csr=csr, all_data=all_data)
    # "mat" structure
    if mat is not None:
      assert isinstance(mat, np.ndarray) and np.ndim(mat) == 2
      if (CONN_MAT in structures) and (CONN_MAT not in all_data):
        all_data[CONN_MAT] = math.asarray(mat, dtype=MAT_DTYPE)
      self._return_by_mat(structures, mat=mat, all_data=all_data)
    # "ij" structure
    if ij is not None:
      assert isinstance(ij[0], np.ndarray)
      assert isinstance(ij[1], np.ndarray)
      if (PRE_IDS in structures) and (PRE_IDS not in structures):
        all_data[PRE_IDS] = math.asarray(ij[0], dtype=IDX_DTYPE)
      if (POST_IDS in structures) and (POST_IDS not in structures):
        all_data[POST_IDS] = math.asarray(ij[1], dtype=IDX_DTYPE)
      self._return_by_ij(structures, ij=ij, all_data=all_data)

    # return
    if len(structures) == 1:
      return all_data[structures[0]]
    else:
      return tuple([all_data[n] for n in structures])
Ejemplo n.º 13
0
def load_npz(filename, target, verbose=False, check=False):
  global math, Base
  if Base is None: from brainpy.base.base import Base
  if math is None: from brainpy import math
  assert isinstance(target, Base)

  all_vars = target.vars(method='relative')
  all_data = np.load(filename)
  for key in all_data.files:
    if verbose: print(f'Loading {key} ...')
    var = all_vars.pop(key)
    var[:] = math.asarray(all_data[key])
  if check: _check_missing(all_vars, filename=filename)
Ejemplo n.º 14
0
    def __init__(self,
                 num_input,
                 num_hidden,
                 num_output,
                 num_batch,
                 dt=None,
                 e_ratio=0.8,
                 sigma_rec=0.,
                 seed=None,
                 w_ir=bp.init.KaimingUniform(scale=1.),
                 w_rr=bp.init.KaimingUniform(scale=1.),
                 w_ro=bp.init.KaimingUniform(scale=1.)):
        super(RNN, self).__init__()

        # parameters
        self.tau = 100
        self.num_batch = num_batch
        self.num_input = num_input
        self.num_hidden = num_hidden
        self.num_output = num_output
        self.e_size = int(num_hidden * e_ratio)
        self.i_size = num_hidden - self.e_size
        if dt is None:
            self.alpha = 1
        else:
            self.alpha = dt / self.tau
        self.sigma_rec = (2 * self.alpha)**0.5 * sigma_rec  # Recurrent noise
        self.rng = bm.random.RandomState(seed=seed)

        # hidden mask
        mask = np.tile([1] * self.e_size + [-1] * self.i_size, (num_hidden, 1))
        np.fill_diagonal(mask, 0)
        self.mask = bm.asarray(mask, dtype=bm.float_)

        # input weight
        self.w_ir = self.get_param(w_ir, (num_input, num_hidden))

        # recurrent weight
        bound = 1 / num_hidden**0.5
        self.w_rr = self.get_param(w_rr, (num_hidden, num_hidden))
        self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size)
        self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden))

        # readout weight
        bound = 1 / self.e_size**0.5
        self.w_ro = self.get_param(w_ro, (self.e_size, num_output))
        self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output))

        # variables
        self.h = bm.Variable(bm.zeros((num_batch, num_hidden)))
        self.o = bm.Variable(bm.zeros((num_batch, num_output)))
Ejemplo n.º 15
0
    def __call__(self, duration, start_t=None):
        """The running function.

    Parameters
    ----------
    duration : float, int, tuple, list
      The running duration.
    start_t : float, optional
      The start simulation time.

    Returns
    -------
    running_time : float
      The total running time.
    """
        # time step
        if start_t is None:
            if self._start_t is None:
                start_t = 0.
            else:
                start_t = self._start_t
        end_t = start_t + duration

        # times
        times = math.arange(start_t, end_t, self.dt)

        # build inputs
        for key in self.mon.item_contents.keys():
            self.mon.item_contents[key] = []  # reshape the monitor items

        # simulations
        t0 = time.time()
        pbar = tqdm.auto.tqdm(total=times.size)
        pbar.set_description(
            f"Running a duration of {round(float(duration), 3)} ({times.size} steps)",
            refresh=True)
        for run_idx in range(times.size):
            self._step((times[run_idx], self.dt))
            pbar.update()
        pbar.close()
        running_time = time.time() - t0

        # monitor post steps
        self.mon.ts = times
        for key, val in self.mon.item_contents.items():
            self.mon.item_contents[key] = math.asarray(val)
        self._start_t = end_t
        if self.numpy_mon_after_run:
            self.mon.numpy()
        return running_time
Ejemplo n.º 16
0
 def _loop_func(t_and_dt):
     out_vars = {k: [] for k in self.mon.item_names}
     times, dts = t_and_dt
     for i in range(len(times)):
         _t = times[i]
         _dt = dts[i]
         self._step([_t, _dt])
         for k in self.mon.item_names:
             out_vars[k].append(
                 math.as_device_array(self.variables[k]))
     out_vars = {
         k: math.asarray(out_vars[k])
         for k in self.mon.item_names
     }
     return out_vars
Ejemplo n.º 17
0
def load_h5(filename, target, verbose=False, check=False):
  global math, Base
  if Base is None: from brainpy.base.base import Base
  if math is None: from brainpy import math
  assert isinstance(target, Base)
  _check(h5py, module_name='h5py', ext=os.path.splitext(filename))

  all_vars = target.vars(method='relative')
  f = h5py.File(filename, "r")
  for key in f.keys():
    if verbose: print(f'Loading {key} ...')
    var = all_vars.pop(key)
    var[:] = math.asarray(f[key][:])
  f.close()
  if check: _check_missing(all_vars, filename=filename)
Ejemplo n.º 18
0
def load_pkl(filename, target, verbose=False, check=False):
  global math, Base
  if Base is None: from brainpy.base.base import Base
  if math is None: from brainpy import math
  assert isinstance(target, Base)
  f = open(filename, 'rb')
  all_data = pickle.load(f)
  f.close()

  all_vars = target.vars(method='relative')
  for key, data in all_data.items():
    if verbose: print(f'Loading {key} ...')
    var = all_vars.pop(key)
    var[:] = math.asarray(data)
  if check: _check_missing(all_vars, filename=filename)
    def __init__(self, pre, post, conn_prob=0.1):
        super(ThalamusInput, self).__init__(pre=pre,
                                            post=post,
                                            conn=bp.conn.FixedProb(conn_prob))
        self.check_pre_attrs('spike')
        self.check_post_attrs('I')

        # connection and weights
        self.pre2post = self.conn.require('pre2post')
        self.syn_num = self.pre2post[0].size
        self.weights = bm.random.normal(*ExpSyn.exc_weight, size=self.syn_num)
        self.weights = bm.where(self.weights < 0., 0., self.weights)

        # variables
        self.turn_on = bm.Variable(bm.asarray([False]))
Ejemplo n.º 20
0
def d4_system():
    model = GJCoupledFHN(2)
    model.gjw = 0.01
    # Iext = bm.asarray([0., 0.1])
    Iext = bm.asarray([0., 0.6])

    # simulation
    runner = bp.StructRunner(model, monitors=['V'], inputs=['Iext', Iext])
    runner.run(300.)
    bp.visualize.line_plot(runner.mon.ts,
                           runner.mon.V,
                           legend='V',
                           plot_ids=list(range(model.num)),
                           show=True)

    # analysis
    def step(vw):
        v, w = bm.split(vw, 2)
        dv = model.dV(v, 0., w, Iext)
        dw = model.dw(w, 0., v)
        return bm.concatenate([dv, dw])

    finder = bp.analysis.SlowPointFinder(f_cell=step)
    # finder.find_fps_with_gd_method(
    #   candidates=bm.random.normal(0., 2., (1000, model.num * 2)),
    #   tolerance=1e-5,
    #   num_batch=200,
    #   opt_setting=dict(method=bm.optimizers.Adam, lr=bm.optimizers.ExponentialDecay(0.05, 1, 0.9999)),
    # )

    finder.find_fps_with_opt_solver(
        candidates=bm.random.normal(0., 2., (1000, model.num * 2)))
    finder.filter_loss(1e-7)
    finder.keep_unique()

    print('fixed_points: ', finder.fixed_points)
    print('losses:', finder.losses)
    if len(finder.fixed_points):
        jac = finder.compute_jacobians(finder.fixed_points)
        for i in range(len(finder.fixed_points)):
            eigval, eigvec = np.linalg.eig(np.asarray(jac[i]))
            plt.figure()
            plt.scatter(np.real(eigval), np.imag(eigval))
            plt.plot([0, 0], [-1, 1], '--')
            plt.xlabel('Real')
            plt.ylabel('Imaginary')
            plt.title(f'FP {i}')
            plt.show()
Ejemplo n.º 21
0
 def __call__(self, shape, dtype=None):
     shape = [size2len(d) for d in shape]
     n_rows = shape[self.axis]
     n_cols = np.prod(shape) // n_rows
     matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols,
                                                              n_rows)
     norm_dst = self.rng.normal(size=matrix_shape)
     q_mat, r_mat = np.linalg.qr(norm_dst)
     # Enforce Q is uniformly distributed
     q_mat *= np.sign(np.diag(r_mat))
     if n_rows < n_cols:
         q_mat = q_mat.T
     q_mat = np.reshape(q_mat,
                        (n_rows, ) + tuple(np.delete(shape, self.axis)))
     q_mat = np.moveaxis(q_mat, 0, self.axis)
     return self.scale * math.asarray(q_mat, dtype=dtype)
Ejemplo n.º 22
0
 def dist(self, d):
     v_size = bm.asarray([self.z_range, self.z_range])
     return bm.where(d > v_size / 2, v_size - d, d)
Ejemplo n.º 23
0
 def __call__(self, shape, dtype=None):
     shape = [size2len(d) for d in shape]
     r = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape)
     return math.asarray(r, dtype=dtype)
Ejemplo n.º 24
0
 def __call__(self, shape, dtype=None):
     shape = [size2len(d) for d in shape]
     weights = self.rng.normal(size=shape, scale=self.scale)
     return math.asarray(weights, dtype=dtype)
Ejemplo n.º 25
0
# training function
@bm.jit
@bm.function(nodes=(net, opt))
def train(xs, ys):
    grads, (loss, os) = grad_f(xs, ys)
    opt.update(grads)
    return loss, os


# %%
running_acc = 0
running_loss = 0
for i in range(1500):
    inputs, labels_np = dataset()
    inputs = bm.asarray(inputs)
    labels = bm.asarray(labels_np)
    loss, outputs = train(inputs, labels)
    running_loss += loss
    # Compute performance
    output_np = np.argmax(outputs.numpy(), axis=-1).flatten()
    labels_np = labels_np.flatten()
    ind = labels_np > 0  # Only analyze time points when target is not fixation
    running_acc += np.mean(labels_np[ind] == output_np[ind])
    if i % 100 == 99:
        running_loss /= 100
        running_acc /= 100
        print('Step {}, Loss {:0.4f}, Acc {:0.3f}'.format(
            i + 1, running_loss, running_acc))
        running_loss = 0
        running_acc = 0
Ejemplo n.º 26
0
    def __call__(self, shape, dtype=None):
        """Build the weights.

    Parameters
    ----------
    shape : tuple of int, list of int, int
      The network shape. Note, this is not the weight shape.
    """
        if isinstance(shape, int):
            shape = (shape, )
        net_size = tools.size2num(shape)

        # value ranges to encode
        if self.encoding_values is None:
            value_ranges = tuple([(0, s) for s in shape])
        elif isinstance(self.encoding_values, (tuple, list)):
            if len(self.encoding_values) == 0:
                raise ValueError
            elif isinstance(self.encoding_values[0], (int, float)):
                assert len(self.encoding_values) == 2
                assert self.encoding_values[0] < self.encoding_values[1]
                value_ranges = tuple([self.encoding_values for _ in shape])
            elif isinstance(self.encoding_values[0], (tuple, list)):
                if len(self.encoding_values) != len(shape):
                    raise ValueError(
                        f'The network size has {len(shape)} dimensions, while '
                        f'the encoded values provided only has {len(self.encoding_values)}-D. '
                        f'Error in {str(self)}.')
                for v in self.encoding_values:
                    assert isinstance(v[0], (int, float))
                    assert len(v) == 2
                value_ranges = tuple(self.encoding_values)
            else:
                raise ValueError(
                    f'Unsupported encoding values: {self.encoding_values}')
        else:
            raise ValueError(
                f'Unsupported encoding values: {self.encoding_values}')

        # values
        values = [
            np.linspace(vs[0], vs[1], n + 1)[:n]
            for vs, n in zip(value_ranges, shape)
        ]
        post_values = np.stack([v.flatten() for v in np.meshgrid(*values)])
        value_sizes = np.array([v[1] - v[0] for v in value_ranges])
        if value_sizes.ndim < post_values.ndim:
            value_sizes = np.expand_dims(
                value_sizes,
                axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))

        # connectivity matrix
        conn_mat = []
        for i in range(net_size):
            # values for node i
            i_coordinate = tuple()
            for s in shape[:-1]:
                i, pos = divmod(i, s)
                i_coordinate += (pos, )
            i_coordinate += (i, )
            i_value = np.array(
                [values[i][c] for i, c in enumerate(i_coordinate)])
            if i_value.ndim < post_values.ndim:
                i_value = np.expand_dims(
                    i_value,
                    axis=tuple([i + 1 for i in range(post_values.ndim - 1)]))
            # distances
            dists = np.abs(i_value - post_values)
            if self.periodic_boundary:
                dists = np.where(dists > value_sizes / 2, value_sizes - dists,
                                 dists)
            exp_dists = np.exp(
                -(np.linalg.norm(dists, axis=0) / self.sigma)**2 / 2)
            conn_mat.append(exp_dists)
        conn_mat = np.stack(conn_mat)
        if self.normalize:
            conn_mat /= conn_mat.max()
        if not self.include_self:
            np.fill_diagonal(conn_mat, 0.)

        # connectivity weights
        conn_weights = conn_mat * self.max_w
        conn_weights = np.where(conn_weights < self.min_w, 0., conn_weights)
        return math.asarray(conn_weights, dtype=dtype)
Ejemplo n.º 27
0
 def f3(x, y):
   r1 = bm.asarray([x[0] * y[0], 5 * x[2] * y[1]])
   r2 = bm.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
   return r1, r2
Ejemplo n.º 28
0
  def test_jacfwd_and_aux_nested(self):
    def f(x):
      jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x)
      return aux[0]

    f2 = lambda x: x ** 3

    self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.))

    self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))

    def f(x):
      jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x)
      return aux[0] * bm.sin(x)

    f2 = lambda x: x ** 3 * bm.sin(x)

    self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.))

    self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
Ejemplo n.º 29
0
 def test_syn2post_prod(self):
     data = bm.arange(5)
     segment_ids = bm.array([0, 0, 1, 1, 2])
     self.assertTrue(
         bm.array_equal(bm.syn2post_prod(data, segment_ids, 3),
                        bm.asarray([0, 6, 4])))
Ejemplo n.º 30
0
 def test_syn2post_mean(self):
     data = bm.arange(5)
     segment_ids = bm.array([0, 0, 1, 1, 2])
     self.assertTrue(
         bm.array_equal(bm.syn2post_mean(data, segment_ids, 3),
                        bm.asarray([0.5, 2.5, 4.])))