Пример #1
0
def test_mean_agg_constructor():
    agg = MeanAggregator(2)
    assert agg.output_dim == 2
    assert not agg.has_bias

    # Check config
    config = agg.get_config()
    assert config["output_dim"] == 2
    assert config["bias"] is False
    assert config["act"] == "relu"
Пример #2
0
def test_mean_agg_apply():
    agg = MeanAggregator(5, bias=True, act=lambda x: x, kernel_initializer="ones")
    inp1 = keras.Input(shape=(1, 2))
    inp2 = keras.Input(shape=(1, 2, 2))
    out = agg([inp1, inp2])

    assert agg.weight_dims == [3, 2]

    model = keras.Model(inputs=[inp1, inp2], outputs=out)
    x1 = np.array([[[1, 1]]])
    x2 = np.array([[[[2, 2], [3, 3]]]])
    actual = model.predict([x1, x2])
    expected = np.array([[[2, 2, 2, 5, 5]]])
    assert expected == pytest.approx(actual)
Пример #3
0
def test_mean_agg_zero_neighbours():
    agg = MeanAggregator(4, bias=False, act=lambda x: x, kernel_initializer="ones")

    inp1 = keras.Input(shape=(1, 2))
    inp2 = keras.Input(shape=(1, 0, 2))

    out = agg([inp1, inp2])
    model = keras.Model(inputs=[inp1, inp2], outputs=out)

    x1 = np.array([[[1, 1]]])
    x2 = np.zeros((1, 1, 0, 2))

    actual = model.predict([x1, x2])
    expected = np.array([[[2, 2, 2, 2]]])
    assert expected == pytest.approx(actual)
Пример #4
0
def test_mean_agg_apply_groups():
    agg = MeanAggregator(11, bias=True, act=lambda x: x, kernel_initializer="ones")
    inp1 = keras.Input(shape=(1, 2))
    inp2 = keras.Input(shape=(1, 2, 2))
    inp3 = keras.Input(shape=(1, 2, 2))
    out = agg([inp1, inp2, inp3])

    assert agg.weight_dims == [5, 3, 3]

    model = keras.Model(inputs=[inp1, inp2, inp3], outputs=out)
    x1 = np.array([[[1, 1]]])
    x2 = np.array([[[[2, 2], [3, 3]]]])
    x3 = np.array([[[[5, 5], [4, 4]]]])

    actual = model.predict([x1, x2, x3])
    print(actual)

    expected = np.array([[[2] * 5 + [5] * 3 + [9] * 3]])
    assert expected == pytest.approx(actual)
Пример #5
0
def test_mean_agg_constructor_1():
    agg = MeanAggregator(output_dim=4, bias=True, act=lambda x: x + 1)
    assert agg.output_dim == 4
    assert agg.has_bias
    assert agg.act(2) == 3