Example #1
0
    def test_batch_generator_fn(self):
        shuffle_buffer_size = 10
        rows_in_row_group = 100
        batch_size = 32

        def _create_numpy_array(n_rows, shape):
            return np.array([[i for i in range(j, j + shape)] for j in range(n_rows)])

        """A dummy reader class only run 1 epoch (2 rows of data) for each iteration"""
        class DummyReader():
            def __init__(self):
                self._in_iter = False

            def __iter__(self):
                if self._in_iter:
                    raise RuntimeError('Do not support resetting a dummy reader while in the middle of iteration.')

                self._in_iter = True
                Row = collections.namedtuple('row', ['col1', 'col2', 'sample_weight', 'label'])

                col11 = _create_numpy_array(rows_in_row_group, 1)
                col21 = _create_numpy_array(rows_in_row_group, 10)
                label1 = _create_numpy_array(rows_in_row_group, 8)
                sw1 = np.array([i / 100. for i in range(rows_in_row_group)])

                row1 = Row(col1=col11, col2=col21, label=label1, sample_weight=sw1)

                col12 = _create_numpy_array(rows_in_row_group, 1)
                col22 = _create_numpy_array(rows_in_row_group, 10)
                label2 = _create_numpy_array(rows_in_row_group, 8)
                sw2 = np.array([i / 100. for i in range(rows_in_row_group)])
                row2 = Row(col1=col12, col2=col22, label=label2, sample_weight=sw2)
                try:
                    yield row1
                    yield row2
                finally:
                    self._in_iter = False

        metadata = \
            {
                'col1': {
                    'dtype': float,
                    'intermediate_format': constants.NOCHANGE,
                    'max_size': 1,
                    'shape': 1
                },
                'col2': {
                    'dtype': DenseVector,
                    'intermediate_format': constants.ARRAY,
                    'max_size': 10,
                    'shape': 10
                },
                'label': {
                    'dtype': float,
                    'intermediate_format': constants.NOCHANGE,
                    'max_size': 1,
                    'shape': 1
                },
            }

        reader = DummyReader()

        feature_columns = ['col1', 'col2']
        label_columns = ['label']
        sample_weight_col = 'sample_weight'

        input_shapes = [[-1, 1], [-1, 2, 5]]
        output_shapes = [[-1, 2, 4]]

        batch_generator = BareKerasUtil._batch_generator_fn(
            feature_columns, label_columns, sample_weight_col,
            input_shapes, output_shapes, metadata)

        for shuffle in [True, False]:
            batch_gen = batch_generator(reader, batch_size, shuffle_buffer_size, shuffle=shuffle)

            for _ in range(10):
                batch = next(batch_gen)
                assert batch[0][0][0].shape == (1,)
                assert batch[0][1][0].shape == (2, 5)
                assert batch[1][0][0].shape == (2, 4)
                # sample weight has to be a singel np array with shape (batch_size,)
                assert batch[2][0].shape == (batch_size,)
Example #2
0
    def test_batch_generator_fn(self):
        shuffle_buffer_size = 10
        rows_in_row_group = 100
        batch_size = 32

        def _create_numpy_array(n_rows, shape):
            return np.array([[i for i in range(j, j + shape)]
                             for j in range(n_rows)])

        def dummy_reader():
            Row = collections.namedtuple(
                'row', ['col1', 'col2', 'sample_weight', 'label'])

            col11 = _create_numpy_array(rows_in_row_group, 1)
            col21 = _create_numpy_array(rows_in_row_group, 10)
            label1 = _create_numpy_array(rows_in_row_group, 8)
            sw1 = np.array([i / 100. for i in range(rows_in_row_group)])

            row1 = Row(col1=col11, col2=col21, label=label1, sample_weight=sw1)

            col12 = _create_numpy_array(rows_in_row_group, 1)
            col22 = _create_numpy_array(rows_in_row_group, 10)
            label2 = _create_numpy_array(rows_in_row_group, 8)
            sw2 = np.array([i / 100. for i in range(rows_in_row_group)])
            row2 = Row(col1=col12, col2=col22, label=label2, sample_weight=sw2)

            while True:
                yield row1
                yield row2

        metadata = \
            {
                'col1': {
                    'dtype': float,
                    'intermediate_format': constants.NOCHANGE,
                    'max_size': 1,
                    'shape': 1
                },
                'col2': {
                    'dtype': DenseVector,
                    'intermediate_format': constants.ARRAY,
                    'max_size': 10,
                    'shape': 10
                },
                'label': {
                    'dtype': float,
                    'intermediate_format': constants.NOCHANGE,
                    'max_size': 1,
                    'shape': 1
                },
            }

        reader = dummy_reader()

        feature_columns = ['col1', 'col2']
        label_columns = ['label']
        sample_weight_col = 'sample_weight'

        input_shapes = [[-1, 1], [-1, 2, 5]]
        output_shapes = [[-1, 2, 4]]

        batch_generator = BareKerasUtil._batch_generator_fn(
            feature_columns, label_columns, sample_weight_col, input_shapes,
            output_shapes, batch_size, metadata)

        for shuffle in [True, False]:
            batch_gen = batch_generator(reader,
                                        shuffle_buffer_size,
                                        shuffle=shuffle)

            for _ in range(10):
                batch = next(batch_gen)
                assert batch[0][0][0].shape == (1, )
                assert batch[0][1][0].shape == (2, 5)
                assert batch[1][0][0].shape == (2, 4)
                # sample weight has to be a singel np array with shape (batch_size,)
                assert batch[2][0].shape == (batch_size, )