def test_truncation(truncation): data = {"a": np.random.randn(2, 10), "b": np.random.randn(2, 10)} duration = 10 if truncation is None else truncation with pytest.warns(None) as w: for i, (o, d) in enumerate( utils.minibatch_generator(data, 2, shuffle=False, truncation=truncation)): assert np.allclose(d["a"], data["a"][:, o:o + duration]) assert np.allclose(d["b"], data["b"][:, o:o + duration]) # pylint: disable=undefined-loop-variable assert i == 10 // duration - (10 % duration == 0) assert len(w) == (1 if truncation == 3 else 0) # check truncation with n_steps input for i, (o, d) in enumerate(utils.minibatch_generator(10, None, truncation=3)): assert o == i * 3 if i < 3: assert d == 3 else: assert d == 1
def test_minibatch_generator(shuffle): inputs = {"a": np.arange(100)[:, None]} targets = {"b": np.arange(1, 101)[:, None]} x_all = [] y_all = [] for _, x, y in utils.minibatch_generator(inputs, targets, 10, shuffle=shuffle): x_all += [x["a"]] y_all += [y["b"]] x_all = np.concatenate(x_all) y_all = np.concatenate(y_all) if shuffle: assert not np.allclose(x_all, inputs["a"]) assert not np.allclose(y_all, targets["b"]) x_all = np.sort(x_all, axis=0) y_all = np.sort(y_all, axis=0) assert np.allclose(x_all, inputs["a"]) assert np.allclose(y_all, targets["b"]) x_all = [] y_all = [] with pytest.warns(UserWarning): for _, x, y in utils.minibatch_generator(inputs, targets, 12, shuffle=shuffle): assert x["a"].shape[0] == 12 assert y["b"].shape[0] == 12 x_all += [x["a"]] y_all += [y["b"]] x_all = np.sort(np.concatenate(x_all)) y_all = np.sort(np.concatenate(y_all)) assert len(x_all) == 96 assert len(y_all) == 96 if shuffle: assert not np.allclose(x_all, np.arange(96)[:, None]) assert not np.allclose(y_all, np.arange(1, 97)[:, None]) else: assert np.allclose(x_all, np.arange(96)[:, None]) assert np.allclose(y_all, np.arange(1, 97)[:, None])
def test_minibatch_generator(shuffle): data = {"a": np.arange(100)[:, None], "b": np.arange(1, 101)[:, None]} x_all = [] y_all = [] for _, d in utils.minibatch_generator(data, 10, shuffle=shuffle): x_all += [d["a"]] y_all += [d["b"]] x_all = np.concatenate(x_all) y_all = np.concatenate(y_all) if shuffle: assert not np.allclose(x_all, data["a"]) assert not np.allclose(y_all, data["b"]) x_all = np.sort(x_all, axis=0) y_all = np.sort(y_all, axis=0) assert np.allclose(x_all, data["a"]) assert np.allclose(y_all, data["b"]) x_all = [] y_all = [] with pytest.warns(UserWarning): for _, d in utils.minibatch_generator(data, 12, shuffle=shuffle): assert d["a"].shape[0] == 12 assert d["b"].shape[0] == 12 x_all += [d["a"]] y_all += [d["b"]] x_all = np.sort(np.concatenate(x_all)) y_all = np.sort(np.concatenate(y_all)) assert len(x_all) == 96 assert len(y_all) == 96 if shuffle: assert not np.allclose(x_all, np.arange(96)[:, None]) assert not np.allclose(y_all, np.arange(1, 97)[:, None]) else: assert np.allclose(x_all, np.arange(96)[:, None]) assert np.allclose(y_all, np.arange(1, 97)[:, None])
def test_truncation(truncation): data = {"a": np.random.randn(2, 10), "b": np.random.randn(2, 10)} duration = 10 if truncation is None else truncation with pytest.warns(None) as w: for i, (o, d) in enumerate(utils.minibatch_generator( data, 2, shuffle=False, truncation=truncation)): assert np.allclose(d["a"], data["a"][:, o:o + duration]) assert np.allclose(d["b"], data["b"][:, o:o + duration]) # pylint: disable=undefined-loop-variable assert i == 10 // duration - (10 % duration == 0) assert len(w) == (1 if truncation == 3 else 0) # check truncation with n_steps input for i, (o, d) in enumerate(utils.minibatch_generator( 10, None, truncation=3)): assert o == i * 3 if i < 3: assert d == 3 else: assert d == 1
def test_truncation(truncation): inputs = {"a": np.random.randn(2, 10)} targets = {"b": np.random.randn(2, 10)} duration = 10 if truncation is None else truncation with pytest.warns(None) as w: for i, (o, x, y) in enumerate( utils.minibatch_generator(inputs, targets, 2, shuffle=False, truncation=truncation)): assert np.allclose(x["a"], inputs["a"][:, o:o + duration]) assert np.allclose(y["b"], targets["b"][:, o:o + duration]) # pylint: disable=undefined-loop-variable assert i == 10 // duration - (10 % duration == 0) assert len(w) == (1 if truncation == 3 else 0)