Esempio n. 1
0
def test_get_tensor_signal():
    signals = SignalDict(tf.float32, 3)

    # check that tensor_signal is created correctly
    key = object()
    tensor_signal = signals.get_tensor_signal(
        (0,), key, np.float64, (3, 4), True)

    assert isinstance(tensor_signal, TensorSignal)
    assert np.array_equal(tensor_signal.indices, (0,))
    assert tensor_signal.key == key
    assert tensor_signal.dtype == np.float64
    assert tensor_signal.shape == (3, 4)
    assert tensor_signal.minibatch_size == 3
    assert tensor_signal.constant == signals.constant
    assert len(signals) == 0

    # check adding signal to sig_map
    sig = Signal(np.zeros(4))
    sig.minibatched = True
    tensor_signal = signals.get_tensor_signal(
        np.arange(4), key, np.float64, (2, 2), True, signal=sig)
    assert len(signals) == 1
    assert signals[sig] is tensor_signal
    assert next(iter(signals)) is sig
    assert next(iter(signals.values())) is tensor_signal

    # error if sig shape doesn't match indices
    with pytest.raises(AssertionError):
        sig = Signal(np.zeros((2, 2)))
        sig.minibatched = True
        signals.get_tensor_signal(
            np.arange(4), key, np.float64, (2, 2), True, signal=sig)

    # error if sig size doesn't match given shape
    with pytest.raises(AssertionError):
        sig = Signal(np.zeros(4))
        sig.minibatched = True
        signals.get_tensor_signal(
            np.arange(4), key, np.float64, (2, 3), True, signal=sig)

    # error if minibatched doesn't match
    with pytest.raises(AssertionError):
        sig = Signal(np.zeros(4))
        sig.minibatched = False
        signals.get_tensor_signal(
            np.arange(4), key, np.float64, (2, 2), True, signal=sig)