def test_signal_dict_scatter(minibatched): minibatch_size = 2 var_size = 19 signals = SignalDict(tf.float32, minibatch_size) key = object() var_key = object() val = np.random.random( (minibatch_size, var_size) if minibatched else (var_size,) ).astype(np.float32) update_shape = (minibatch_size, 4) if minibatched else (4,) pre_slice = np.index_exp[:, :4] if minibatched else np.index_exp[:4] post_slice = np.index_exp[:, 4:] if minibatched else np.index_exp[4:] signals.bases = {key: tf.constant(val), var_key: tf.Variable(val)} x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (4,), minibatched) with pytest.raises(BuildError, match="wrong dtype"): signals.scatter(x, tf.ones(update_shape, dtype=tf.float64)) x_var = signals.get_tensor_signal([(0, 4)], var_key, tf.float32, (4,), minibatched) with pytest.raises(BuildError, match="should not be a Variable"): signals.scatter(x_var, tf.ones(update_shape)) # update signals.scatter(x, tf.ones(update_shape)) y = signals.bases[key] assert np.allclose(y[pre_slice], 1) assert np.allclose(y[post_slice], val[post_slice]) assert signals.write_types["scatter_update"] == 1 # increment, and reshaping val signals.scatter( x, tf.ones((minibatch_size, 2, 2) if minibatched else (2, 2)), mode="inc" ) y = signals.bases[key] assert np.allclose(y[pre_slice], 2) assert np.allclose(y[post_slice], val[post_slice]) assert signals.write_types["scatter_add"] == 1 # recognize assignment to full array x = signals.get_tensor_signal( [(0, var_size)], key, tf.float32, (var_size,), minibatched ) y = tf.ones((minibatch_size, var_size) if minibatched else (var_size,)) signals.scatter(x, y) assert signals.bases[key] is y assert signals.write_types["assign"] == 1
def test_signal_dict_scatter(): minibatch_size = 1 var_size = 19 signals = SignalDict(tf.float32, minibatch_size) sess = tf.InteractiveSession() key = object() val = np.random.randn(var_size, minibatch_size) signals.bases = {key: tf.assign(tf.Variable(val, dtype=tf.float32), val)} x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True) with pytest.raises(BuildError): # wrong dtype signals.scatter(x, tf.ones((4,), dtype=tf.float64)) # update signals.scatter(x, tf.ones((4,))) y = sess.run(signals.bases[key]) assert np.allclose(y[:4], 1) assert np.allclose(y[4:], val[4:]) # increment, and reshaping val signals.scatter(x, tf.ones((2, 2)), mode="inc") y = sess.run(signals.bases[key]) assert np.allclose(y[:4], 2) assert np.allclose(y[4:], val[4:]) # recognize assignment to full array x = signals.get_tensor_signal(np.arange(var_size), key, tf.float32, (var_size,), True) y = tf.ones((var_size, 1)) signals.scatter(x, y) assert signals.bases[key].op.type == "Assign" # recognize assignment to strided full array x = signals.get_tensor_signal(np.arange(0, var_size, 2), key, tf.float32, (var_size // 2 + 1,), True) y = tf.ones((var_size // 2 + 1, 1)) signals.scatter(x, y) assert signals.bases[key].op.type == "ScatterUpdate" sess.close()