def test_signal_dict_gather(): minibatch_size = 1 var_size = 19 signals = SignalDict(tf.float32, minibatch_size) sess = tf.InteractiveSession() key = object() val = np.random.randn(var_size, minibatch_size) signals.bases = {key: tf.constant(val, dtype=tf.float32)} x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True) # sliced read assert np.allclose(sess.run(signals.gather(x)), val[:4]) # read with reshape x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (2, 2), True) assert np.allclose(sess.run(signals.gather(x)), val[:4].reshape((2, 2, minibatch_size))) # gather read x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True) y = signals.gather(x, force_copy=True) assert "Gather" in y.op.type x = signals.get_tensor_signal([0, 0, 3, 3], key, tf.float32, (4,), True) assert np.allclose(sess.run(signals.gather(x)), val[[0, 0, 3, 3]]) assert "Gather" in y.op.type # reading from full array x = signals.get_tensor_signal(np.arange(var_size), key, tf.float32, (var_size,), True) y = signals.gather(x) assert y is signals.bases[key] # reading from strided full array x = signals.get_tensor_signal(np.arange(0, var_size, 2), key, tf.float32, (var_size // 2 + 1,), True) y = signals.gather(x) assert y.op.type == "StridedSlice" assert y.op.inputs[0] is signals.bases[key] # minibatch dimension x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True) assert signals.gather(x).get_shape() == (4, 1) x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), False) assert signals.gather(x).get_shape() == (4,) sess.close()
def test_signal_dict_gather(minibatched): minibatch_size = 3 var_size = 19 signals = SignalDict(tf.float32, minibatch_size) key = object() val = np.random.random( (minibatch_size, var_size) if minibatched else (var_size,) ).astype(np.float32) gathered_val = val[:, :4] if minibatched else val[:4] signals.bases = {key: tf.constant(val, dtype=tf.float32)} x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (4,), minibatched) # sliced read assert np.allclose(signals.gather(x), gathered_val) assert signals.read_types["strided_slice"] == 1 # read with reshape x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (2, 2), minibatched) y = signals.gather(x) shape = (minibatch_size, 2, 2) if minibatched else (2, 2) assert y.shape == shape assert np.allclose(y, gathered_val.reshape(shape)) assert signals.read_types["strided_slice"] == 2 # gather read x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (4,), minibatched) y = signals.gather(x, force_copy=True) assert signals.read_types["gather"] == 1 x = signals.get_tensor_signal( ((0, 1), (0, 1), (3, 4), (3, 4)), key, tf.float32, (4,), minibatched ) assert np.allclose( signals.gather(x), val[:, [0, 0, 3, 3]] if minibatched else val[[0, 0, 3, 3]] ) assert signals.read_types["gather"] == 2 # reading from full array x = signals.get_tensor_signal( [(0, var_size)], key, tf.float32, (var_size,), minibatched ) y = signals.gather(x) assert y is signals.bases[key] assert signals.read_types["identity"] == 1 # reading from strided full array x = signals.get_tensor_signal( tuple((i * 2, i * 2 + 1) for i in range(var_size // 2 + 1)), key, tf.float32, (var_size // 2 + 1,), minibatched, ) y = signals.gather(x) assert y is not signals.bases[key] assert signals.read_types["gather"] == 3