コード例 #1
0
def test_order_signals_views():
    base = dummies.Signal(shape=(6, ), label="base")
    sig = dummies.Signal(shape=(7, ), label="sig")
    sig2 = dummies.Signal(shape=(7, ), label="sig2")
    views = [
        dummies.Signal(shape=(1, ),
                       base_shape=(5, ),
                       offset=1 + i,
                       label="view_%d" % i) for i in range(5)
    ]
    for v in views:
        v.base = base
    plan = [
        (
            dummies.Op(reads=[base]),
            dummies.Op(reads=[views[1]]),
            dummies.Op(reads=[views[0]]),
            dummies.Op(reads=[sig2]),
        ),
        (dummies.Op(reads=[base]), dummies.Op(reads=[sig])),
        tuple(dummies.Op(reads=[views[i]]) for i in range(4, 2, -1)),
        (dummies.Op(reads=[views[4]]), dummies.Op(reads=[sig])),
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous([base, sig, sig2], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
    assert ordered(new_plan[2], sigs)
    assert ordered(new_plan[3], sigs)
コード例 #2
0
def test_planner_unmergeable(planner):
    # check that non-mergeable operators aren't merged
    input0 = dummies.Signal()
    operators = [Copy(input0, dummies.Signal(dtype=np.float32)),
                 Copy(input0, dummies.Signal(dtype=np.int32))]
    plan = planner(operators)
    assert len(plan) == 2
    assert type(plan[0][0]) == Copy
    assert len(plan[0]) == 1
    assert type(plan[1][0]) == Copy
    assert len(plan[1]) == 1
コード例 #3
0
def test_planner_mergeable(planner):
    # check that mergeable operators are merged
    input0 = dummies.Signal()
    input1 = dummies.Signal()
    output0 = dummies.Signal()
    output1 = dummies.Signal()
    operators = [Copy(input0, output0, inc=True),
                 Copy(input1, output1, inc=True)]
    plan = planner(operators)
    assert len(plan) == 1
    assert type(plan[0][0]) == Copy
    assert len(plan[0]) == 2
コード例 #4
0
def test_noop_order_signals():
    inputs = [dummies.Signal(label="a"), dummies.Signal(label="b"),
              dummies.Signal(label="c", base_shape=(2,))]
    plan = [(dummies.Op(reads=[x]),) for x in inputs]

    sigs, new_plan = noop_order_signals(plan)

    assert all(x == y for x, y in zip(plan, new_plan))
    assert len(sigs) == 3

    sigs.remove(inputs[0])
    sigs.remove(inputs[1])
    assert sigs[0].name == "c.base"
コード例 #5
0
def test_planner_size():
    # check that operators are selected according to number of available ops
    input0 = dummies.Signal()
    operators = [Copy(input0, dummies.Signal(), inc=True)
                 for _ in range(2)]
    operators += [Copy(input0, dummies.Signal())]
    operators += [DotInc(input0, dummies.Signal(), dummies.Signal())
                  for _ in range(3)]
    plan = greedy_planner(operators)
    assert len(plan) == 3
    assert len(plan[0]) == 3
    assert len(plan[1]) == 2
    assert len(plan[2]) == 1
コード例 #6
0
def test_planner_chain(planner):
    # test a chain
    a = dummies.Signal(label="a")
    b = dummies.Signal(label="b")
    c = dummies.Signal(label="c")
    d = dummies.Signal(label="d")
    operators = [Copy(a, b, inc=True) for _ in range(3)]
    operators += [SimPyFunc(c, lambda x: x, None, b)]
    operators += [Copy(c, d, inc=True) for _ in range(2)]
    plan = planner(operators)
    assert len(plan) == 3
    assert len(plan[0]) == 3
    assert len(plan[1]) == 1
    assert len(plan[2]) == 2
コード例 #7
0
def test_planner_cycle(planner):
    inputs = [dummies.Signal() for _ in range(3)]
    operators = [Copy(inputs[0], inputs[1]), Copy(inputs[1], inputs[2]),
                 Copy(inputs[2], inputs[0])]

    with pytest.raises(BuildError):
        planner(operators)
コード例 #8
0
def test_noop_planner():
    inputs = [dummies.Signal() for _ in range(3)]
    operators = [Copy(inputs[1], inputs[2]), Copy(inputs[0], inputs[1])]

    plan = noop_planner(operators)
    assert len(plan) == len(operators)
    assert plan[0] == (operators[1],)
    assert plan[1] == (operators[0],)
コード例 #9
0
def test_create_signals_views():
    sigs = [dummies.Signal(shape=(2, 2), base_shape=(4,)),
            dummies.Signal(shape=(2, 2), base_shape=(4,))]
    sigs += [sigs[0].base, sigs[1].base]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs[2:])
    assert list(graph.base_arrays_init.values())[0][0].shape == (8, 10)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key
    assert np.all(graph.signals[sigs[0]].indices == (0, 1, 2, 3))
    assert np.all(graph.signals[sigs[1]].indices == (4, 5, 6, 7))
    assert np.all(graph.signals[sigs[0]].indices ==
                  graph.signals[sigs[2]].indices)
    assert np.all(graph.signals[sigs[1]].indices ==
                  graph.signals[sigs[3]].indices)
コード例 #10
0
def test_order_signals_lowpass():
    # test that lowpass outputs are ordered as reads

    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    time = dummies.Signal()
    plan = [
        tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time,
                         mode="update") for i in range(0, 4, 2)),
        tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time,
                         mode="update") for i in range(5, 9, 2))]
    sigs, new_plan = order_signals(plan)

    assert contiguous(inputs[1:5:2], sigs)
    assert contiguous(inputs[6:10:2], sigs)

    assert ordered(new_plan[0], sigs, block=1)
    assert ordered(new_plan[0], sigs, block=2)
    assert ordered(new_plan[1], sigs, block=1)
    assert ordered(new_plan[1], sigs, block=2)
コード例 #11
0
def test_order_signals_noreads():
    # test with ops that don't have any reads
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i]]) for i in range(5)),
        tuple(dummies.Op(sets=[inputs[5 + i]]) for i in range(5)),
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:5], sigs)
    assert ordered(new_plan[0], sigs)
コード例 #12
0
def test_create_signals_views():
    sigs = [
        dummies.Signal(shape=(2, 2), base_shape=(4, )),
        dummies.Signal(shape=(2, 2), base_shape=(4, )),
    ]
    sigs += [sigs[0].base, sigs[1].base]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs[2:])
    assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [
        (10, 4),
        (10, 4),
    ]
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key
    assert graph.signals[sigs[0]].slices == ((0, 4), )
    assert graph.signals[sigs[1]].slices == ((4, 8), )
    assert graph.signals[sigs[0]].slices == graph.signals[sigs[2]].slices
    assert graph.signals[sigs[1]].slices == graph.signals[sigs[3]].slices
コード例 #13
0
def test_order_signals_disjoint():
    # disjoint reads
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]

    plan = [
        tuple(dummies.Op(reads=[inputs[i]]) for i in range(5)),
        tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(5))]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:5], sigs)
    assert contiguous(inputs[5:], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
コード例 #14
0
def test_order_signals_partial3():
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i]]) for i in [0, 1, 2, 3]),
        tuple(dummies.Op(reads=[inputs[i]]) for i in [0, 4, 7]),
        tuple(dummies.Op(reads=[inputs[i]]) for i in [5, 6, 7])]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:4], sigs)
    assert contiguous([inputs[0], inputs[4], inputs[7]], sigs)
    assert contiguous(inputs[5:8], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
    assert ordered(new_plan[2], sigs)
コード例 #15
0
def test_order_signals_partial():
    # partially overlapping reads

    # two overlapping sets (A, A/B, B)
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i]]) for i in range(4)),
        tuple(dummies.Op(reads=[inputs[2 + i]]) for i in range(4))]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:4], sigs)
    assert contiguous(inputs[2:6], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
コード例 #16
0
def test_order_signals_subset():
    # ordering in which one read block is fully nested within another
    # (A, A/B)
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i]]) for i in range(10)),
        tuple(dummies.Op(reads=[inputs[4 - i]]) for i in range(5)),
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:5], sigs)
    assert contiguous(inputs[:10], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
コード例 #17
0
def test_create_signals_partition():
    # check that signals are partitioned based on plan
    sigs = [
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal()
    ]
    plan = [
        tuple(dummies.Op(reads=[x]) for x in sigs[:2]),
        tuple(dummies.Op(reads=[x]) for x in sigs[2:]),
    ]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that signals are partitioned for different read blocks
    plan = [tuple(dummies.Op(reads=[sigs[i], sigs[2 + i]]) for i in range(2))]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that signals are partitioned for different sig types
    plan = [
        tuple(
            dummies.Op(reads=[sigs[i]], sets=[sigs[2 + i]]) for i in range(2))
    ]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that resets are ignored
    sigs = [
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal()
    ]
    plan = [tuple(Reset(x) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert len(graph.base_arrays_init["non_trainable"]) == 4
コード例 #18
0
def test_sparsedotinc_mergeable():
    assert mergeable(
        SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(),
                     dummies.Signal()),
        [
            SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(),
                         dummies.Signal())
        ],
    )
コード例 #19
0
def test_order_signals_duplicate_read_blocks():
    # test that order_signal prioritizes read blocks that are duplicated in
    # multiple op groups
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(3)),
        tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(3)),
        tuple(dummies.Op(reads=[inputs[5 + i], inputs[4 - i]]) for i in range(5))]

    sigs, new_plan = order_signals(plan)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
    assert (ordered(new_plan[2], sigs, block=0) or
            ordered(new_plan[2], sigs, block=1))
    assert not ordered(new_plan[2], sigs)
コード例 #20
0
def test_order_signals_multiread_complex():
    # signal sorting with operators that read from multiple signals
    # (overlapping)
    # (C, B/C, A) (where A and B are from the same op)
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(3)),
        tuple(dummies.Op(reads=[inputs[i + 5]]) for i in range(5)),
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:3], sigs)
    assert contiguous(inputs[5:], sigs)
    assert contiguous(inputs[5:8], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
コード例 #21
0
def test_order_signals_partial2():
    # more complex partial overlap
    # (A, A/B, B/C, C)
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i]]) for i in range(5)),
        tuple(dummies.Op(reads=[inputs[2 + i]]) for i in range(4)),
        tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(3)),
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:5], sigs)
    assert contiguous(inputs[5:8], sigs)
    assert contiguous(inputs[2:6], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[1], sigs)
    assert ordered(new_plan[2], sigs)
コード例 #22
0
def test_order_signals_neuron_states():
    # test with neuron states (should be treated as reads)

    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(SimNeurons(None, inputs[0], inputs[1], states=[x])
              for x in inputs[2::2]),
        tuple(SimNeurons(None, inputs[0], inputs[1], states=[x])
              for x in inputs[3::2])]
    sigs, new_plan = order_signals(plan)

    assert contiguous(inputs[2::2], sigs)
    assert contiguous(inputs[3::2], sigs)
    # note: block=0 is just a single signal, so it's always "ordered"
    assert ordered(new_plan[0], sigs, block=1)
    assert ordered(new_plan[1], sigs, block=1)
コード例 #23
0
def test_order_signals_partial_unsatisfiable():
    # this one will be unsatisfied, because after A it will pick A/B (because
    # B is the next biggest block). technically this could be satisfied if
    # we picked A/C next, but is there a way we could know that?
    # (A, A/B, A/C, B)
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i]]) for i in range(7)),
        tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(5)),
        tuple(dummies.Op(reads=[inputs[i]]) for i in range(3)),
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:7], sigs)
    assert not contiguous(inputs[5:], sigs)
    assert contiguous(inputs[:3], sigs)
    assert ordered(new_plan[0], sigs)
    assert ordered(new_plan[2], sigs)
コード例 #24
0
def test_order_signals_duplicates():
    # test where read blocks contain duplicate signals
    inputs = [dummies.Signal(label=str(i)) for i in range(4)]
    plan = [
        tuple(dummies.Op(reads=[inputs[0]]) for _ in range(2)) +
        (dummies.Op(reads=[inputs[2]]),),
        tuple(dummies.Op(reads=[inputs[1]]) for _ in range(2)) +
        (dummies.Op(reads=[inputs[3]]),)
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous([inputs[0], inputs[2]], sigs)
    assert contiguous([inputs[1], inputs[3]], sigs)

    # note: not possible for these to be in increasing order, since they
    # contain duplicates
    assert not ordered(new_plan[0], sigs)
    assert not ordered(new_plan[1], sigs)
コード例 #25
0
def test_order_signals_multiread_complex2():
    # (B, B/A, A, A/C, C) (where A and B are from the same op)
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[2 + i], inputs[i]]) for i in range(4)),
        tuple(dummies.Op(reads=[inputs[5 + i]]) for i in range(3)),
    ]
    sigs, new_plan = order_signals(plan)

    assert contiguous(inputs[5:8], sigs)
    assert ordered(new_plan[1], sigs)

    # TODO: technically it is always possible to order both blocks properly,
    # but it requires you to know which of the two equally sized blocks should
    # have priority, and I'm not sure there's a way to determine that.
    assert contiguous(inputs[:4], sigs) or contiguous(inputs[2:6], sigs)
    assert ordered(new_plan[0], sigs, block=0) or ordered(
        new_plan[0], sigs, block=1)
コード例 #26
0
def test_order_signals_multiread_unsatisfiable():
    # unsatisfiable order for block C (conflicts with A, which gets prioritized
    # because it is in a larger group)
    # (A, A/C, B, B/D)
    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    plan = [
        tuple(dummies.Op(reads=[inputs[i], inputs[5 + i]]) for i in range(5)),
        tuple(dummies.Op(reads=[inputs[1 - i], inputs[5 + i]]) for i in range(2)),
    ]
    sigs, new_plan = order_signals(plan)
    assert contiguous(inputs[:5], sigs)
    assert contiguous(inputs[5:], sigs)
    assert contiguous(inputs[:2], sigs)
    assert contiguous(inputs[5:7], sigs)
    assert ordered(new_plan[0], sigs)
    assert (ordered(new_plan[1], sigs, block=0) or
            ordered(new_plan[1], sigs, block=1))
    assert not ordered(new_plan[1], sigs)
コード例 #27
0
def test_build(trainable, rng):
    sigs = [
        dummies.Signal(shape=(2, 1),
                       dtype="float32",
                       initial_value=0,
                       trainable=trainable),
        dummies.Signal(
            shape=(3, 1),
            dtype="float32",
            initial_value=np.zeros((3, 1)),
            trainable=trainable,
        ),
        dummies.Signal(shape=(4, 1),
                       dtype="float32",
                       initial_value=1,
                       trainable=trainable),
        dummies.Signal(
            shape=(5, 1),
            dtype="float32",
            initial_value=np.ones((5, 1)),
            trainable=trainable,
        ),
        dummies.Signal(
            shape=(6, 1),
            dtype="float32",
            initial_value=rng.uniform(size=(6, 1)),
            trainable=trainable,
        ),
        dummies.Signal(
            shape=(7, 1),
            dtype="float32",
            initial_value=rng.uniform(size=(7, 1)),
            trainable=trainable,
        ),
    ]

    plan = [
        tuple(dummies.Op(reads=[x]) for x in sigs[:2]),
        tuple(dummies.Op(reads=[x]) for x in sigs[2:4]),
        tuple(dummies.Op(reads=[x]) for x in sigs[4:]),
    ]

    graph = dummies.TensorGraph(plan=plan, dtype="float32", minibatch_size=16)
    graph.create_signals(sigs)
    graph.build()

    if trainable:
        assert len(graph.trainable_weights) == 3
        assert len(graph.non_trainable_weights) == 0
    else:
        assert len(graph.trainable_weights) == 0
        assert len(graph.non_trainable_weights) == 3

    init0 = graph.weights[0].numpy()
    assert init0.shape == (5, 1) if trainable else (16, 5, 1)
    assert np.allclose(init0, 0)

    init1 = graph.weights[1].numpy()
    assert init1.shape == (9, 1) if trainable else (16, 9, 1)
    assert np.allclose(init1, 1)

    init2 = graph.weights[2].numpy()
    if trainable:
        assert init2.shape == (13, 1)
        assert np.allclose(init2[:6], sigs[4].initial_value)
        assert np.allclose(init2[6:], sigs[5].initial_value)
    else:
        assert init2.shape == (16, 13, 1)
        assert np.allclose(init2[:, :6], sigs[4].initial_value)
        assert np.allclose(init2[:, 6:], sigs[5].initial_value)
コード例 #28
0
def test_create_signals():
    # check that floats/ints get split into different arrays

    sigs = [
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.int32),
        dummies.Signal(dtype=np.int32),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)

    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that floats all get converted to same precision and combined
    sigs = [
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.float64),
        dummies.Signal(dtype=np.float64),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert np.all([graph.signals[x].dtype == "float32" for x in sigs])
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that ints all get converted to same precision and combined
    sigs = [
        dummies.Signal(dtype=np.int32),
        dummies.Signal(dtype=np.int32),
        dummies.Signal(dtype=np.int64),
        dummies.Signal(dtype=np.int64),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert np.all([graph.signals[x].dtype == "int32" for x in sigs])
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that different shapes go in different groups
    sigs = [
        dummies.Signal(shape=(10, )),
        dummies.Signal(shape=(5, )),
        dummies.Signal(shape=(10, 1)),
        dummies.Signal(shape=(5, 1)),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.base_arrays_init["non_trainable"][graph.signals[
        sigs[0]].key][1] == [
            (10, 10),
            (10, 5),
        ]
    assert graph.base_arrays_init["non_trainable"][graph.signals[
        sigs[2]].key][1] == [
            (10, 10, 1),
            (10, 5, 1),
        ]
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check trainable
    sigs = [
        dummies.Signal(trainable=True),
        dummies.Signal(trainable=True),
        dummies.Signal(trainable=False),
        dummies.Signal(trainable=False),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.base_arrays_init["trainable"][graph.signals[
        sigs[0]].key][1] == [
            (1, ),
            (1, ),
        ]
    assert graph.base_arrays_init["non_trainable"][graph.signals[
        sigs[2]].key][1] == [
            (10, 1),
            (10, 1),
        ]
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that scalars get upsized
    sigs = [dummies.Signal(shape=()), dummies.Signal(shape=(4, ))]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [
        (10, 1),
        (10, 4),
    ]

    # check that boolean signals are handled correctly
    sigs = [dummies.Signal(dtype=np.bool, shape=())]
    plan = [(dummies.Op(reads=sigs), )]
    graph = dummies.TensorGraph(plan, tf.float32, 1)
    graph.create_signals(sigs)
    assert list(
        graph.base_arrays_init["non_trainable"].values())[0][2] == "bool"
コード例 #29
0
def test_remove_identity_muls(Op):
    # check that identity input signals get removed
    As = [1.0, np.diag(np.ones(3)) if Op == DotInc else np.ones(3)]
    for A in As:
        x = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1])
        y = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1])
        a = Signal(A)
        a.trainable = False
        operators = [Op(a, x, y)]
        new_operators = remove_identity_muls(operators)
        assert len(new_operators) == 1
        new_op = new_operators[0]
        assert isinstance(new_op, Copy)
        assert new_op.src is x
        assert new_op.dst is y
        assert new_op.inc

    # check that identity x gets removed for elementwiseinc
    if Op == ElementwiseInc:
        a = dummies.Signal()
        x = dummies.Signal(initial_value=1)
        y = dummies.Signal()
        operators = [Op(a, x, y)]
        new_operators = remove_identity_muls(operators)
        assert len(operators) == 1
        new_op = new_operators[0]
        assert isinstance(new_op, Copy)
        assert new_op.src is a
        assert new_op.dst is y
        assert new_op.inc

    # check that reset inputs get removed
    for A in As:
        x = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1])
        y = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape[:1])
        a = dummies.Signal(shape=(1,) if isinstance(A, float) else A.shape)
        r = Reset(a)
        r.value = A
        operators = [Op(a, x, y), r]
        new_operators = remove_identity_muls(operators)
        assert len(new_operators) == 2
        assert new_operators[1:] == operators[1:]
        new_op = new_operators[0]
        assert isinstance(new_op, Copy)
        assert new_op.src is x
        assert new_op.dst is y
        assert new_op.inc

    # check that non-identity inputs don't get removed
    a = Signal(np.ones((3, 3)))
    a.trainable = False
    operators = [Op(a, dummies.Signal(shape=(3,)),
                    dummies.Signal(shape=(3,)))]
    new_operators = remove_identity_muls(operators)
    assert new_operators == operators

    # check that node inputs don't get removed
    x = dummies.Signal(label="<Node lorem ipsum")
    operators = [Op(x, dummies.Signal(), dummies.Signal())]
    new_operators = remove_identity_muls(operators)
    assert new_operators == operators

    # check that identity inputs + trainable don't get removed
    x = Signal(1.0)
    x.trainable = True
    operators = [Op(x, dummies.Signal(), dummies.Signal())]
    new_operators = remove_identity_muls(operators)
    assert new_operators == operators

    # check that updated input doesn't get removed
    x = dummies.Signal()
    operators = [Op(x, dummies.Signal(), dummies.Signal()),
                 dummies.Op(updates=[x])]
    new_operators = remove_identity_muls(operators)
    assert new_operators == operators

    # check that inc'd input doesn't get removed
    x = dummies.Signal()
    operators = [Op(x, dummies.Signal(), dummies.Signal()),
                 dummies.Op(incs=[x])]
    new_operators = remove_identity_muls(operators)
    assert new_operators == operators

    # check that set'd input doesn't get removed
    x = dummies.Signal()
    operators = [Op(x, dummies.Signal(), dummies.Signal()),
                 dummies.Op(sets=[x])]
    new_operators = remove_identity_muls(operators)
    assert new_operators == operators
コード例 #30
0
def test_remove_constant_copies():
    # check that Copy with no inputs gets turned into Reset
    x = dummies.Signal()
    operators = [Copy(dummies.Signal(), x)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst is x
    assert new_operators[0].value == 0

    # check that Copy with Node input doesn't get changed
    x = dummies.Signal(label="<Node lorem ipsum")
    operators = [Copy(x, dummies.Signal())]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check that Copy with trainable input doesn't get changed
    x = dummies.Signal()
    x.trainable = True
    operators = [Copy(x, dummies.Signal())]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with updated input doesn't get changed
    x = dummies.Signal()
    operators = [Copy(x, dummies.Signal()), dummies.Op(updates=[x])]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with inc'd input doesn't get changed
    x = dummies.Signal()
    operators = [Copy(x, dummies.Signal()), dummies.Op(incs=[x])]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with set input doesn't get changed
    x = dummies.Signal()
    operators = [Copy(x, dummies.Signal()), dummies.Op(sets=[x])]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with read input/output does get changed
    x = dummies.Signal()
    y = dummies.Signal()
    operators = [Copy(x, y), dummies.Op(reads=[x]),
                 dummies.Op(reads=[y])]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 3
    assert new_operators[1:] == operators[1:]
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst is y
    assert new_operators[0].value == 0

    # check Copy with Reset input does get changed
    x = dummies.Signal()
    y = dummies.Signal()
    operators = [Copy(x, y), Reset(x, 2)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst is y
    assert new_operators[0].value == 2

    # check that slicing is respected
    x = dummies.Signal()
    y = Signal(initial_value=[0, 0])
    operators = [Copy(x, y, dst_slice=slice(1, 2)), Reset(x, 2)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst.shape == (1,)
    assert new_operators[0].dst.is_view
    assert new_operators[0].dst.elemoffset == 1
    assert new_operators[0].dst.base is y
    assert new_operators[0].value == 2

    # check that CopyInc gets turned into ResetInc
    x = dummies.Signal()
    y = dummies.Signal()
    operators = [Copy(x, y, inc=True), Reset(x, 2)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.ResetInc)
    assert new_operators[0].dst is y
    assert new_operators[0].value == 2
    assert len(new_operators[0].incs) == 1
    assert len(new_operators[0].sets) == 0