Esempio n. 1
0
def test_signal_dict_scatter(minibatched):
    minibatch_size = 2
    var_size = 19
    signals = SignalDict(tf.float32, minibatch_size)

    key = object()
    var_key = object()
    val = np.random.random(
        (minibatch_size, var_size) if minibatched else (var_size,)
    ).astype(np.float32)
    update_shape = (minibatch_size, 4) if minibatched else (4,)
    pre_slice = np.index_exp[:, :4] if minibatched else np.index_exp[:4]
    post_slice = np.index_exp[:, 4:] if minibatched else np.index_exp[4:]

    signals.bases = {key: tf.constant(val), var_key: tf.Variable(val)}

    x = signals.get_tensor_signal([(0, 4)], key, tf.float32, (4,), minibatched)
    with pytest.raises(BuildError, match="wrong dtype"):
        signals.scatter(x, tf.ones(update_shape, dtype=tf.float64))

    x_var = signals.get_tensor_signal([(0, 4)], var_key, tf.float32, (4,), minibatched)
    with pytest.raises(BuildError, match="should not be a Variable"):
        signals.scatter(x_var, tf.ones(update_shape))

    # update
    signals.scatter(x, tf.ones(update_shape))
    y = signals.bases[key]
    assert np.allclose(y[pre_slice], 1)
    assert np.allclose(y[post_slice], val[post_slice])
    assert signals.write_types["scatter_update"] == 1

    # increment, and reshaping val
    signals.scatter(
        x, tf.ones((minibatch_size, 2, 2) if minibatched else (2, 2)), mode="inc"
    )
    y = signals.bases[key]
    assert np.allclose(y[pre_slice], 2)
    assert np.allclose(y[post_slice], val[post_slice])
    assert signals.write_types["scatter_add"] == 1

    # recognize assignment to full array
    x = signals.get_tensor_signal(
        [(0, var_size)], key, tf.float32, (var_size,), minibatched
    )
    y = tf.ones((minibatch_size, var_size) if minibatched else (var_size,))
    signals.scatter(x, y)
    assert signals.bases[key] is y
    assert signals.write_types["assign"] == 1
Esempio n. 2
0
def test_signal_dict_scatter():
    minibatch_size = 1
    var_size = 19
    signals = SignalDict(tf.float32, minibatch_size)

    sess = tf.InteractiveSession()

    key = object()
    val = np.random.randn(var_size, minibatch_size)
    signals.bases = {key: tf.assign(tf.Variable(val, dtype=tf.float32),
                                    val)}

    x = signals.get_tensor_signal([0, 1, 2, 3], key, tf.float32, (4,), True)
    with pytest.raises(BuildError):
        # wrong dtype
        signals.scatter(x, tf.ones((4,), dtype=tf.float64))

    # update
    signals.scatter(x, tf.ones((4,)))
    y = sess.run(signals.bases[key])
    assert np.allclose(y[:4], 1)
    assert np.allclose(y[4:], val[4:])

    # increment, and reshaping val
    signals.scatter(x, tf.ones((2, 2)), mode="inc")
    y = sess.run(signals.bases[key])
    assert np.allclose(y[:4], 2)
    assert np.allclose(y[4:], val[4:])

    # recognize assignment to full array
    x = signals.get_tensor_signal(np.arange(var_size), key, tf.float32,
                                  (var_size,), True)
    y = tf.ones((var_size, 1))
    signals.scatter(x, y)
    assert signals.bases[key].op.type == "Assign"

    # recognize assignment to strided full array
    x = signals.get_tensor_signal(np.arange(0, var_size, 2), key, tf.float32,
                                  (var_size // 2 + 1,), True)
    y = tf.ones((var_size // 2 + 1, 1))
    signals.scatter(x, y)
    assert signals.bases[key].op.type == "ScatterUpdate"

    sess.close()