コード例 #1
0
def test_truncation(truncation):
    data = {"a": np.random.randn(2, 10), "b": np.random.randn(2, 10)}

    duration = 10 if truncation is None else truncation

    with pytest.warns(None) as w:
        for i, (o, d) in enumerate(
                utils.minibatch_generator(data,
                                          2,
                                          shuffle=False,
                                          truncation=truncation)):
            assert np.allclose(d["a"], data["a"][:, o:o + duration])
            assert np.allclose(d["b"], data["b"][:, o:o + duration])

    # pylint: disable=undefined-loop-variable
    assert i == 10 // duration - (10 % duration == 0)

    assert len(w) == (1 if truncation == 3 else 0)

    # check truncation with n_steps input
    for i, (o,
            d) in enumerate(utils.minibatch_generator(10, None, truncation=3)):
        assert o == i * 3
        if i < 3:
            assert d == 3
        else:
            assert d == 1
コード例 #2
0
def test_minibatch_generator(shuffle):
    inputs = {"a": np.arange(100)[:, None]}
    targets = {"b": np.arange(1, 101)[:, None]}

    x_all = []
    y_all = []
    for _, x, y in utils.minibatch_generator(inputs,
                                             targets,
                                             10,
                                             shuffle=shuffle):
        x_all += [x["a"]]
        y_all += [y["b"]]

    x_all = np.concatenate(x_all)
    y_all = np.concatenate(y_all)

    if shuffle:
        assert not np.allclose(x_all, inputs["a"])
        assert not np.allclose(y_all, targets["b"])
        x_all = np.sort(x_all, axis=0)
        y_all = np.sort(y_all, axis=0)

    assert np.allclose(x_all, inputs["a"])
    assert np.allclose(y_all, targets["b"])

    x_all = []
    y_all = []
    with pytest.warns(UserWarning):
        for _, x, y in utils.minibatch_generator(inputs,
                                                 targets,
                                                 12,
                                                 shuffle=shuffle):
            assert x["a"].shape[0] == 12
            assert y["b"].shape[0] == 12
            x_all += [x["a"]]
            y_all += [y["b"]]
    x_all = np.sort(np.concatenate(x_all))
    y_all = np.sort(np.concatenate(y_all))

    assert len(x_all) == 96
    assert len(y_all) == 96

    if shuffle:
        assert not np.allclose(x_all, np.arange(96)[:, None])
        assert not np.allclose(y_all, np.arange(1, 97)[:, None])
    else:
        assert np.allclose(x_all, np.arange(96)[:, None])
        assert np.allclose(y_all, np.arange(1, 97)[:, None])
コード例 #3
0
def test_minibatch_generator(shuffle):
    data = {"a": np.arange(100)[:, None], "b": np.arange(1, 101)[:, None]}

    x_all = []
    y_all = []
    for _, d in utils.minibatch_generator(data, 10, shuffle=shuffle):
        x_all += [d["a"]]
        y_all += [d["b"]]

    x_all = np.concatenate(x_all)
    y_all = np.concatenate(y_all)

    if shuffle:
        assert not np.allclose(x_all, data["a"])
        assert not np.allclose(y_all, data["b"])
        x_all = np.sort(x_all, axis=0)
        y_all = np.sort(y_all, axis=0)

    assert np.allclose(x_all, data["a"])
    assert np.allclose(y_all, data["b"])

    x_all = []
    y_all = []
    with pytest.warns(UserWarning):
        for _, d in utils.minibatch_generator(data, 12, shuffle=shuffle):
            assert d["a"].shape[0] == 12
            assert d["b"].shape[0] == 12
            x_all += [d["a"]]
            y_all += [d["b"]]
    x_all = np.sort(np.concatenate(x_all))
    y_all = np.sort(np.concatenate(y_all))

    assert len(x_all) == 96
    assert len(y_all) == 96

    if shuffle:
        assert not np.allclose(x_all, np.arange(96)[:, None])
        assert not np.allclose(y_all, np.arange(1, 97)[:, None])
    else:
        assert np.allclose(x_all, np.arange(96)[:, None])
        assert np.allclose(y_all, np.arange(1, 97)[:, None])
コード例 #4
0
def test_minibatch_generator(shuffle):
    data = {"a": np.arange(100)[:, None], "b": np.arange(1, 101)[:, None]}

    x_all = []
    y_all = []
    for _, d in utils.minibatch_generator(data, 10, shuffle=shuffle):
        x_all += [d["a"]]
        y_all += [d["b"]]

    x_all = np.concatenate(x_all)
    y_all = np.concatenate(y_all)

    if shuffle:
        assert not np.allclose(x_all, data["a"])
        assert not np.allclose(y_all, data["b"])
        x_all = np.sort(x_all, axis=0)
        y_all = np.sort(y_all, axis=0)

    assert np.allclose(x_all, data["a"])
    assert np.allclose(y_all, data["b"])

    x_all = []
    y_all = []
    with pytest.warns(UserWarning):
        for _, d in utils.minibatch_generator(data, 12, shuffle=shuffle):
            assert d["a"].shape[0] == 12
            assert d["b"].shape[0] == 12
            x_all += [d["a"]]
            y_all += [d["b"]]
    x_all = np.sort(np.concatenate(x_all))
    y_all = np.sort(np.concatenate(y_all))

    assert len(x_all) == 96
    assert len(y_all) == 96

    if shuffle:
        assert not np.allclose(x_all, np.arange(96)[:, None])
        assert not np.allclose(y_all, np.arange(1, 97)[:, None])
    else:
        assert np.allclose(x_all, np.arange(96)[:, None])
        assert np.allclose(y_all, np.arange(1, 97)[:, None])
コード例 #5
0
def test_truncation(truncation):
    data = {"a": np.random.randn(2, 10), "b": np.random.randn(2, 10)}

    duration = 10 if truncation is None else truncation

    with pytest.warns(None) as w:
        for i, (o, d) in enumerate(utils.minibatch_generator(
                data, 2, shuffle=False, truncation=truncation)):
            assert np.allclose(d["a"], data["a"][:, o:o + duration])
            assert np.allclose(d["b"], data["b"][:, o:o + duration])

    # pylint: disable=undefined-loop-variable
    assert i == 10 // duration - (10 % duration == 0)

    assert len(w) == (1 if truncation == 3 else 0)

    # check truncation with n_steps input
    for i, (o, d) in enumerate(utils.minibatch_generator(
            10, None, truncation=3)):
        assert o == i * 3
        if i < 3:
            assert d == 3
        else:
            assert d == 1
コード例 #6
0
def test_truncation(truncation):
    inputs = {"a": np.random.randn(2, 10)}
    targets = {"b": np.random.randn(2, 10)}

    duration = 10 if truncation is None else truncation

    with pytest.warns(None) as w:
        for i, (o, x, y) in enumerate(
                utils.minibatch_generator(inputs,
                                          targets,
                                          2,
                                          shuffle=False,
                                          truncation=truncation)):
            assert np.allclose(x["a"], inputs["a"][:, o:o + duration])
            assert np.allclose(y["b"], targets["b"][:, o:o + duration])

    # pylint: disable=undefined-loop-variable
    assert i == 10 // duration - (10 % duration == 0)

    assert len(w) == (1 if truncation == 3 else 0)