def testStartCMLETrainingWithUserContainer(self, mock_discovery): self._training_inputs['masterConfig'] = {'imageUri': 'my-custom-image'} mock_discovery.build.return_value = self._mock_api_client mock_create = mock.Mock() self._mock_api_client.projects().jobs().create = mock_create mock_get = mock.Mock() self._mock_api_client.projects().jobs().get.return_value = mock_get mock_get.execute.return_value = { 'state': 'SUCCEEDED', } class_path = 'foo.bar.class' runner.start_cmle_training(self._inputs, self._outputs, self._exec_properties, class_path, self._training_inputs) mock_create.assert_called_with(body=mock.ANY, parent='projects/{}'.format( self._project_id)) (_, kwargs) = mock_create.call_args body = kwargs['body'] self.assertDictContainsSubset( { 'masterConfig': { 'imageUri': 'my-custom-image', }, 'args': [ '--executor_class_path', class_path, '--inputs', '{}', '--outputs', '{}', '--exec-properties', '{"custom_config": {}}' ], }, body['trainingInput']) self.assertStartsWith(body['jobId'], 'tfx_') mock_get.execute.assert_called_with()
def testStartCMLETraining(self, mock_discovery, mock_dependency_utils): mock_discovery.build.return_value = self._mock_api_client mock_create = mock.Mock() self._mock_api_client.projects().jobs().create = mock_create mock_get = mock.Mock() self._mock_api_client.projects().jobs().get.return_value = mock_get mock_get.execute.return_value = { 'state': 'SUCCEEDED', } mock_dependency_utils.build_ephemeral_package.return_value = self._fake_package class_path = 'foo.bar.class' runner.start_cmle_training(self._inputs, self._outputs, self._exec_properties, class_path, self._training_inputs) mock_dependency_utils.build_ephemeral_package.assert_called_with() mock_create.assert_called_with(body=mock.ANY, parent='projects/{}'.format( self._project_id)) (_, kwargs) = mock_create.call_args body = kwargs['body'] self.assertDictContainsSubset( { 'pythonVersion': runner._get_caip_python_version(), 'runtimeVersion': '.'.join(tf.__version__.split('.')[0:2]), 'jobDir': self._job_dir, 'args': [ '--executor_class_path', class_path, '--inputs', '{}', '--outputs', '{}', '--exec-properties', '{"custom_config": {}}' ], 'pythonModule': 'tfx.scripts.run_executor', 'packageUris': [os.path.join(self._job_dir, 'fake_package')], }, body['trainingInput']) self.assertStartsWith(body['jobId'], 'tfx_') mock_get.execute.assert_called_with()
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]): """Starts a trainer job on Google Cloud AI Platform. Args: input_dict: Passthrough input dict for tfx.components.Trainer.executor. output_dict: Passthrough input dict for tfx.components.Trainer.executor. exec_properties: Mostly a passthrough input dict for tfx.components.Trainer.executor. custom_config.ai_platform_training_args is consumed by this class. For the full set of parameters supported by Google Cloud AI Platform, refer to https://cloud.google.com/ml-engine/docs/tensorflow/training-jobs#configuring_the_job Returns: None Raises: ValueError: if ai_platform_training_args is not in exec_properties.custom_config. RuntimeError: if the Google Cloud AI Platform training job failed. """ self._log_startup(input_dict, output_dict, exec_properties) if not exec_properties.get('custom_config', {}).get('ai_platform_training_args'): err_msg = '\'ai_platform_training_args\' not found in custom_config.' tf.logging.error(err_msg) raise ValueError(err_msg) training_inputs = exec_properties.get( 'custom_config', {}).pop('ai_platform_training_args') executor_class_path = '%s.%s' % ( tfx_trainer_executor.Executor.__module__, tfx_trainer_executor.Executor.__name__) return runner.start_cmle_training(input_dict, output_dict, exec_properties, executor_class_path, training_inputs)
def Do(self, input_dict: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, List[types.Artifact]], exec_properties: Dict[Text, Any]) -> None: """Uses a user-supplied tf.estimator to train a TensorFlow model locally. The Trainer Executor invokes a training_fn callback function provided by the user via the module_file parameter. With the tf.estimator returned by this function, the Trainer Executor then builds a TensorFlow model using the user-provided tf.estimator. Args: input_dict: Input dict from input key to a list of ML-Metadata Artifacts. - examples: Examples used for training, must include 'train' and 'eval' splits. - transform_output: Optional input transform graph. - schema: Schema of the data. output_dict: Output dict from output key to a list of Artifacts. - output: Exported model. exec_properties: A dict of execution properties. - train_args: JSON string of trainer_pb2.TrainArgs instance, providing args for training. - eval_args: JSON string of trainer_pb2.EvalArgs instance, providing args for eval. - module_file: Python module file containing UDF model definition. - warm_starting: Whether or not we need to do warm starting. - warm_start_from: Optional. If warm_starting is True, this is the directory to find previous model to warm start on. Returns: None Raises: ValueError: When neither or both of 'module_file' and 'trainer_fn' are present in 'exec_properties'. """ self._log_startup(input_dict, output_dict, exec_properties) # TODO(zhitaoli): Deprecate this in a future version. if exec_properties.get('custom_config', None): cmle_args = exec_properties.get('custom_config', {}).get('cmle_training_args') if cmle_args: executor_class_path = '.'.join([Executor.__module__, Executor.__name__]) absl.logging.warn( 'Passing \'cmle_training_args\' to trainer directly is deprecated, ' 'please use extension executor at ' 'tfx.extensions.google_cloud_ai_platform.trainer.executor instead') return runner.start_cmle_training(input_dict, output_dict, exec_properties, executor_class_path, cmle_args) trainer_fn = self._GetTrainerFn(exec_properties) # Set up training parameters train_files = [ _all_files_pattern( artifact_utils.get_split_uri(input_dict['examples'], 'train')) ] transform_output = artifact_utils.get_single_uri( input_dict['transform_output']) if input_dict.get( 'transform_output', None) else None eval_files = [ _all_files_pattern( artifact_utils.get_split_uri(input_dict['examples'], 'eval')) ] schema_file = io_utils.get_only_uri_in_dir( artifact_utils.get_single_uri(input_dict['schema'])) train_args = trainer_pb2.TrainArgs() eval_args = trainer_pb2.EvalArgs() json_format.Parse(exec_properties['train_args'], train_args) json_format.Parse(exec_properties['eval_args'], eval_args) # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with # num_steps=None. Conversion of the proto to python will set the default # value of an int as 0 so modify the value here. Tensorflow will raise an # error if num_steps <= 0. train_steps = train_args.num_steps or None eval_steps = eval_args.num_steps or None output_path = artifact_utils.get_single_uri(output_dict['output']) serving_model_dir = path_utils.serving_model_dir(output_path) eval_model_dir = path_utils.eval_model_dir(output_path) # Assemble warm start path if needed. warm_start_from = None if exec_properties.get('warm_starting') and exec_properties.get( 'warm_start_from'): previous_model_dir = os.path.join(exec_properties['warm_start_from'], path_utils.SERVING_MODEL_DIR) if previous_model_dir and tf.io.gfile.exists( os.path.join(previous_model_dir, self._CHECKPOINT_FILE_NAME)): warm_start_from = previous_model_dir # TODO(b/126242806) Use PipelineInputs when it is available in third_party. hparams = _HParamWrapper( # A list of uris for train files. train_files=train_files, # An optional single uri for transform graph produced by TFT. Will be # None if not specified. transform_output=transform_output, # A single uri for the output directory of the serving model. serving_model_dir=serving_model_dir, # A list of uris for eval files. eval_files=eval_files, # A single uri for schema file. schema_file=schema_file, # Number of train steps. train_steps=train_steps, # Number of eval steps. eval_steps=eval_steps, # A single uri for the model directory to warm start from. warm_start_from=warm_start_from) schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) training_spec = trainer_fn(hparams, schema) # Train the model absl.logging.info('Training model.') tf.estimator.train_and_evaluate(training_spec['estimator'], training_spec['train_spec'], training_spec['eval_spec']) absl.logging.info('Training complete. Model written to %s', serving_model_dir) # Export an eval savedmodel for TFMA absl.logging.info('Exporting eval_savedmodel for TFMA.') tfma.export.export_eval_savedmodel( estimator=training_spec['estimator'], export_dir_base=eval_model_dir, eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) absl.logging.info('Exported eval_savedmodel to %s.', eval_model_dir)