示例#1
0
 def test_merge_zero_subregions(self):
     aux = TEST_AUX_DATA.copy()
     data_source = DataSource()
     record = {"country_code": "AA"}
     key = data_source.merge(record, {"metadata": aux},
                             keys=TEST_METADATA_KEYS)
     self.assertEqual(key, "AA")
示例#2
0
 def test_merge_by_key(self):
     aux = TEST_AUX_DATA.copy()
     data_source = DataSource()
     record = {"key": "AE_1_2"}
     key = data_source.merge(record, {"metadata": aux},
                             keys=TEST_METADATA_KEYS)
     self.assertEqual(key, record["key"])
示例#3
0
 def test_merge_no_match(self):
     aux = TEST_AUX_DATA.copy()
     data_source = DataSource()
     record = {"country_code": "__"}
     key = data_source.merge(record, {"metadata": aux},
                             keys=TEST_METADATA_KEYS)
     self.assertTrue(key is None)
示例#4
0
    def test_merge_one_subregion(self):
        aux = TEST_AUX_DATA.copy()
        pipeline = DataSource()

        record = {"country_code": "AB"}
        key = pipeline.merge(record, {"metadata": aux})
        self.assertTrue(key is None)

        record = {"country_code": "AB", "subregion1_code": None}
        key = pipeline.merge(record, {"metadata": aux})
        self.assertEqual(key, "AB")

        record = {"country_code": "AB", "subregion1_code": "1"}
        key = pipeline.merge(record, {"metadata": aux})
        self.assertEqual(key, "AB_1")
示例#5
0
    def test_merge_one_subregion(self):
        aux = TEST_AUX_DATA.copy()
        data_source = DataSource()

        record = {"country_code": "AB"}
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertTrue(key is None)

        record = {"country_code": "AB", "subregion1_code": None}
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, "AB")

        record = {"country_code": "AB", "subregion1_code": "1"}
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, "AB_1")
示例#6
0
    def test_merge_null_vs_empty(self):
        aux = TEST_AUX_DATA.copy()
        data_source = DataSource()

        # Only one record has null region1_code
        record = {"country_code": "AD", "subregion1_code": None}
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, "AD")

        # Empty means "do not compare" rather than "filter non-null"
        record = {"country_code": "AD", "subregion1_code": ""}
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, None)

        # There are multiple records that fit this merge, so it's ambiguous
        record = {"country_code": "AD", "subregion1_code": "1"}
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, None)

        # Match fails because subregion1_code is not null
        record = {
            "country_code": "AD",
            "subregion1_code": None,
            "subregion2_code": "1"
        }
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, None)

        # Match is exact so the merge is unambiguous
        record = {
            "country_code": "AD",
            "subregion1_code": "1",
            "subregion2_code": "1"
        }
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, "AD_1_1")

        # Even though we don't have subregion1_code, there's only one record that matches
        record = {
            "country_code": "AD",
            "subregion1_code": "",
            "subregion2_code": "1"
        }
        key = data_source.merge(record, {"metadata": aux},
                                keys=TEST_METADATA_KEYS)
        self.assertEqual(key, "AD_1_1")
示例#7
0
    model_dir = parsed_args.job_dir

FileLogger("TEST_PREDICT", model_dir)

data_utils = DataUtils()
conf_parser = TrainDevConfigParser(
    data_utils.get_file_contents(
        os.path.join(os.path.dirname(task_mnist.__file__), "config.json")
    )
)

tf_feature_cols = TfFeatureColumns(conf_parser)

ds_test = DataSource(
    _tf_feature_cols=tf_feature_cols,
    _ds_class=pred_data,
    _dsutil=data_utils,
    force_test=True,
)
if ds_test.is_empty:
    ds_test = DataSource(
        _tf_feature_cols=tf_feature_cols,
        _ds_class="train",
        _dsutil=data_utils,
        force_test=True,
    )
    tf.logging.warn(
        "prediction dataset is empty; will use training dataset for prediction instead!"
    )

_handle = Tester(
    model_fn=resolve_model_type(conf_parser.model_type),
示例#8
0
 def test_merge_zero_subregions(self):
     aux = TEST_AUX_DATA.copy()
     pipeline = DataSource()
     record = {"country_code": "AA"}
     key = pipeline.merge(record, {"metadata": aux})
     self.assertEqual(key, "AA")
示例#9
0
 def test_merge_by_key(self):
     aux = TEST_AUX_DATA.copy()
     pipeline = DataSource()
     record = {"key": "AE_1_2"}
     key = pipeline.merge(record, {"metadata": aux})
     self.assertEqual(key, record["key"])
示例#10
0
 def test_merge_no_match(self):
     aux = TEST_AUX_DATA.copy()
     pipeline = DataSource()
     record = {"country_code": "__"}
     key = pipeline.merge(record, {"metadata": aux})
     self.assertTrue(key is None)
示例#11
0
if parsed_args.job_dir:
    model_dir = parsed_args.job_dir

FileLogger("TRAIN_EVAL", model_dir)

data_utils = DataUtils()
conf_parser = TrainDevConfigParser(
    data_utils.get_file_contents(
        os.path.join(os.path.dirname(task_inertial_har.__file__), "config.json")
    )
)

tf_feature_cols = TfFeatureColumns(conf_parser)

ds_train = DataSource(
    _tf_feature_cols=tf_feature_cols, _ds_class="train", _dsutil=data_utils
)
# train data_sources can not be empty
if ds_train.is_empty:
    raise RuntimeError("training dataset can not be empty!")

# evaluation dataset can be empty, yet a warning will be given
ds_eval = DataSource(
    _tf_feature_cols=tf_feature_cols, _ds_class="eval", _dsutil=data_utils
)
if ds_eval.is_empty:
    ds_eval = DataSource(
        _tf_feature_cols=tf_feature_cols,
        _ds_class="train",
        _dsutil=data_utils,
        force_test=True,
示例#12
0
文件: trainer.py 项目: tmaone/s3vdc
    def __init__(
        self,
        model_fn: Callable[[
            dict,
            dict,
            tf.contrib.training.HParams,
            str,
            list,
            list,
            list,
            tf.estimator.RunConfig,
        ], Tuple[dict, tf.Tensor, tf.Operation, dict, list, list], ],
        train_data_source: DataSource,
        eval_data_source: DataSource,
        hyper_params: Union[tf.contrib.training.HParams, dict] = None,
        hooks: list = None,
        model_dir: str = "model",
        date_time_str: str = None,
        export_models_as_text: bool = False,
    ) -> None:
        """Initialize a Trainer

        Arguments:
            model_fn {Callable[[dict,dict,tf.contrib.training.HParams,str,list,list,list,tf.estimator.RunConfig,],Tuple[dict, tf.Tensor, tf.Operation, dict, list, list],]} -- A model function.
            train_data_source {DataSource} -- The DataSource object of training dataset.
            eval_data_source {DataSource} -- The DataSource object of evaluation dataset.

        Keyword Arguments:
            hyper_params {Union[tf.contrib.training.HParams, dict]} -- Hyper parameters. (default: {None})
            hooks {list} -- A list of custom hooks. (default: {None})
            model_dir {str} -- The model directory. (default: {"model"})
            date_time_str {str} -- A date and time string in the format of "yyyymmdd_hhmmss" (default: {None})
            export_models_as_text {bool} -- Export readable model if True. (default: {False})
        """

        # hyper_params
        if hyper_params is None:
            hyper_params = {}
        hyper_params = tf.contrib.training.HParams(**hyper_params)
        self.hyper_params = join_hparams(
            default_general_hparams(),
            hyper_params)  # type: tf.contrib.training.HParams

        # run_config
        # TODO: make run_config configurable
        self.model_dir = model_dir
        self.run_config = default_run_config(
            model_dir=model_dir,
            save_summary_steps=self.hyper_params.get(
                "summaryFrequency", default=100),  # unit: steps
            save_checkpoints_mins=self.hyper_params.get(
                "checkpointFrequency", default=5),  # unit: minutes
            keep_checkpoint_max=self.hyper_params.get("keepMaxCheckpoint",
                                                      default=5),
        )

        # date and time string
        self.date_time_str = date_time_str

        # default eval metrics output path
        self.eval_metric_output_path = ""
        if date_time_str is None:
            tf.logging.warning(
                "missing date and time info as input params for training job")
            self.eval_metric_output_path = os.path.join(
                model_dir,
                "eval_metric_{}.json".format(
                    datetime.datetime.now().strftime("%Y%m%d_%H%M%S")),
            )
        else:
            self.eval_metric_output_path = os.path.join(
                model_dir, "eval_metric_{}.json".format(date_time_str))

        # TODO: extend the functions via different hooks
        self.hooks = hooks

        # set train and eval data sources
        self.train_data_source = train_data_source
        self.eval_data_source = eval_data_source

        # model_fn
        self.model_fn = model_fn

        # feature columns
        self.only_feature_columns = (
            train_data_source.tf_feature_cols.only_feature_columns)
        # self.id_feature_columns = train_data_source.tf_feature_cols.id_feature_columns
        # self.label_feature_columns = train_data_source.tf_feature_cols.label_feature_columns

        # trainspec
        self.train_spec = tf.estimator.TrainSpec(
            input_fn=lambda: train_data_source.input_fn(
                batch_size=self.hyper_params.batchSize),
            max_steps=self.hyper_params.maxSteps,
        )

        # Serving input function
        feature_spec = tf.feature_column.make_parse_example_spec(
            # self.train_data_source.tf_feature_cols.id_feature_columns
            # + self.train_data_source.tf_feature_cols.only_feature_columns
            self.train_data_source.tf_feature_cols.export_feature_columns)
        serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
            feature_spec)
        self.export_models_as_text = export_models_as_text
        best_exporter = tf.estimator.BestExporter(
            name="best_exporter",
            serving_input_receiver_fn=serving_input_fn,
            as_text=self.export_models_as_text,
        )
        final_exporter = tf.estimator.FinalExporter(
            name="final_exporter",
            serving_input_receiver_fn=serving_input_fn,
            as_text=self.export_models_as_text,
        )

        # estimator
        self.estimator = self.model_fn(
            self.only_feature_columns,
            self.run_config,
            self.hyper_params,
            train_data_source.tf_feature_cols,
        )

        # evalspec
        throttle_secs = self.hyper_params.get("minEvalFrequency",
                                              default=10)  # unit: minutes
        self.eval_spec = tf.estimator.EvalSpec(
            input_fn=eval_data_source.input_fn,
            steps=None,
            exporters=[final_exporter, best_exporter],
            start_delay_secs=0,
            throttle_secs=throttle_secs * 60
            # hooks=[tf_debug.DumpingDebugHook("/tmp/tfdb_dump1")]
        )
示例#13
0
        def model_fn(
            features: dict,
            labels: dict,
            mode: str,
            params: tf.contrib.training.HParams,
            config: tf.estimator.RunConfig,
        ) -> tf.estimator.EstimatorSpec:
            """Model Function

            Arguments:
                features {dict} -- A dict of features with feature names (str) as keys.
                labels {dict} -- A dict of labels with feature names (str) as keys.
                mode {str} -- one of tf.estimator.ModeKeys.(TRAIN|EVAL|PREDICT).
                params {tf.contrib.training.HParams} -- hyper parameters.
                config {tf.estimator.RunConfig} -- run config

            Returns:
                tf.estimator.EstimatorSpec -- Estimator specification
            """

            tf.logging.info("model_fn is called with mode={}".format(
                str(mode)))

            # hold Ids
            id_holder = {}
            id_feature_columns = tf_feature_col_obj.id_feature_columns
            if id_feature_columns is not None:
                with tf.variable_scope("ID_holder"):
                    for id_fc in id_feature_columns:
                        _id_fc_key = TfFeatureColumns.feature_column_key(id_fc)
                        id_holder[
                            _id_fc_key] = DataSource.try_sparse2dense_tensor(
                                features[_id_fc_key])
                        del features[_id_fc_key]

            # Call the actual model function!
            (
                predictions,
                loss,
                train_op,
                eval_metric_ops,
                training_hooks,
                evaluation_hooks,
            ) = user_model_fn(
                only_features=features,
                labels=labels,
                hparams=params,
                mode=mode,
                only_feature_columns=only_feature_columns,
                label_feature_columns=tf_feature_col_obj.label_feature_columns,
                no_export_columns=tf_feature_col_obj.no_export_feature_columns,
                config=config,
            )

            # wrap exports
            exports = {
                "predictions": tf.estimator.export.PredictOutput(predictions)
            }

            # append Ids to prediction result
            for key in id_holder.keys():
                predictions[key] = id_holder[key]

            return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                loss=loss,
                train_op=train_op,
                eval_metric_ops=eval_metric_ops,
                export_outputs=exports,
                training_hooks=training_hooks,
                evaluation_hooks=evaluation_hooks,
            )