Exemple #1
0
class TestDataset(unittest.TestCase):
    def setUp(self):
        self._data = {'col1': [1, 2], 'col2': [3, 4], 'col4': [5, 6]}
        self._dataframe = pd.DataFrame(data=self._data)
        self._dataset = DataModel(self._dataframe)

    def test_validate_columns_invalid(self):
        with self.assertRaises(RuntimeError):
            self._dataset.validate_columns(['col3'])

    def test_validate_columns(self):
        self._dataset.validate_columns(['col1'])

    def test_feature_columns(self):
        intended_columns = ['col1', 'col2']
        self._dataset.set_feature_columns(intended_columns)

        feature_columns = self._dataset.get_feature_columns()
        result_columns = list(feature_columns.columns.values)

        self.assertEqual(result_columns, intended_columns)

    def test_target_column(self):
        intended_column = 'col1'
        self._dataset.set_target_column(intended_column)

        target_column = self._dataset.get_target_column()

        self.assertEqual(target_column.tolist(), self._data[intended_column])
Exemple #2
0
    def assign_fn(self, data_model: DataModel, fn_name: str, kwargs: dict):
        if hasattr(self.fn_holder, fn_name):
            return self.load_from_holder(data_model, fn_name, kwargs)

        if fn_name == self.PANDAS_FN:
            kwargs['x'] = data_model.get_input_fn_x_data()
            kwargs['y'] = data_model.get_target_column()
            kwargs['target_column'] = data_model.target_column_name

            fn = getattr(tf.estimator.inputs, 'pandas_input_fn')

            return fn(**kwargs)
def base_fn(data_model: Data.DataModel, batch_size=1, epoch=1):
    """ input function one, made for shoes AI.

    :param data_model: Data.MLDataset
    :param epoch: int
    :param batch_size: int
    :return:
    """

    features = _dataset_to_dict(features=data_model.get_feature_columns())

    data_set = tf.data.Dataset.from_tensor_slices(
        (features, data_model.get_target_column()))

    data_set = data_set.shuffle(100).repeat(epoch).batch(batch_size)

    return data_set.make_one_shot_iterator().get_next()