def __init__(self, network, dt=0.001, seed=None, model=None, context=None, n_prealloc_probes=32, profiling=None, if_python_code='none', planner=greedy_planner, progress_bar=True): # --- check version if nengo.version.version_info in bad_nengo_versions: raise ValueError( "This simulator does not support Nengo version %s. Upgrade " "with 'pip install --upgrade --no-deps nengo'." % nengo.__version__) elif nengo.version.version_info > latest_nengo_version_info: warnings.warn("This version of `nengo_ocl` has not been tested " "with your `nengo` version (%s). The latest fully " "supported version is %s" % (nengo.__version__, latest_nengo_version)) # --- create these first since they are used in __del__ self.closed = False self.model = None # --- arguments/attributes if context is None and Simulator.some_context is None: print('No context argument was provided to nengo_ocl.Simulator') print("Calling pyopencl.create_some_context() for you now:") Simulator.some_context = cl.create_some_context() if profiling is None: profiling = int(os.getenv("NENGO_OCL_PROFILING", 0)) self.context = Simulator.some_context if context is None else context self.profiling = profiling self.queue = cl.CommandQueue( self.context, properties=PROFILING_ENABLE if self.profiling else 0) if if_python_code not in ['none', 'warn', 'error']: raise ValueError("%r not a valid value for `if_python_code`" % if_python_code) self.if_python_code = if_python_code self.n_prealloc_probes = n_prealloc_probes self.progress_bar = progress_bar # --- Nengo build with Timer() as nengo_timer: if model is None: self.model = Model(dt=float(dt), label="%s, dt=%f" % (network, dt), decoder_cache=get_default_decoder_cache()) else: self.model = model if network is not None: # Build the network into the model self.model.build(network) logger.info("Nengo build in %0.3f s" % nengo_timer.duration) # --- operators with Timer() as planner_timer: operators = list(self.model.operators) # convert DotInc and Copy to MultiDotInc operators = list(map(MultiDotInc.convert_to, operators)) operators = MultiDotInc.compress(operators) # plan the order of operations, combining where appropriate op_groups = planner(operators) assert len([typ for typ, _ in op_groups if typ is Reset ]) < 2, ("All resets not planned together") self.operators = operators self.op_groups = op_groups logger.info("Planning in %0.3f s" % planner_timer.duration) with Timer() as signals_timer: # Initialize signals all_signals = stable_unique(sig for op in operators for sig in op.all_signals) all_bases = stable_unique(sig.base for sig in all_signals) sigdict = SignalDict() # map from Signal.base -> ndarray for op in operators: op.init_signals(sigdict) # Add built states to the probe dictionary self._probe_outputs = dict(self.model.params) # Provide a nicer interface to probe outputs self.data = ProbeDict(self._probe_outputs) # Create data on host and add views self.all_data = RaggedArray( [sigdict[sb] for sb in all_bases], names=[getattr(sb, 'name', '') for sb in all_bases], dtype=np.float32) view_builder = ViewBuilder(all_bases, self.all_data) view_builder.setup_views(operators) for probe in self.model.probes: view_builder.append_view(self.model.sig[probe]['in']) view_builder.add_views_to(self.all_data) self.all_bases = all_bases self.sidx = { k: np.int32(v) for k, v in iteritems(view_builder.sidx) } self._A_views = view_builder._A_views self._X_views = view_builder._X_views self._YYB_views = view_builder._YYB_views del view_builder # Copy data to device self.all_data = CLRaggedArray(self.queue, self.all_data) logger.info("Signals in %0.3f s" % signals_timer.duration) # --- set seed self.seed = np.random.randint(npext.maxint) if seed is None else seed self.rng = np.random.RandomState(self.seed) # --- create list of plans self._raggedarrays_to_reset = {} self._cl_rngs = {} self._python_rngs = {} plans = [] with Timer() as plans_timer: for op_type, op_list in op_groups: plans.extend(self.plan_op_group(op_type, op_list)) plans.extend(self.plan_probes()) logger.info("Plans in %0.3f s" % plans_timer.duration) # -- create object to execute list of plans self._plans = Plans(plans, self.profiling) self.rng = None # all randomness set, should no longer be used self._reset_probes() # clears probes from previous model builds
def __init__(self, network, dt=0.001, seed=None, model=None, planner=greedy_planner): with Timer() as nengo_timer: if model is None: self.model = Model(dt=float(dt), label="%s, dt=%f" % (network, dt), decoder_cache=get_default_decoder_cache()) else: self.model = model if network is not None: # Build the network into the model self.model.build(network) logger.info("Nengo build in %0.3f s" % nengo_timer.duration) # --- set seed seed = np.random.randint(npext.maxint) if seed is None else seed self.seed = seed self.rng = np.random.RandomState(self.seed) self._step = Signal(np.array(0.0, dtype=np.float64), name='step') self._time = Signal(np.array(0.0, dtype=np.float64), name='time') # --- operators with Timer() as planner_timer: operators = list(self.model.operators) # convert DotInc, Reset, Copy, and ProdUpdate to MultiProdUpdate operators = list(map(MultiProdUpdate.convert_to, operators)) operators = MultiProdUpdate.compress(operators) # plan the order of operations, combining where appropriate op_groups = planner(operators) assert len([typ for typ, _ in op_groups if typ is Reset ]) < 2, ("All resets not planned together") # add time operator after planning, to ensure it goes first time_op = TimeUpdate(self._step, self._time) operators.insert(0, time_op) op_groups.insert(0, (type(time_op), [time_op])) self.operators = operators self.op_groups = op_groups logger.info("Planning in %0.3f s" % planner_timer.duration) with Timer() as signals_timer: # Initialize signals all_signals = signals_from_operators(operators) all_bases = stable_unique([sig.base for sig in all_signals]) sigdict = SignalDict() # map from Signal.base -> ndarray for op in operators: op.init_signals(sigdict) # Add built states to the probe dictionary self._probe_outputs = self.model.params # Provide a nicer interface to probe outputs self.data = ProbeDict(self._probe_outputs) self.all_data = RaggedArray( [sigdict[sb] for sb in all_bases], [getattr(sb, 'name', '') for sb in all_bases], dtype=np.float32) builder = ViewBuilder(all_bases, self.all_data) self._AX_views = {} self._YYB_views = {} for op_type, op_list in op_groups: self.setup_views(builder, op_type, op_list) for probe in self.model.probes: builder.append_view(self.model.sig[probe]['in']) builder.add_views_to(self.all_data) self.all_bases = all_bases self.sidx = builder.sidx self._prep_all_data() logger.info("Signals in %0.3f s" % signals_timer.duration) # --- create list of plans with Timer() as plans_timer: self._plan = [] for op_type, op_list in op_groups: self._plan.extend(self.plan_op_group(op_type, op_list)) self._plan.extend(self.plan_probes()) logger.info("Plans in %0.3f s" % plans_timer.duration) self.n_steps = 0