コード例 #1
0
    def test_fetch_last_blessed_model(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        log_root = os.path.join(output_data_dir, 'log_root')

        # Mock metadata.
        mock_metadata = tf.test.mock.Mock()
        model_validator_driver = driver.Driver(log_root, mock_metadata)

        # No blessed model.
        mock_metadata.get_all_artifacts.return_value = []
        self.assertEqual((None, None),
                         model_validator_driver._fetch_last_blessed_model())

        # Mock blessing artifacts.
        artifacts = []
        for span in [4, 3, 2, 1]:
            model_blessing = types.TfxType(type_name='ModelBlessingPath')
            model_blessing.span = span
            model_blessing.set_string_custom_property('current_model',
                                                      'uri-%d' % span)
            model_blessing.set_int_custom_property('current_model_id', span)
            # Only odd spans are "blessed"
            model_blessing.set_int_custom_property('blessed', span % 2)
            artifacts.append(model_blessing.artifact)
        mock_metadata.get_all_artifacts.return_value = artifacts
        self.assertEqual(('uri-3', 3),
                         model_validator_driver._fetch_last_blessed_model())
コード例 #2
0
    def testFetchLastBlessedModel(self):
        # Mock metadata.
        mock_metadata = tf.compat.v1.test.mock.Mock()
        model_validator_driver = driver.Driver(mock_metadata)
        component_id = 'test_component'
        pipeline_name = 'test_pipeline'

        # No blessed model.
        mock_metadata.get_artifacts_by_type.return_value = []
        self.assertEqual((None, None),
                         model_validator_driver._fetch_last_blessed_model(
                             pipeline_name, component_id))

        # Mock blessing artifacts.
        artifacts = [
            self._create_mock_artifact(aid, aid % 2, pipeline_name,
                                       component_id) for aid in [4, 3, 2, 1]
        ]

        # Mock blessing artifact produced by another component and another pipeline.
        artifacts.extend([
            self._create_mock_artifact(True, 5, pipeline_name,
                                       'different_component'),
            self._create_mock_artifact(True, 6, 'different_pipeline',
                                       component_id)
        ])

        mock_metadata.get_artifacts_by_type.return_value = artifacts
        self.assertEqual(('uri-3', 3),
                         model_validator_driver._fetch_last_blessed_model(
                             pipeline_name, component_id))
コード例 #3
0
ファイル: driver_test.py プロジェクト: zxlzr/tfx
    def testFetchLastBlessedModel(self):
        # Mock metadata.
        mock_metadata = tf.compat.v1.test.mock.Mock()
        model_validator_driver = driver.Driver(mock_metadata)
        component_id = 'test_component'

        # No blessed model.
        mock_metadata.get_artifacts_by_type.return_value = []
        self.assertEqual(
            (None, None),
            model_validator_driver._fetch_last_blessed_model(component_id))

        # Mock blessing artifacts.
        artifacts = []
        for aid in [4, 3, 2, 1]:
            model_blessing = self._create_mock_artifact(
                aid, aid % 2, component_id)
            artifacts.append(model_blessing.artifact)

        # Mock blessing artifact produced by another component.
        model_blessing = self._create_mock_artifact(True, 5,
                                                    'different_component')
        artifacts.append(model_blessing.artifact)

        mock_metadata.get_artifacts_by_type.return_value = artifacts
        self.assertEqual(
            ('uri-3', 3),
            model_validator_driver._fetch_last_blessed_model(component_id))
コード例 #4
0
ファイル: driver_test.py プロジェクト: yiching/tfx
  def test_fetch_last_blessed_model(self):
    # Mock metadata.
    mock_metadata = tf.test.mock.Mock()
    model_validator_driver = driver.Driver(mock_metadata)
    component_unique_name = 'test_component'

    # No blessed model.
    mock_metadata.get_all_artifacts.return_value = []
    self.assertEqual(
        (None, None),
        model_validator_driver._fetch_last_blessed_model(component_unique_name))

    # Mock blessing artifacts.
    artifacts = []
    for span in [4, 3, 2, 1]:
      model_blessing = self._create_mock_artifact(span % 2, span,
                                                  component_unique_name)
      artifacts.append(model_blessing.artifact)

    # Mock blessing artifact produced by another component.
    model_blessing = self._create_mock_artifact(True, 5, 'different_component')
    artifacts.append(model_blessing.artifact)

    mock_metadata.get_all_artifacts.return_value = artifacts
    self.assertEqual(
        ('uri-3', 3),
        model_validator_driver._fetch_last_blessed_model(component_unique_name))
コード例 #5
0
    def test_fetch_last_blessed_model(self):
        # Mock metadata.
        mock_metadata = tf.test.mock.Mock()
        model_validator_driver = driver.Driver(mock_metadata)

        # No blessed model.
        mock_metadata.get_all_artifacts.return_value = []
        self.assertEqual((None, None),
                         model_validator_driver._fetch_last_blessed_model())

        # Mock blessing artifacts.
        artifacts = []
        for span in [4, 3, 2, 1]:
            model_blessing = types.TfxArtifact(type_name='ModelBlessingPath')
            model_blessing.span = span
            model_blessing.set_string_custom_property('current_model',
                                                      'uri-%d' % span)
            model_blessing.set_int_custom_property('current_model_id', span)
            # Only odd spans are "blessed"
            model_blessing.set_int_custom_property('blessed', span % 2)
            artifacts.append(model_blessing.artifact)
        mock_metadata.get_all_artifacts.return_value = artifacts
        self.assertEqual(('uri-3', 3),
                         model_validator_driver._fetch_last_blessed_model())