def test_read_to_fit_memory_dangling_element(tmpdir_factory, output_dtype):
    """Test that data is read in correctly when `len(data) = 1 mod batch_size`."""
    data = np.zeros((10, 10)).astype(str)
    for i in range(data.shape[0]):
        data[i, i] = str(i + 1)
    data_dir = tmpdir_factory.mktemp("ten_line_csv")
    data_file = data_dir.join("ten_lines.csv")
    np.savetxt(data_file.strpath, data, delimiter=",", newline="\n", fmt="%s")

    X_read, y_read = _read_to_fit_memory(
        _get_reader(data_dir.strpath, 3),
        psutil.virtual_memory().total,
        output_dtype=output_dtype,
        target_column_index=0,
    )
    assert np.array_equal(data[:, 1:], X_read)
    assert np.array_equal(data[:, 0], y_read)
Exemple #2
0
def test_get_reader_incorrect_path():
    """Test for reading from a path that doesn't exist"""
    with pytest.raises(RuntimeError):
        _get_reader(source="incorrect", batch_size=100)
Exemple #3
0
def test_get_reader_error_malformed_channel_cfg(cfg, expected_error):
    """Test for reading from an invalid channel"""
    with pytest.raises(expected_error):
        with managed_env_var(cfg):
            _get_reader(source="abc", batch_size=1000)