コード例 #1
0
    def testInvalidArtifact(self):
        with self.assertRaisesRegexp(
                ValueError,
                'The "mlmd_artifact_type" argument must be passed'):
            artifact.Artifact()

        class MyBadArtifact(artifact.Artifact):
            # No TYPE_NAME
            pass

        with self.assertRaisesRegexp(
                ValueError,
                'The Artifact subclass .* must override the TYPE_NAME attribute '
        ):
            MyBadArtifact()

        class MyNewArtifact(artifact.Artifact):
            TYPE_NAME = 'MyType'

        # Okay without additional type_name argument.
        MyNewArtifact()

        # Not okay to pass type_name on subclass.
        with self.assertRaisesRegexp(
                ValueError,
                'The "mlmd_artifact_type" argument must not be passed for Artifact '
                'subclass'):
            MyNewArtifact(mlmd_artifact_type=metadata_store_pb2.ArtifactType())
コード例 #2
0
    def test_invalid_artifact(self):
        with self.assertRaisesRegexp(ValueError,
                                     'The "type_name" field must be passed'):
            artifact.Artifact()

        class MyBadArtifact(artifact.Artifact):
            # No TYPE_NAME
            pass

        with self.assertRaisesRegexp(
                ValueError,
                'The Artifact subclass .* must override the TYPE_NAME attribute '
        ):
            MyBadArtifact()

        class MyArtifact(artifact.Artifact):
            TYPE_NAME = 'MyType'

        # Okay without additional type_name argument.
        MyArtifact()

        # Not okay to pass type_name on subclass.
        with self.assertRaisesRegexp(
                ValueError,
                'The "type_name" field must not be passed for Artifact subclass'
        ):
            MyArtifact(type_name='OtherType')
コード例 #3
0
 def testGetSingleUriDeprecated(self):
     with mock.patch.object(tf_logging, 'warning'):
         warn_mock = mock.MagicMock()
         tf_logging.warning = warn_mock
         my_artifact = artifact.Artifact('TestType')
         my_artifact.uri = '123'
         self.assertEqual('123', types.get_single_uri([my_artifact]))
         warn_mock.assert_called_once()
         self.assertIn('tfx.utils.types.get_single_uri has been renamed to',
                       warn_mock.call_args[0][5])
コード例 #4
0
 def testGetSingleInstanceDeprecated(self):
     with mock.patch.object(tf_logging, 'warning'):
         warn_mock = mock.MagicMock()
         tf_logging.warning = warn_mock
         my_artifact = artifact.Artifact('TestType')
         self.assertIs(my_artifact,
                       types.get_single_instance([my_artifact]))
         warn_mock.assert_called_once()
         self.assertIn(
             'tfx.utils.types.get_single_instance has been renamed to',
             warn_mock.call_args[0][5])
コード例 #5
0
ファイル: artifact_utils_test.py プロジェクト: dlegor/tfx
 def testGetFromSingleList(self):
     """Test various retrieval utilities on a single list of Artifact."""
     artifacts = [artifact.Artifact('MyTypeName')]
     artifacts[0].uri = '/tmp/evaluri'
     artifacts[0].split_names = '["eval"]'
     self.assertEqual(artifacts[0],
                      artifact_utils.get_single_instance(artifacts))
     self.assertEqual('/tmp/evaluri',
                      artifact_utils.get_single_uri(artifacts))
     self.assertEqual('/tmp/evaluri/eval',
                      artifact_utils.get_split_uri(artifacts, 'eval'))
     with self.assertRaises(ValueError):
         artifact_utils.get_split_uri(artifacts, 'train')
コード例 #6
0
 def testGetSplitUriDeprecated(self):
     with mock.patch.object(tf_logging, 'warning'):
         warn_mock = mock.MagicMock()
         tf_logging.warning = warn_mock
         my_artifact = artifact.Artifact('TestType')
         my_artifact.uri = '123'
         my_artifact.split_names = artifact_utils.encode_split_names(
             ['train'])
         self.assertEqual('123/train',
                          types.get_split_uri([my_artifact], 'train'))
         warn_mock.assert_called_once()
         self.assertIn('tfx.utils.types.get_split_uri has been renamed to',
                       warn_mock.call_args[0][5])
コード例 #7
0
    def test_artifact(self):
        instance = artifact.Artifact('MyTypeName', split='eval')

        # Test property getters.
        self.assertEqual('', instance.uri)
        self.assertEqual(0, instance.id)
        self.assertEqual(0, instance.type_id)
        self.assertEqual('MyTypeName', instance.type_name)
        self.assertEqual('', instance.state)
        self.assertEqual('eval', instance.split)
        self.assertEqual(0, instance.span)

        # Test property setters.
        instance.uri = '/tmp/uri2'
        self.assertEqual('/tmp/uri2', instance.uri)

        instance.id = 1
        self.assertEqual(1, instance.id)

        instance.type_id = 2
        self.assertEqual(2, instance.type_id)

        instance.state = artifact.ArtifactState.DELETED
        self.assertEqual(artifact.ArtifactState.DELETED, instance.state)

        instance.split = ''
        self.assertEqual('', instance.split)

        instance.span = 20190101
        self.assertEqual(20190101, instance.span)

        instance.set_int_custom_property('int_key', 20)
        self.assertEqual(
            20, instance.artifact.custom_properties['int_key'].int_value)

        instance.set_string_custom_property('string_key', 'string_value')
        self.assertEqual(
            'string_value',
            instance.artifact.custom_properties['string_key'].string_value)

        self.assertEqual(
            'Artifact(type_name: MyTypeName, uri: /tmp/uri2, split: , id: 1)',
            str(instance))

        # Test json serialization.
        json_dict = instance.json_dict()
        s = json.dumps(json_dict)
        other_instance = artifact.Artifact.parse_from_json_dict(json.loads(s))
        self.assertEqual(instance.artifact, other_instance.artifact)
        self.assertEqual(instance.artifact_type, other_instance.artifact_type)
コード例 #8
0
    def testArtifact(self):
        instance = artifact.Artifact('MyTypeName')
        instance.split_names = artifact_utils.encode_split_names(['eval'])

        # Test property getters.
        self.assertEqual('', instance.uri)
        self.assertEqual(0, instance.id)
        self.assertEqual(0, instance.type_id)
        self.assertEqual('MyTypeName', instance.type_name)
        self.assertEqual('', instance.state)
        self.assertEqual('["eval"]', instance.split_names)
        self.assertEqual(0, instance.span)

        # Test property setters.
        instance.uri = '/tmp/uri2'
        self.assertEqual('/tmp/uri2', instance.uri)

        instance.id = 1
        self.assertEqual(1, instance.id)

        instance.type_id = 2
        self.assertEqual(2, instance.type_id)

        instance.state = artifact.ArtifactState.DELETED
        self.assertEqual(artifact.ArtifactState.DELETED, instance.state)

        instance.split_names = ''
        self.assertEqual('', instance.split_names)

        instance.span = 20190101
        self.assertEqual(20190101, instance.span)

        instance.set_int_custom_property('int_key', 20)
        self.assertEqual(
            20, instance.mlmd_artifact.custom_properties['int_key'].int_value)

        instance.set_string_custom_property('string_key', 'string_value')
        self.assertEqual(
            'string_value', instance.mlmd_artifact.
            custom_properties['string_key'].string_value)

        self.assertEqual(
            'Artifact(type_name: MyTypeName, uri: /tmp/uri2, id: 1)',
            str(instance))

        # Test json serialization.
        json_dict = json_utils.dumps(instance)
        other_instance = json_utils.loads(json_dict)
        self.assertEqual(instance.mlmd_artifact, other_instance.mlmd_artifact)
        self.assertEqual(instance.artifact_type, other_instance.artifact_type)
コード例 #9
0
 def testJsonify(self):
   input_channel = types.Channel(
       type_name="InputType",
       artifacts=[artifact.Artifact(type_name="InputType")])
   component = _BasicComponent(folds=10, input=input_channel)
   json_dict = json_utils.dumps(component)
   recovered_component = json_utils.loads(json_dict)
   self.assertEqual(recovered_component.__class__, component.__class__)
   self.assertEqual(recovered_component.component_id, "_BasicComponent")
   self.assertEqual(input_channel.type_name,
                    recovered_component.inputs["input"].type_name)
   self.assertEqual(len(recovered_component.inputs["input"].get()), 1)
   self.assertIsInstance(recovered_component.outputs["output"], types.Channel)
   self.assertEqual(recovered_component.outputs.output.type_name, "OutputType")
   self.assertEqual(recovered_component.DRIVER_CLASS, component.DRIVER_CLASS)
コード例 #10
0
ファイル: artifact_utils_test.py プロジェクト: anitameh/tfx-1
 def test_get_from_single_list(self):
   """Test various retrieval utilities on a single list of Artifact."""
   single_list = [artifact.Artifact('MyTypeName', split='eval')]
   single_list[0].uri = '/tmp/evaluri'
   self.assertEqual(single_list[0],
                    artifact_utils.get_single_instance(single_list))
   self.assertEqual('/tmp/evaluri', artifact_utils.get_single_uri(single_list))
   self.assertEqual(single_list[0],
                    artifact_utils._get_split_instance(single_list, 'eval'))
   self.assertEqual('/tmp/evaluri',
                    artifact_utils.get_split_uri(single_list, 'eval'))
   with self.assertRaises(ValueError):
     artifact_utils._get_split_instance(single_list, 'train')
   with self.assertRaises(ValueError):
     artifact_utils.get_split_uri(single_list, 'train')
コード例 #11
0
ファイル: artifact_utils_test.py プロジェクト: dlegor/tfx
    def testGetFromSplits(self):
        """Test various retrieval utilities on a list of split Artifact."""
        artifacts = [artifact.Artifact('MyTypeName')]
        artifacts[0].uri = '/tmp'
        artifacts[0].split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        self.assertEqual(artifacts[0].split_names, '["train", "eval"]')

        self.assertIs(artifact_utils.get_single_instance(artifacts),
                      artifacts[0])
        self.assertEqual('/tmp', artifact_utils.get_single_uri(artifacts))
        self.assertEqual('/tmp/train',
                         artifact_utils.get_split_uri(artifacts, 'train'))
        self.assertEqual('/tmp/eval',
                         artifact_utils.get_split_uri(artifacts, 'eval'))
コード例 #12
0
ファイル: artifact_utils_test.py プロジェクト: anitameh/tfx-1
  def test_get_from_split_list(self):
    """Test various retrieval utilities on a list of split Artifact."""
    split_list = []
    for split in ['train', 'eval']:
      instance = artifact.Artifact('MyTypeName', split=split)
      instance.uri = '/tmp/' + split
      split_list.append(instance)

    with self.assertRaises(ValueError):
      artifact_utils.get_single_instance(split_list)

    with self.assertRaises(ValueError):
      artifact_utils.get_single_uri(split_list)

    self.assertEqual(split_list[0],
                     artifact_utils._get_split_instance(split_list, 'train'))
    self.assertEqual('/tmp/train',
                     artifact_utils.get_split_uri(split_list, 'train'))
    self.assertEqual(split_list[1],
                     artifact_utils._get_split_instance(split_list, 'eval'))
    self.assertEqual('/tmp/eval',
                     artifact_utils.get_split_uri(split_list, 'eval'))
コード例 #13
0
 def testStringTypeNameNotAllowed(self):
     with self.assertRaisesRegexp(
             ValueError,
             'The "mlmd_artifact_type" argument must be an instance of the proto '
             'message'):
         artifact.Artifact('StringTypeName')