예제 #1
0
def test_create_signals_partition():
    # check that signals are partitioned based on plan
    sigs = [
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal()
    ]
    plan = [
        tuple(dummies.Op(reads=[x]) for x in sigs[:2]),
        tuple(dummies.Op(reads=[x]) for x in sigs[2:]),
    ]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that signals are partitioned for different read blocks
    plan = [tuple(dummies.Op(reads=[sigs[i], sigs[2 + i]]) for i in range(2))]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that signals are partitioned for different sig types
    plan = [
        tuple(
            dummies.Op(reads=[sigs[i]], sets=[sigs[2 + i]]) for i in range(2))
    ]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that resets are ignored
    sigs = [
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal(),
        dummies.Signal()
    ]
    plan = [tuple(Reset(x) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert len(graph.base_arrays_init["non_trainable"]) == 4
예제 #2
0
def test_create_signals_views():
    sigs = [dummies.Signal(shape=(2, 2), base_shape=(4,)),
            dummies.Signal(shape=(2, 2), base_shape=(4,))]
    sigs += [sigs[0].base, sigs[1].base]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs[2:])
    assert list(graph.base_arrays_init.values())[0][0].shape == (8, 10)
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key
    assert np.all(graph.signals[sigs[0]].indices == (0, 1, 2, 3))
    assert np.all(graph.signals[sigs[1]].indices == (4, 5, 6, 7))
    assert np.all(graph.signals[sigs[0]].indices ==
                  graph.signals[sigs[2]].indices)
    assert np.all(graph.signals[sigs[1]].indices ==
                  graph.signals[sigs[3]].indices)
예제 #3
0
def test_create_signals_views():
    sigs = [
        dummies.Signal(shape=(2, 2), base_shape=(4, )),
        dummies.Signal(shape=(2, 2), base_shape=(4, )),
    ]
    sigs += [sigs[0].base, sigs[1].base]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs[2:])
    assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [
        (10, 4),
        (10, 4),
    ]
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key
    assert graph.signals[sigs[0]].slices == ((0, 4), )
    assert graph.signals[sigs[1]].slices == ((4, 8), )
    assert graph.signals[sigs[0]].slices == graph.signals[sigs[2]].slices
    assert graph.signals[sigs[1]].slices == graph.signals[sigs[3]].slices
예제 #4
0
def test_build(trainable, rng):
    sigs = [
        dummies.Signal(shape=(2, 1),
                       dtype="float32",
                       initial_value=0,
                       trainable=trainable),
        dummies.Signal(
            shape=(3, 1),
            dtype="float32",
            initial_value=np.zeros((3, 1)),
            trainable=trainable,
        ),
        dummies.Signal(shape=(4, 1),
                       dtype="float32",
                       initial_value=1,
                       trainable=trainable),
        dummies.Signal(
            shape=(5, 1),
            dtype="float32",
            initial_value=np.ones((5, 1)),
            trainable=trainable,
        ),
        dummies.Signal(
            shape=(6, 1),
            dtype="float32",
            initial_value=rng.uniform(size=(6, 1)),
            trainable=trainable,
        ),
        dummies.Signal(
            shape=(7, 1),
            dtype="float32",
            initial_value=rng.uniform(size=(7, 1)),
            trainable=trainable,
        ),
    ]

    plan = [
        tuple(dummies.Op(reads=[x]) for x in sigs[:2]),
        tuple(dummies.Op(reads=[x]) for x in sigs[2:4]),
        tuple(dummies.Op(reads=[x]) for x in sigs[4:]),
    ]

    graph = dummies.TensorGraph(plan=plan, dtype="float32", minibatch_size=16)
    graph.create_signals(sigs)
    graph.build()

    if trainable:
        assert len(graph.trainable_weights) == 3
        assert len(graph.non_trainable_weights) == 0
    else:
        assert len(graph.trainable_weights) == 0
        assert len(graph.non_trainable_weights) == 3

    init0 = graph.weights[0].numpy()
    assert init0.shape == (5, 1) if trainable else (16, 5, 1)
    assert np.allclose(init0, 0)

    init1 = graph.weights[1].numpy()
    assert init1.shape == (9, 1) if trainable else (16, 9, 1)
    assert np.allclose(init1, 1)

    init2 = graph.weights[2].numpy()
    if trainable:
        assert init2.shape == (13, 1)
        assert np.allclose(init2[:6], sigs[4].initial_value)
        assert np.allclose(init2[6:], sigs[5].initial_value)
    else:
        assert init2.shape == (16, 13, 1)
        assert np.allclose(init2[:, :6], sigs[4].initial_value)
        assert np.allclose(init2[:, 6:], sigs[5].initial_value)
예제 #5
0
def test_create_signals():
    # check that floats/ints get split into different arrays

    sigs = [
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.int32),
        dummies.Signal(dtype=np.int32),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)

    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that floats all get converted to same precision and combined
    sigs = [
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.float32),
        dummies.Signal(dtype=np.float64),
        dummies.Signal(dtype=np.float64),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert np.all([graph.signals[x].dtype == "float32" for x in sigs])
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that ints all get converted to same precision and combined
    sigs = [
        dummies.Signal(dtype=np.int32),
        dummies.Signal(dtype=np.int32),
        dummies.Signal(dtype=np.int64),
        dummies.Signal(dtype=np.int64),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert np.all([graph.signals[x].dtype == "int32" for x in sigs])
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key == graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that different shapes go in different groups
    sigs = [
        dummies.Signal(shape=(10, )),
        dummies.Signal(shape=(5, )),
        dummies.Signal(shape=(10, 1)),
        dummies.Signal(shape=(5, 1)),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.base_arrays_init["non_trainable"][graph.signals[
        sigs[0]].key][1] == [
            (10, 10),
            (10, 5),
        ]
    assert graph.base_arrays_init["non_trainable"][graph.signals[
        sigs[2]].key][1] == [
            (10, 10, 1),
            (10, 5, 1),
        ]
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check trainable
    sigs = [
        dummies.Signal(trainable=True),
        dummies.Signal(trainable=True),
        dummies.Signal(trainable=False),
        dummies.Signal(trainable=False),
    ]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert graph.base_arrays_init["trainable"][graph.signals[
        sigs[0]].key][1] == [
            (1, ),
            (1, ),
        ]
    assert graph.base_arrays_init["non_trainable"][graph.signals[
        sigs[2]].key][1] == [
            (10, 1),
            (10, 1),
        ]
    assert graph.signals[sigs[0]].key == graph.signals[sigs[1]].key
    assert graph.signals[sigs[1]].key != graph.signals[sigs[2]].key
    assert graph.signals[sigs[2]].key == graph.signals[sigs[3]].key

    # check that scalars get upsized
    sigs = [dummies.Signal(shape=()), dummies.Signal(shape=(4, ))]
    plan = [tuple(dummies.Op(reads=[x]) for x in sigs)]
    graph = dummies.TensorGraph(plan, tf.float32, 10)
    graph.create_signals(sigs)
    assert list(graph.base_arrays_init["non_trainable"].values())[0][1] == [
        (10, 1),
        (10, 4),
    ]

    # check that boolean signals are handled correctly
    sigs = [dummies.Signal(dtype=np.bool, shape=())]
    plan = [(dummies.Op(reads=sigs), )]
    graph = dummies.TensorGraph(plan, tf.float32, 1)
    graph.create_signals(sigs)
    assert list(
        graph.base_arrays_init["non_trainable"].values())[0][2] == "bool"