Esempio n. 1
0
def test_tensor_signal_load_indices():
    sess = tf.InteractiveSession()

    sig = TensorSignal([2, 3, 4, 5], object(), None, (4,), None)
    sig.load_indices()
    assert np.all(sig.tf_indices.eval() == sig.indices)
    start, stop, step = sess.run(sig.as_slice)
    assert start == 2
    assert stop == 6
    assert step == 1

    sig = TensorSignal([2, 4, 6, 8], object(), None, (4,), None)
    sig.load_indices()
    assert np.all(sig.tf_indices.eval() == sig.indices)
    start, stop, step = sess.run(sig.as_slice)
    assert start == 2
    assert stop == 9
    assert step == 2

    sig = TensorSignal([2, 2, 3, 3], object(), None, (4,), None)
    sig.load_indices()
    assert np.all(sig.tf_indices.eval() == sig.indices)
    assert sig.as_slice is None

    sess.close()
Esempio n. 2
0
def test_signal_dict_scatter():
    minibatch_size = 1
    var_size = 19
    signals = SignalDict(None, tf.float32, minibatch_size)

    sess = tf.InteractiveSession()

    key = object()
    val = np.random.randn(var_size, minibatch_size)
    signals.bases = {key: tf.assign(tf.Variable(val, dtype=tf.float32),
                                    val)}

    x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), None)
    with pytest.raises(BuildError):
        # assigning to trainable variable
        signals.scatter(x, None)

    x.minibatch_size = 1
    with pytest.raises(BuildError):
        # indices not loaded
        signals.scatter(x, None)

    x.load_indices()
    with pytest.raises(BuildError):
        # wrong dtype
        signals.scatter(x, tf.ones((4,), dtype=tf.float64))

    # update
    signals.scatter(x, tf.ones((4,)))
    y = sess.run(signals.bases[key])
    assert np.allclose(y[:4], 1)
    assert np.allclose(y[4:], val[4:])

    # increment, and reshaping val
    signals.scatter(x, tf.ones((2, 2)), mode="inc")
    y = sess.run(signals.bases[key])
    assert np.allclose(y[:4], 2)
    assert np.allclose(y[4:], val[4:])

    # recognize assignment to full array
    x = TensorSignal(np.arange(var_size), key, tf.float32, (var_size,), 1)
    x.load_indices()
    y = tf.ones((var_size, 1))
    signals.scatter(x, y)
    assert signals.bases[key].op.type == "Assign"

    # recognize assignment to strided full array
    x = TensorSignal(np.arange(0, var_size, 2), key, tf.float32,
                     (var_size // 2 + 1,), True)
    x.load_indices()
    y = tf.ones((var_size // 2 + 1, 1))
    signals.scatter(x, y)
    assert signals.bases[key].op.type == "ScatterUpdate"

    sess.close()
Esempio n. 3
0
def test_signal_dict_gather():
    minibatch_size = 1
    var_size = 19
    signals = SignalDict(None, tf.float32, minibatch_size)

    sess = tf.InteractiveSession()

    key = object()
    val = np.random.randn(var_size, minibatch_size)
    signals.bases = {key: tf.constant(val, dtype=tf.float32)}

    x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), 1)
    with pytest.raises(BuildError):
        # indices not loaded
        signals.gather(x)

    # sliced read
    x.load_indices()
    assert np.allclose(sess.run(signals.gather(x)), val[:4])

    # read with reshape
    x = TensorSignal([0, 1, 2, 3], key, tf.float32, (2, 2), 1)
    x.load_indices()
    assert np.allclose(sess.run(signals.gather(x)),
                       val[:4].reshape((2, 2, minibatch_size)))

    # gather read
    x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), 1)
    x.load_indices()
    y = signals.gather(x, force_copy=True)
    assert "Gather" in y.op.type

    x = TensorSignal([0, 0, 3, 3], key, tf.float32, (4,), 1)
    x.load_indices()
    assert np.allclose(sess.run(signals.gather(x)),
                       val[[0, 0, 3, 3]])
    assert "Gather" in y.op.type

    # reading from full array
    x = TensorSignal(np.arange(var_size), key, tf.float32, (var_size,), 1)
    x.load_indices()
    y = signals.gather(x)
    assert y is signals.bases[key]

    # reading from strided full array
    x = TensorSignal(np.arange(0, var_size, 2), key, tf.float32,
                     (var_size // 2 + 1,), 1)
    x.load_indices()
    y = signals.gather(x)
    assert y.op.type == "StridedSlice"
    assert y.op.inputs[0] is signals.bases[key]

    # minibatch dimension
    x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), 1)
    x.load_indices()
    assert signals.gather(x).get_shape() == (4, 1)

    x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), None)
    x.load_indices()
    assert signals.gather(x).get_shape() == (4,)

    sess.close()