Esempio n. 1
0
def test_signal_dict_gather():
    minibatch_size = 1
    var_size = 19
    signals = SignalDict(tf.float32, minibatch_size)

    sess = tf.InteractiveSession()

    key = object()
    val = np.random.randn(var_size, minibatch_size)
    signals.bases = {key: tf.constant(val, dtype=tf.float32)}

    x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True)

    # sliced read
    assert np.allclose(sess.run(signals.gather(x)), val[:4])

    # read with reshape
    x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (2, 2), True)
    assert np.allclose(sess.run(signals.gather(x)),
                       val[:4].reshape((2, 2, minibatch_size)))

    # gather read
    x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True)
    y = signals.gather(x, force_copy=True)
    assert "Gather" in y.op.type

    x = signals.get_tensor_signal([0, 0, 3, 3], key, tf.float32, (4,), True)
    assert np.allclose(sess.run(signals.gather(x)),
                       val[[0, 0, 3, 3]])
    assert "Gather" in y.op.type

    # reading from full array
    x = signals.get_tensor_signal(np.arange(var_size), key, tf.float32,
                                  (var_size,), True)
    y = signals.gather(x)
    assert y is signals.bases[key]

    # reading from strided full array
    x = signals.get_tensor_signal(np.arange(0, var_size, 2), key, tf.float32,
                                  (var_size // 2 + 1,), True)
    y = signals.gather(x)
    assert y.op.type == "StridedSlice"
    assert y.op.inputs[0] is signals.bases[key]

    # minibatch dimension
    x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True)
    assert signals.gather(x).get_shape() == (4, 1)

    x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), False)
    assert signals.gather(x).get_shape() == (4,)

    sess.close()
Esempio n. 2
0
def test_signal_dict_gather(minibatched):
    minibatch_size = 3
    var_size = 19
    signals = SignalDict(tf.float32, minibatch_size)

    key = object()
    val = np.random.random(
        (minibatch_size, var_size) if minibatched else (var_size,)
    ).astype(np.float32)
    gathered_val = val[:, :4] if minibatched else val[:4]
    signals.bases = {key: tf.constant(val, dtype=tf.float32)}

    x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (4,), minibatched)

    # sliced read
    assert np.allclose(signals.gather(x), gathered_val)
    assert signals.read_types["strided_slice"] == 1

    # read with reshape
    x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (2, 2), minibatched)
    y = signals.gather(x)
    shape = (minibatch_size, 2, 2) if minibatched else (2, 2)
    assert y.shape == shape
    assert np.allclose(y, gathered_val.reshape(shape))
    assert signals.read_types["strided_slice"] == 2

    # gather read
    x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (4,), minibatched)
    y = signals.gather(x, force_copy=True)
    assert signals.read_types["gather"] == 1

    x = signals.get_tensor_signal(
        ((0, 1), (0, 1), (3, 4), (3, 4)), key, tf.float32, (4,), minibatched
    )
    assert np.allclose(
        signals.gather(x), val[:, [0, 0, 3, 3]] if minibatched else val[[0, 0, 3, 3]]
    )
    assert signals.read_types["gather"] == 2

    # reading from full array
    x = signals.get_tensor_signal(
        [(0, var_size)], key, tf.float32, (var_size,), minibatched
    )
    y = signals.gather(x)
    assert y is signals.bases[key]
    assert signals.read_types["identity"] == 1

    # reading from strided full array
    x = signals.get_tensor_signal(
        tuple((i * 2, i * 2 + 1) for i in range(var_size // 2 + 1)),
        key,
        tf.float32,
        (var_size // 2 + 1,),
        minibatched,
    )
    y = signals.gather(x)
    assert y is not signals.bases[key]
    assert signals.read_types["gather"] == 3