Esempio n. 1
0
def test_create_signals_partition():
    # check that signals are partitioned based on plan
    sigs = [DummySignal(), DummySignal(),
            DummySignal(), DummySignal()]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs[:2]),
            tuple(DummyOp(reads=[x]) for x in sigs[2:])]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key != sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check that signals are partioned for different read blocks
    plan = [tuple(DummyOp(reads=[sigs[i], sigs[2 + i]]) for i in range(2))]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key != sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check that signals are partioned for different sig types
    plan = [tuple(DummyOp(reads=[sigs[i]], sets=[sigs[2 + i]])
                  for i in range(2))]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key != sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check that resets are ignored
    sigs = [DummySignal(), DummySignal(), DummySignal(), DummySignal()]
    plan = [tuple(Reset(x) for x in sigs)]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert len(bases) == 4
Esempio n. 2
0
def test_create_signals_views():
    sigs = [DummySignal(shape=(2, 2), base_shape=(4,)),
            DummySignal(shape=(2, 2), base_shape=(4,))]
    sigs += [sigs[0].base, sigs[1].base]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs)]
    bases, sig_map = create_signals(sigs[2:], plan, np.float32, 10)
    assert list(bases.values())[0][0].shape == (8, 10)
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key == sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key
    assert np.all(sig_map[sigs[0]].indices == (0, 1, 2, 3))
    assert np.all(sig_map[sigs[1]].indices == (4, 5, 6, 7))
    assert np.all(sig_map[sigs[0]].indices == sig_map[sigs[2]].indices)
    assert np.all(sig_map[sigs[1]].indices == sig_map[sigs[3]].indices)
Esempio n. 3
0
def test_create_signals():
    # check that floats/ints get split into different arrays
    sigs = [DummySignal(dtype=np.float32), DummySignal(dtype=np.float32),
            DummySignal(dtype=np.int32), DummySignal(dtype=np.int32)]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs)]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key != sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check that floats all get converted to same precision and combined
    sigs = [DummySignal(dtype=np.float32), DummySignal(dtype=np.float32),
            DummySignal(dtype=np.float64), DummySignal(dtype=np.float64)]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs)]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert np.all([sig_map[x].dtype == np.float32 for x in sigs])
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key == sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check that ints all get converted to same precision and combined
    sigs = [DummySignal(dtype=np.int32), DummySignal(dtype=np.int32),
            DummySignal(dtype=np.int64), DummySignal(dtype=np.int64)]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs)]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert np.all([sig_map[x].dtype == np.int32 for x in sigs])
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key == sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check that different shapes go in different groups
    sigs = [DummySignal(shape=(10,)), DummySignal(shape=(5,)),
            DummySignal(shape=(10, 1)), DummySignal(shape=(5, 1))]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs)]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert bases[sig_map[sigs[0]].key][0].shape == (15, 10)
    assert bases[sig_map[sigs[2]].key][0].shape == (15, 1, 10)
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key != sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check trainable
    sigs = [DummySignal(trainable=True), DummySignal(trainable=True),
            DummySignal(trainable=False), DummySignal(trainable=False)]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs)]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert bases[sig_map[sigs[0]].key][0].shape == (2,)
    assert bases[sig_map[sigs[2]].key][0].shape == (2, 10)
    assert sig_map[sigs[0]].key == sig_map[sigs[1]].key
    assert sig_map[sigs[1]].key != sig_map[sigs[2]].key
    assert sig_map[sigs[2]].key == sig_map[sigs[3]].key

    # check that scalars get upsized
    sigs = [DummySignal(shape=()), DummySignal(shape=(4,))]
    plan = [tuple(DummyOp(reads=[x]) for x in sigs)]
    bases, sig_map = create_signals(sigs, plan, np.float32, 10)
    assert list(bases.values())[0][0].shape == (5, 10)
Esempio n. 4
0
    def __init__(self, model, dt, unroll_simulation, dtype, minibatch_size,
                 device, progress):
        self.model = model
        self.dt = dt
        self.unroll = unroll_simulation
        self.dtype = dtype
        self.minibatch_size = minibatch_size
        self.device = device
        self.graph = tf.Graph()

        # find invariant inputs (nodes that don't receive any input other
        # than the simulation time). we'll compute these outside the simulation
        # and feed in the result.
        if self.model.toplevel is None:
            self.invariant_inputs = OrderedDict()
        else:
            self.invariant_inputs = OrderedDict(
                (n, n.output) for n in self.model.toplevel.all_nodes if
                n.size_in == 0 and not isinstance(n, tensor_node.TensorNode))

        # filter unused operators
        # remove TimeUpdate because it is executed as part of the simulation
        # loop, not part of the step plan. remove input nodes because they
        # are executed outside the simulation.
        node_processes = [
            n.output for n in self.invariant_inputs
            if isinstance(n.output, Process)
        ]
        operators = [
            op for op in self.model.operators
            if not (isinstance(op, TimeUpdate) or
                    (isinstance(op, SimPyFunc) and op.x is None) or
                    (isinstance(op, SimProcess) and op.input is None
                     and op.process in node_processes))
        ]

        # mark trainable signals
        self.mark_signals()

        logger.info("Initial plan length: %d", len(operators))

        # apply graph simplification functions
        try:
            simplifications = model.toplevel.config[
                model.toplevel].simplifications
        except (ConfigError, AttributeError):
            simplifications = [
                graph_optimizer.remove_constant_copies,
                graph_optimizer.remove_unmodified_resets,
                graph_optimizer.remove_zero_incs,
                graph_optimizer.remove_identity_muls,
            ]

        with progress.sub("operator simplificaton", max_value=None):
            old_operators = []
            while len(old_operators) != len(operators) or any(
                    x is not y for x, y in zip(operators, old_operators)):
                old_operators = operators
                for simp in simplifications:
                    operators = simp(operators)

        # group mergeable operators
        try:
            planner = model.toplevel.config[model.toplevel].planner
        except (ConfigError, AttributeError):
            planner = graph_optimizer.tree_planner

        with progress.sub("merging operators", max_value=None):
            plan = planner(operators)

        # TODO: we could also merge operators sequentially (e.g., combine
        # a copy and dotinc into one op), as long as the intermediate signal
        # is only written to by one op and read by one op

        # order signals/operators to promote contiguous reads
        try:
            sorter = model.toplevel.config[model.toplevel].sorter
        except (ConfigError, AttributeError):
            sorter = graph_optimizer.order_signals

        with progress.sub("ordering signals", max_value=None):
            sigs, self.plan = sorter(plan, n_passes=10)

        # create base arrays and map Signals to TensorSignals (views on those
        # base arrays)
        with progress.sub("creating signals", max_value=None):
            self.base_arrays_init, self.sig_map = \
                graph_optimizer.create_signals(
                    sigs, self.plan, float_type=dtype.as_numpy_dtype,
                    minibatch_size=self.minibatch_size)

        logger.info("Optimized plan length: %d", len(self.plan))
        logger.info("Number of base arrays: %d", len(self.base_arrays_init))