def test_get_tensor_signal(): signals = SignalDict(tf.float32, 3) # check that tensor_signal is created correctly key = object() tensor_signal = signals.get_tensor_signal( (0,), key, np.float64, (3, 4), True) assert isinstance(tensor_signal, TensorSignal) assert np.array_equal(tensor_signal.indices, (0,)) assert tensor_signal.key == key assert tensor_signal.dtype == np.float64 assert tensor_signal.shape == (3, 4) assert tensor_signal.minibatch_size == 3 assert tensor_signal.constant == signals.constant assert len(signals) == 0 # check adding signal to sig_map sig = Signal(np.zeros(4)) sig.minibatched = True tensor_signal = signals.get_tensor_signal( np.arange(4), key, np.float64, (2, 2), True, signal=sig) assert len(signals) == 1 assert signals[sig] is tensor_signal assert next(iter(signals)) is sig assert next(iter(signals.values())) is tensor_signal # error if sig shape doesn't match indices with pytest.raises(AssertionError): sig = Signal(np.zeros((2, 2))) sig.minibatched = True signals.get_tensor_signal( np.arange(4), key, np.float64, (2, 2), True, signal=sig) # error if sig size doesn't match given shape with pytest.raises(AssertionError): sig = Signal(np.zeros(4)) sig.minibatched = True signals.get_tensor_signal( np.arange(4), key, np.float64, (2, 3), True, signal=sig) # error if minibatched doesn't match with pytest.raises(AssertionError): sig = Signal(np.zeros(4)) sig.minibatched = False signals.get_tensor_signal( np.arange(4), key, np.float64, (2, 2), True, signal=sig)