def test_train_model_local(self):
        with tempfile.TemporaryDirectory() as temp_dir_name:
            num_records = 1000
            columns = [
                "sepal_length",
                "sepal_width",
                "petal_length",
                "petal_width",
                "class",
            ]
            training_data = create_iris_csv_file(
                size=num_records, columns=columns, temp_dir=temp_dir_name
            )
            validation_data = create_iris_csv_file(
                size=num_records, columns=columns, temp_dir=temp_dir_name
            )

            data_reader_params = (
                'columns=["sepal_length", "sepal_width", "petal_length",'
                '"petal_width", "class"]; sep=","'
            )
            args = LocalExecutorArgs(
                num_epochs=1,
                minibatch_size=32,
                training_data=training_data,
                validation_data=validation_data,
                evaluation_steps=10,
                model_zoo=_model_zoo_path,
                model_def=_iris_dnn_def,
                dataset_fn="dataset_fn",
                loss="loss",
                optimizer="optimizer",
                eval_metrics_fn="eval_metrics_fn",
                model_params="",
                prediction_outputs_processor="PredictionOutputsProcessor",
                data_reader_params=data_reader_params,
                num_minibatches_per_task=5,
                callbacks="callbacks",
            )
            local_executor = LocalExecutor(args)
            train_tasks = local_executor._gen_tasks(
                local_executor.training_data
            )
            validation_tasks = local_executor._gen_tasks(
                local_executor.validation_data
            )

            train_dataset = local_executor._get_dataset(train_tasks)
            features, labels = next(iter(train_dataset))
            loss = local_executor._train(features, labels)
            self.assertEqual(type(loss.numpy()), np.float32)

            validation_dataset = local_executor._get_dataset(validation_tasks)
            metrics = local_executor._evaluate(validation_dataset)
            self.assertEqual(list(metrics.keys()), ["accuracy"])
示例#2
0
    def test_csv_data_reader(self):
        with tempfile.TemporaryDirectory() as temp_dir_name:
            num_records = 128
            columns = [
                "sepal_length",
                "sepal_width",
                "petal_length",
                "petal_width",
                "class",
            ]
            iris_file_name = create_iris_csv_file(
                size=num_records, columns=columns, temp_dir=temp_dir_name
            )
            csv_data_reader = CSVDataReader(columns=columns, sep=",")
            task = _MockedTask(
                0, num_records, iris_file_name, elasticdl_pb2.TRAINING
            )

            def _gen():
                for record in csv_data_reader.read_records(task):
                    yield record

            def _dataset_fn(dataset, mode, metadata):
                def _parse_data(record):
                    features = tf.strings.to_number(record[0:-1], tf.float32)
                    label = tf.strings.to_number(record[-1], tf.float32)
                    return features, label

                dataset = dataset.map(_parse_data)
                dataset = dataset.batch(10)
                return dataset

            dataset = tf.data.Dataset.from_generator(
                _gen, csv_data_reader.records_output_types
            )
            dataset = _dataset_fn(dataset, None, None)
            for features, labels in dataset:
                self.assertEqual(features.shape.as_list(), [10, 4])
                self.assertEqual(labels.shape.as_list(), [10])
                break
示例#3
0
 def test_csv_data_reader(self):
     with tempfile.TemporaryDirectory() as temp_dir_name:
         num_records = 128
         columns = [
             "sepal_length",
             "sepal_width",
             "petal_length",
             "petal_width",
             "class",
         ]
         iris_file_name = create_iris_csv_file(size=num_records,
                                               columns=columns,
                                               temp_dir=temp_dir_name)
         csv_data_reader = TextDataReader(filename=iris_file_name,
                                          records_per_task=20)
         shards = csv_data_reader.create_shards()
         self.assertEqual(len(shards), 7)
         task = _Task(iris_file_name, 0, 20, elasticai_api_pb2.TRAINING)
         record_count = 0
         for record in csv_data_reader.read_records(task):
             record_count += 1
         self.assertEqual(csv_data_reader.get_size(), num_records)
         self.assertEqual(record_count, 20)