def test_tensor_signal_load_indices(): sess = tf.InteractiveSession() sig = TensorSignal([2, 3, 4, 5], object(), None, (4,), None) sig.load_indices() assert np.all(sig.tf_indices.eval() == sig.indices) start, stop, step = sess.run(sig.as_slice) assert start == 2 assert stop == 6 assert step == 1 sig = TensorSignal([2, 4, 6, 8], object(), None, (4,), None) sig.load_indices() assert np.all(sig.tf_indices.eval() == sig.indices) start, stop, step = sess.run(sig.as_slice) assert start == 2 assert stop == 9 assert step == 2 sig = TensorSignal([2, 2, 3, 3], object(), None, (4,), None) sig.load_indices() assert np.all(sig.tf_indices.eval() == sig.indices) assert sig.as_slice is None sess.close()
def test_signal_dict_scatter(): minibatch_size = 1 var_size = 19 signals = SignalDict(None, tf.float32, minibatch_size) sess = tf.InteractiveSession() key = object() val = np.random.randn(var_size, minibatch_size) signals.bases = {key: tf.assign(tf.Variable(val, dtype=tf.float32), val)} x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), None) with pytest.raises(BuildError): # assigning to trainable variable signals.scatter(x, None) x.minibatch_size = 1 with pytest.raises(BuildError): # indices not loaded signals.scatter(x, None) x.load_indices() with pytest.raises(BuildError): # wrong dtype signals.scatter(x, tf.ones((4,), dtype=tf.float64)) # update signals.scatter(x, tf.ones((4,))) y = sess.run(signals.bases[key]) assert np.allclose(y[:4], 1) assert np.allclose(y[4:], val[4:]) # increment, and reshaping val signals.scatter(x, tf.ones((2, 2)), mode="inc") y = sess.run(signals.bases[key]) assert np.allclose(y[:4], 2) assert np.allclose(y[4:], val[4:]) # recognize assignment to full array x = TensorSignal(np.arange(var_size), key, tf.float32, (var_size,), 1) x.load_indices() y = tf.ones((var_size, 1)) signals.scatter(x, y) assert signals.bases[key].op.type == "Assign" # recognize assignment to strided full array x = TensorSignal(np.arange(0, var_size, 2), key, tf.float32, (var_size // 2 + 1,), True) x.load_indices() y = tf.ones((var_size // 2 + 1, 1)) signals.scatter(x, y) assert signals.bases[key].op.type == "ScatterUpdate" sess.close()
def test_signal_dict_gather(): minibatch_size = 1 var_size = 19 signals = SignalDict(None, tf.float32, minibatch_size) sess = tf.InteractiveSession() key = object() val = np.random.randn(var_size, minibatch_size) signals.bases = {key: tf.constant(val, dtype=tf.float32)} x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), 1) with pytest.raises(BuildError): # indices not loaded signals.gather(x) # sliced read x.load_indices() assert np.allclose(sess.run(signals.gather(x)), val[:4]) # read with reshape x = TensorSignal([0, 1, 2, 3], key, tf.float32, (2, 2), 1) x.load_indices() assert np.allclose(sess.run(signals.gather(x)), val[:4].reshape((2, 2, minibatch_size))) # gather read x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), 1) x.load_indices() y = signals.gather(x, force_copy=True) assert "Gather" in y.op.type x = TensorSignal([0, 0, 3, 3], key, tf.float32, (4,), 1) x.load_indices() assert np.allclose(sess.run(signals.gather(x)), val[[0, 0, 3, 3]]) assert "Gather" in y.op.type # reading from full array x = TensorSignal(np.arange(var_size), key, tf.float32, (var_size,), 1) x.load_indices() y = signals.gather(x) assert y is signals.bases[key] # reading from strided full array x = TensorSignal(np.arange(0, var_size, 2), key, tf.float32, (var_size // 2 + 1,), 1) x.load_indices() y = signals.gather(x) assert y.op.type == "StridedSlice" assert y.op.inputs[0] is signals.bases[key] # minibatch dimension x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), 1) x.load_indices() assert signals.gather(x).get_shape() == (4, 1) x = TensorSignal([0, 1, 2, 3], key, tf.float32, (4,), None) x.load_indices() assert signals.gather(x).get_shape() == (4,) sess.close()