def test_create_signals_partition(): # check that signals are partitioned based on plan sigs = [DummySignal(), DummySignal(), DummySignal(), DummySignal()] plan = [tuple(DummyOp(reads=[x]) for x in sigs[:2]), tuple(DummyOp(reads=[x]) for x in sigs[2:])] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key != sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check that signals are partioned for different read blocks plan = [tuple(DummyOp(reads=[sigs[i], sigs[2 + i]]) for i in range(2))] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key != sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check that signals are partioned for different sig types plan = [tuple(DummyOp(reads=[sigs[i]], sets=[sigs[2 + i]]) for i in range(2))] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key != sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check that resets are ignored sigs = [DummySignal(), DummySignal(), DummySignal(), DummySignal()] plan = [tuple(Reset(x) for x in sigs)] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert len(bases) == 4
def test_create_signals_views(): sigs = [DummySignal(shape=(2, 2), base_shape=(4,)), DummySignal(shape=(2, 2), base_shape=(4,))] sigs += [sigs[0].base, sigs[1].base] plan = [tuple(DummyOp(reads=[x]) for x in sigs)] bases, sig_map = create_signals(sigs[2:], plan, np.float32, 10) assert list(bases.values())[0][0].shape == (8, 10) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key == sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key assert np.all(sig_map[sigs[0]].indices == (0, 1, 2, 3)) assert np.all(sig_map[sigs[1]].indices == (4, 5, 6, 7)) assert np.all(sig_map[sigs[0]].indices == sig_map[sigs[2]].indices) assert np.all(sig_map[sigs[1]].indices == sig_map[sigs[3]].indices)
def test_create_signals(): # check that floats/ints get split into different arrays sigs = [DummySignal(dtype=np.float32), DummySignal(dtype=np.float32), DummySignal(dtype=np.int32), DummySignal(dtype=np.int32)] plan = [tuple(DummyOp(reads=[x]) for x in sigs)] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key != sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check that floats all get converted to same precision and combined sigs = [DummySignal(dtype=np.float32), DummySignal(dtype=np.float32), DummySignal(dtype=np.float64), DummySignal(dtype=np.float64)] plan = [tuple(DummyOp(reads=[x]) for x in sigs)] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert np.all([sig_map[x].dtype == np.float32 for x in sigs]) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key == sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check that ints all get converted to same precision and combined sigs = [DummySignal(dtype=np.int32), DummySignal(dtype=np.int32), DummySignal(dtype=np.int64), DummySignal(dtype=np.int64)] plan = [tuple(DummyOp(reads=[x]) for x in sigs)] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert np.all([sig_map[x].dtype == np.int32 for x in sigs]) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key == sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check that different shapes go in different groups sigs = [DummySignal(shape=(10,)), DummySignal(shape=(5,)), DummySignal(shape=(10, 1)), DummySignal(shape=(5, 1))] plan = [tuple(DummyOp(reads=[x]) for x in sigs)] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert bases[sig_map[sigs[0]].key][0].shape == (15, 10) assert bases[sig_map[sigs[2]].key][0].shape == (15, 1, 10) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key != sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check trainable sigs = [DummySignal(trainable=True), DummySignal(trainable=True), DummySignal(trainable=False), DummySignal(trainable=False)] plan = [tuple(DummyOp(reads=[x]) for x in sigs)] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert bases[sig_map[sigs[0]].key][0].shape == (2,) assert bases[sig_map[sigs[2]].key][0].shape == (2, 10) assert sig_map[sigs[0]].key == sig_map[sigs[1]].key assert sig_map[sigs[1]].key != sig_map[sigs[2]].key assert sig_map[sigs[2]].key == sig_map[sigs[3]].key # check that scalars get upsized sigs = [DummySignal(shape=()), DummySignal(shape=(4,))] plan = [tuple(DummyOp(reads=[x]) for x in sigs)] bases, sig_map = create_signals(sigs, plan, np.float32, 10) assert list(bases.values())[0][0].shape == (5, 10)
def __init__(self, model, dt, unroll_simulation, dtype, minibatch_size, device, progress): self.model = model self.dt = dt self.unroll = unroll_simulation self.dtype = dtype self.minibatch_size = minibatch_size self.device = device self.graph = tf.Graph() # find invariant inputs (nodes that don't receive any input other # than the simulation time). we'll compute these outside the simulation # and feed in the result. if self.model.toplevel is None: self.invariant_inputs = OrderedDict() else: self.invariant_inputs = OrderedDict( (n, n.output) for n in self.model.toplevel.all_nodes if n.size_in == 0 and not isinstance(n, tensor_node.TensorNode)) # filter unused operators # remove TimeUpdate because it is executed as part of the simulation # loop, not part of the step plan. remove input nodes because they # are executed outside the simulation. node_processes = [ n.output for n in self.invariant_inputs if isinstance(n.output, Process) ] operators = [ op for op in self.model.operators if not (isinstance(op, TimeUpdate) or (isinstance(op, SimPyFunc) and op.x is None) or (isinstance(op, SimProcess) and op.input is None and op.process in node_processes)) ] # mark trainable signals self.mark_signals() logger.info("Initial plan length: %d", len(operators)) # apply graph simplification functions try: simplifications = model.toplevel.config[ model.toplevel].simplifications except (ConfigError, AttributeError): simplifications = [ graph_optimizer.remove_constant_copies, graph_optimizer.remove_unmodified_resets, graph_optimizer.remove_zero_incs, graph_optimizer.remove_identity_muls, ] with progress.sub("operator simplificaton", max_value=None): old_operators = [] while len(old_operators) != len(operators) or any( x is not y for x, y in zip(operators, old_operators)): old_operators = operators for simp in simplifications: operators = simp(operators) # group mergeable operators try: planner = model.toplevel.config[model.toplevel].planner except (ConfigError, AttributeError): planner = graph_optimizer.tree_planner with progress.sub("merging operators", max_value=None): plan = planner(operators) # TODO: we could also merge operators sequentially (e.g., combine # a copy and dotinc into one op), as long as the intermediate signal # is only written to by one op and read by one op # order signals/operators to promote contiguous reads try: sorter = model.toplevel.config[model.toplevel].sorter except (ConfigError, AttributeError): sorter = graph_optimizer.order_signals with progress.sub("ordering signals", max_value=None): sigs, self.plan = sorter(plan, n_passes=10) # create base arrays and map Signals to TensorSignals (views on those # base arrays) with progress.sub("creating signals", max_value=None): self.base_arrays_init, self.sig_map = \ graph_optimizer.create_signals( sigs, self.plan, float_type=dtype.as_numpy_dtype, minibatch_size=self.minibatch_size) logger.info("Optimized plan length: %d", len(self.plan)) logger.info("Number of base arrays: %d", len(self.base_arrays_init))