def __init__(self, model, dt, unroll_simulation, dtype, minibatch_size, device, progress): self.model = model self.dt = dt self.unroll = unroll_simulation self.dtype = dtype self.minibatch_size = minibatch_size self.device = device self.graph = tf.Graph() self.signals = signals.SignalDict(self.dtype, self.minibatch_size) self.inference_only = config.get_setting(model, "inference_only", False) # find invariant inputs (nodes that don't receive any input other # than the simulation time). we'll compute these outside the simulation # and feed in the result. if self.model.toplevel is None: self.invariant_inputs = OrderedDict() else: self.invariant_inputs = OrderedDict( (n, n.output) for n in self.model.toplevel.all_nodes if n.size_in == 0 and not isinstance(n, tensor_node.TensorNode)) # filter unused operators # remove TimeUpdate because it is executed as part of the simulation # loop, not part of the step plan. remove input nodes because they # are executed outside the simulation. node_processes = [ n.output for n in self.invariant_inputs if isinstance(n.output, Process) ] operators = [ op for op in self.model.operators if not (isinstance(op, TimeUpdate) or (isinstance(op, SimPyFunc) and op.x is None) or (isinstance(op, SimProcess) and op.input is None and op.process in node_processes)) ] # mark trainable signals self.mark_signals() logger.info("Initial plan length: %d", len(operators)) # apply graph simplification functions simplifications = config.get_setting(model, "simplifications", [ graph_optimizer.remove_constant_copies, graph_optimizer.remove_unmodified_resets, graph_optimizer.remove_zero_incs, graph_optimizer.remove_identity_muls, ]) with progress.sub("operator simplificaton", max_value=None): old_operators = [] while len(old_operators) != len(operators) or any( x is not y for x, y in zip(operators, old_operators)): old_operators = operators for simp in simplifications: operators = simp(operators) # group mergeable operators planner = config.get_setting(model, "planner", graph_optimizer.tree_planner) with progress.sub("merging operators", max_value=None): plan = planner(operators) # TODO: we could also merge operators sequentially (e.g., combine # a copy and dotinc into one op), as long as the intermediate signal # is only written to by one op and read by one op # order signals/operators to promote contiguous reads sorter = config.get_setting(model, "sorter", graph_optimizer.order_signals) with progress.sub("ordering signals", max_value=None): sigs, self.plan = sorter(plan, n_passes=10) # create base arrays and map Signals to TensorSignals (views on those # base arrays) with progress.sub("creating signals", max_value=None): self.create_signals(sigs) logger.info("Optimized plan length: %d", len(self.plan)) logger.info("Number of base arrays: %d", len(self.base_arrays_init)) # initialize op builder build_config = builder.BuildConfig( inference_only=self.inference_only, lif_smoothing=config.get_setting(self.model, "lif_smoothing"), cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed, ) self.op_builder = builder.Builder(self.plan, self.graph, self.signals, build_config)
def call(self, inputs, training=None, progress=None, stateful=False): """ Constructs the graph elements to simulate the model. Parameters ---------- inputs : list of ``tf.Tensor`` Input layers/tensors for the network (must match the structure defined in `.build_inputs`). training : bool Whether the network is being run in training or inference mode. If None, uses the symbolic Keras learning phase variable. progress : `.utils.ProgressBar` Progress bar for construction stage. stateful : bool Whether or not to build the model to support preserving the internal state between executions. Returns ------- probe_arrays : list of ``tf.Tensor`` Tensors representing the output of all the Probes in the network (order corresponding to ``self.model.probes``, which is the order the Probes were instantiated). """ super().call(inputs, training=training) if training == 1 and self.inference_only: raise BuildError( "TensorGraph was created with inference_only=True; cannot be built " "with training=%s" % training) tf.random.set_seed(self.seed) if progress is None: progress = utils.NullProgressBar() # reset signaldict self.signals.reset() # create these constants once here for reuse in different operators self.signals.dt = tf.constant(self.dt, self.dtype) self.signals.dt_val = self.dt # store the actual value as well self.signals.zero = tf.constant(0, self.dtype) self.signals.one = tf.constant(1, self.dtype) # set up invariant inputs with trackable.no_automatic_dependency_tracking_scope(self): self.node_inputs = {} for n, inp in zip(self.invariant_inputs, inputs): # specify shape of inputs (keras sometimes loses this shape information) inp.set_shape([self.minibatch_size, inp.shape[1], n.size_out]) self.node_inputs[n] = inp self.steps_to_run = inputs[-1][0, 0] # initialize op builder build_config = builder.BuildConfig( inference_only=self.inference_only, lif_smoothing=config.get_setting(self.model, "lif_smoothing"), cpu_only=self.device == "/cpu:0" or not utils.tf_gpu_installed, rng=np.random.RandomState(self.seed), training=(tf.keras.backend.learning_phase() if training is None else training), ) self.op_builder = builder.Builder(self.plan, self.signals, build_config) # pre-build stage with progress.sub("pre-build stage", max_value=len(self.plan)) as sub: self.op_builder.build_pre(sub) # build stage with progress.sub("build stage", max_value=len(self.plan) * self.unroll) as sub: steps_run, probe_arrays, final_internal_state, final_base_params = ( self._build_loop(sub) if self.use_loop else self._build_no_loop(sub)) # store these so that they can be accessed after the initial build with trackable.no_automatic_dependency_tracking_scope(self): self.steps_run = steps_run self.probe_arrays = probe_arrays self.final_internal_state = final_internal_state self.final_base_params = final_base_params # logging logger.info("Number of reads: %d", sum(x for x in self.signals.read_types.values())) for x in self.signals.read_types.items(): logger.info(" %s: %d", *x) logger.info("Number of writes: %d", sum(x for x in self.signals.write_types.values())) for x in self.signals.write_types.items(): logger.info(" %s: %d", *x) # note: always return steps_run so that the simulation will run for the given # number of steps, even if there are no output probes outputs = list(probe_arrays.values()) + [steps_run] updates = [] if stateful: # update saved state updates.extend( var.assign(val) for var, val in zip(self.saved_state.values(), final_internal_state)) # if any of the base params have changed (due to online learning rules) then we # also need to assign those back to the original variable (so that their # values will persist). any parameters targeted by online learning rules # will be minibatched, so we only need to update the minibatched params. for (key, var), val in zip(self.base_params.items(), final_base_params): try: minibatched = self.base_arrays_init["non_trainable"][key][-1] except KeyError: minibatched = self.base_arrays_init["trainable"][key][-1] if minibatched: updates.append(var.assign(val)) logger.info("Number of variable updates: %d", len(updates)) if len(updates) > 0: with tf.control_dependencies(updates): outputs = [tf.identity(x) for x in outputs] return outputs
def __init__( self, model, dt, unroll_simulation, minibatch_size, device, progress, seed ): super().__init__( name="TensorGraph", dynamic=False, trainable=not config.get_setting(model, "inference_only", False), dtype=config.get_setting(model, "dtype", "float32"), batch_size=minibatch_size, ) self.model = model self.dt = dt self.unroll = unroll_simulation self.use_loop = config.get_setting(model, "use_loop", True) self.minibatch_size = minibatch_size self.device = device self.seed = seed self.inference_only = not self.trainable self.signals = signals.SignalDict(self.dtype, self.minibatch_size) # find invariant inputs (nodes that don't receive any input other # than the simulation time). we'll compute these outside the simulation # and feed in the result. if self.model.toplevel is None: self.invariant_inputs = OrderedDict() else: self.invariant_inputs = OrderedDict( (n, n.output) for n in self.model.toplevel.all_nodes if n.size_in == 0 and not isinstance(n, tensor_node.TensorNode) ) # remove input nodes because they are executed outside the simulation node_processes = [ n.output for n in self.invariant_inputs if isinstance(n.output, Process) ] operators = [ op for op in self.model.operators if not ( (isinstance(op, SimPyFunc) and op.x is None) or ( isinstance(op, SimProcess) and op.input is None and op.process in node_processes ) ) ] # mark trainable signals self.mark_signals() logger.info("Initial plan length: %d", len(operators)) # apply graph simplification functions simplifications = config.get_setting( model, "simplifications", graph_optimizer.default_simplifications, ) with progress.sub("operator simplificaton", max_value=None): old_operators = [] while len(old_operators) != len(operators) or any( x is not y for x, y in zip(operators, old_operators) ): old_operators = operators for simp in simplifications: operators = simp(operators) # group mergeable operators planner = config.get_setting(model, "planner", graph_optimizer.tree_planner) with progress.sub("merging operators", max_value=None): plan = planner(operators) # TODO: we could also merge operators sequentially (e.g., combine # a copy and dotinc into one op), as long as the intermediate signal # is only written to by one op and read by one op # order signals/operators to promote contiguous reads sorter = config.get_setting(model, "sorter", graph_optimizer.order_signals) with progress.sub("ordering signals", max_value=None): sigs, self.plan = sorter(plan, n_passes=10) # create base arrays and map Signals to TensorSignals (views on those # base arrays) with progress.sub("creating signals", max_value=None): self.create_signals(sigs) # generate unique names for layer inputs/outputs # this follows the TensorFlow unique naming scheme, so if multiple objects are # created with the same name, they will be named like name, NAME_1, name_2 # (note: case insensitive) self.io_names = {} name_count = defaultdict(int) for obj in list(self.invariant_inputs.keys()) + self.model.probes: name = ( type(obj).__name__.lower() if obj.label is None else utils.sanitize_name(obj.label) ) key = name.lower() if name_count[key] > 0: name += "_%d" % name_count[key] self.io_names[obj] = name name_count[key] += 1 # set up op builder self.op_builder = builder.Builder(self.plan) # logging logger.info("Optimized plan length: %d", len(self.plan)) logger.info( "Number of base arrays: (%s, %d), (%s, %d), (%s, %d)", *sum(((k, len(x)) for k, x in self.base_arrays_init.items()), ()), )