Exemplo n.º 1
0
 def testGenericExecutor(self):
   self._exec_properties['module_file'] = self._module_file
   executor.GenericExecutor().Do(
       input_dict=self._input_dict,
       output_dict=self._output_dict,
       exec_properties=self._exec_properties)
   self._verify_model_exports()
Exemplo n.º 2
0
  def setUp(self):
    super(ExecutorTest, self).setUp()
    self._source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    self._output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create input dict.
    examples = standard_artifacts.Examples()
    examples.uri = os.path.join(self._source_data_dir,
                                'transform/transformed_examples')
    examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
    transform_output = standard_artifacts.TransformGraph()
    transform_output.uri = os.path.join(self._source_data_dir,
                                        'transform/transform_output')
    schema = standard_artifacts.Schema()
    schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
    previous_model = standard_artifacts.Model()
    previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous')

    self._input_dict = {
        executor.EXAMPLES_KEY: [examples],
        executor.TRANSFORM_GRAPH_KEY: [transform_output],
        executor.SCHEMA_KEY: [schema],
        executor.BASE_MODEL_KEY: [previous_model]
    }

    # Create output dict.
    self._model_exports = standard_artifacts.Model()
    self._model_exports.uri = os.path.join(self._output_data_dir,
                                           'model_export_path')
    self._output_dict = {executor.OUTPUT_MODEL_KEY: [self._model_exports]}

    # Create exec properties skeleton.
    self._exec_properties = {
        'train_args':
            json_format.MessageToJson(
                trainer_pb2.TrainArgs(num_steps=1000),
                preserving_proto_field_name=True),
        'eval_args':
            json_format.MessageToJson(
                trainer_pb2.EvalArgs(num_steps=500),
                preserving_proto_field_name=True),
        'warm_starting':
            False,
    }

    self._module_file = os.path.join(self._source_data_dir, 'module_file',
                                     'trainer_module.py')
    self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                  trainer_module.trainer_fn.__name__)

    # Executors for test.
    self._trainer_executor = executor.Executor()
    self._generic_trainer_executor = executor.GenericExecutor()
Exemplo n.º 3
0
    def setUp(self):
        super(ExecutorTest, self).setUp()
        self._source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        e1 = standard_artifacts.Examples()
        e1.uri = os.path.join(self._source_data_dir,
                              'transform/transformed_examples')
        e1.split_names = artifact_utils.encode_split_names(['train', 'eval'])

        e2 = copy.deepcopy(e1)

        self._single_artifact = [e1]
        self._multiple_artifacts = [e1, e2]

        transform_graph = standard_artifacts.TransformGraph()
        transform_graph.uri = os.path.join(self._source_data_dir,
                                           'transform/transform_graph')

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
        previous_model = standard_artifacts.Model()
        previous_model.uri = os.path.join(self._source_data_dir,
                                          'trainer/previous')

        self._input_dict = {
            standard_component_specs.EXAMPLES_KEY: self._single_artifact,
            standard_component_specs.TRANSFORM_GRAPH_KEY: [transform_graph],
            standard_component_specs.SCHEMA_KEY: [schema],
            standard_component_specs.BASE_MODEL_KEY: [previous_model]
        }

        # Create output dict.
        self._model_exports = standard_artifacts.Model()
        self._model_exports.uri = os.path.join(self._output_data_dir,
                                               'model_export_path')
        self._model_run_exports = standard_artifacts.ModelRun()
        self._model_run_exports.uri = os.path.join(self._output_data_dir,
                                                   'model_run_path')
        self._output_dict = {
            standard_component_specs.MODEL_KEY: [self._model_exports],
            standard_component_specs.MODEL_RUN_KEY: [self._model_run_exports]
        }

        # Create exec properties skeleton.
        self._exec_properties = {
            standard_component_specs.TRAIN_ARGS_KEY:
            proto_utils.proto_to_json(trainer_pb2.TrainArgs(num_steps=1000)),
            standard_component_specs.EVAL_ARGS_KEY:
            proto_utils.proto_to_json(trainer_pb2.EvalArgs(num_steps=500)),
            'warm_starting':
            False,
        }

        self._module_file = os.path.join(
            self._source_data_dir, standard_component_specs.MODULE_FILE_KEY,
            'trainer_module.py')
        self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                      trainer_module.trainer_fn.__name__)

        # Executors for test.
        self._trainer_executor = executor.Executor()
        self._generic_trainer_executor = executor.GenericExecutor()