Ejemplo n.º 1
0
def _test_dedupe_column_names(tmpdir,
                              input_column_names: List[str],
                              input_data: List[int],
                              expected_column_names: List[str],
                              expected_data: List[int],
                              dedupe_column_names: bool = True,
                              **kwargs) -> None:

    header_str = ','.join(input_column_names)
    data_str = ','.join(str(x) for x in input_data)
    csv_file = tmpdir.join("test.csv")
    csv_file.write(header_str + '\n' + data_str)

    dataset = [mlio.File(str(csv_file))]
    reader_params = mlio.DataReaderParams(dataset=dataset, batch_size=1)
    csv_params = mlio.CsvParams(dedupe_column_names=dedupe_column_names,
                                **kwargs)
    reader = mlio.CsvReader(reader_params, csv_params)

    example = reader.read_example()
    names = [desc.name for desc in example.schema.descriptors]
    assert names == expected_column_names

    record = [as_numpy(feature) for feature in example]
    assert np.all(np.array(record).squeeze() == np.array(expected_data))
def _get_reader(source, batch_size):
    """Returns 'CsvReader' for the given source

       Parameters
       ----------
       source: str or bytes
           Name of the SageMaker Channel, File, or directory from which the data is being read or
           the Python buffer object from which the data is being read.

       batch_size : int
           The batch size in rows to read from the source.

       Returns
       -------
       mlio.CsvReader
           CsvReader configured with a SageMaker Pipe, File or InMemory buffer
       """
    data_reader_params = mlio.DataReaderParams(dataset=_get_data(source),
                                               batch_size=batch_size,
                                               warn_bad_instances=False)
    csv_params = mlio.CsvParams(default_data_type=mlio.DataType.STRING,
                                header_row_index=None,
                                allow_quoted_new_lines=True)
    return mlio.CsvReader(data_reader_params=data_reader_params,
                          csv_params=csv_params)
Ejemplo n.º 3
0
def test_csv_params():
    filename = os.path.join(resources_dir, 'test.csv')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset, batch_size=1)
    csv_prm = mlio.CsvParams(header_row_index=None)
    reader = mlio.CsvReader(rdr_prm, csv_prm)

    example = reader.read_example()
    record = [as_numpy(feature) for feature in example]
    assert np.all(np.array(record).squeeze() == np.array([1, 0, 0, 0]))

    reader2 = mlio.CsvReader(rdr_prm, csv_prm)
    assert reader2.peek_example()
Ejemplo n.º 4
0
def test_csv_nonutf_encoding_with_encoding_param():
    filename = os.path.join(resources_dir, 'test_iso8859_5.csv')
    dataset = [mlio.File(filename)]
    rdr_prm = mlio.DataReaderParams(dataset=dataset,
                                    batch_size=2)
    csv_params = mlio.CsvParams(encoding='ISO-8859-5')

    reader = mlio.CsvReader(rdr_prm, csv_params)
    example = reader.read_example()
    nonutf_feature = example['col_3']

    try:
        feature_np = as_numpy(nonutf_feature)
    except SystemError as err:
        pytest.fail("Unexpected exception thrown")
def _get_csv_dmatrix_pipe_mode(pipe_path, csv_weights):
    """Get Data Matrix from CSV data in pipe mode.

    :param pipe_path: SageMaker pipe path where CSV formatted training data is piped
    :param csv_weights: 1 if instance weights are in second column of CSV data; else 0
    :return: xgb.DMatrix or None
    """
    try:
        pipes_path = pipe_path if isinstance(pipe_path, list) else [pipe_path]
        dataset = [mlio.SageMakerPipe(path) for path in pipes_path]
        reader_params = mlio.DataReaderParams(dataset=dataset,
                                              batch_size=BATCH_SIZE)
        csv_params = mlio.CsvParams(header_row_index=None)
        reader = mlio.CsvReader(reader_params, csv_params)

        # Check if data is present in reader
        if reader.peek_example() is not None:
            examples = []
            for example in reader:
                # Write each feature (column) of example into a single numpy array
                tmp = [as_numpy(feature).squeeze() for feature in example]
                tmp = np.array(tmp)
                if len(tmp.shape) > 1:
                    # Columns are written as rows, needs to be transposed
                    tmp = tmp.T
                else:
                    # If tmp is a 1-D array, it needs to be reshaped as a matrix
                    tmp = np.reshape(tmp, (1, tmp.shape[0]))
                examples.append(tmp)

            data = np.vstack(examples)
            del examples

            if csv_weights == 1:
                dmatrix = xgb.DMatrix(data[:, 2:],
                                      label=data[:, 0],
                                      weight=data[:, 1])
            else:
                dmatrix = xgb.DMatrix(data[:, 1:], label=data[:, 0])

            return dmatrix
        else:
            return None

    except Exception as e:
        raise exc.UserError(
            "Failed to load csv data with exception:\n{}".format(e))
Ejemplo n.º 6
0
def test_csv_params_members():
    csv_prm = mlio.CsvParams()

    assert csv_prm.column_names == []
    assert csv_prm.name_prefix == ''
    assert csv_prm.use_columns == set()
    assert csv_prm.use_columns_by_index == set()
    assert csv_prm.default_data_type is None
    assert csv_prm.column_types == {}
    assert csv_prm.column_types_by_index == {}
    assert csv_prm.header_row_index == 0
    assert csv_prm.has_single_header is False
    assert csv_prm.delimiter == ','
    assert csv_prm.quote_char == '"'
    assert csv_prm.comment_char is None
    assert csv_prm.allow_quoted_new_lines is False
    assert csv_prm.skip_blank_lines is True
    assert csv_prm.encoding is None
    assert csv_prm.max_field_length is None
    assert csv_prm.max_field_length_handling == \
        mlio.MaxFieldLengthHandling.ERROR
    assert csv_prm.max_line_length is None
    assert csv_prm.parser_params.nan_values == set()
    assert csv_prm.parser_params.number_base == 10

    csv_prm.header_row_index = None
    assert csv_prm.header_row_index is None

    csv_prm.parser_params.number_base = 2
    assert csv_prm.parser_params.number_base == 2
    '''Due to a shortcoming in pybind11, values cannot be added to container
    types, and updates must instead be made via assignment.'''
    csv_prm.column_types['foo'] = mlio.DataType.STRING  # Doesn't work
    assert csv_prm.column_types == {}

    csv_prm.column_types = {'foo': mlio.DataType.STRING}  # OK
    assert csv_prm.column_types == {'foo': mlio.DataType.STRING}