def testInvalidArtifact(self): with self.assertRaisesRegexp( ValueError, 'The "mlmd_artifact_type" argument must be passed'): artifact.Artifact() class MyBadArtifact(artifact.Artifact): # No TYPE_NAME pass with self.assertRaisesRegexp( ValueError, 'The Artifact subclass .* must override the TYPE_NAME attribute ' ): MyBadArtifact() class MyNewArtifact(artifact.Artifact): TYPE_NAME = 'MyType' # Okay without additional type_name argument. MyNewArtifact() # Not okay to pass type_name on subclass. with self.assertRaisesRegexp( ValueError, 'The "mlmd_artifact_type" argument must not be passed for Artifact ' 'subclass'): MyNewArtifact(mlmd_artifact_type=metadata_store_pb2.ArtifactType())
def test_invalid_artifact(self): with self.assertRaisesRegexp(ValueError, 'The "type_name" field must be passed'): artifact.Artifact() class MyBadArtifact(artifact.Artifact): # No TYPE_NAME pass with self.assertRaisesRegexp( ValueError, 'The Artifact subclass .* must override the TYPE_NAME attribute ' ): MyBadArtifact() class MyArtifact(artifact.Artifact): TYPE_NAME = 'MyType' # Okay without additional type_name argument. MyArtifact() # Not okay to pass type_name on subclass. with self.assertRaisesRegexp( ValueError, 'The "type_name" field must not be passed for Artifact subclass' ): MyArtifact(type_name='OtherType')
def testGetSingleUriDeprecated(self): with mock.patch.object(tf_logging, 'warning'): warn_mock = mock.MagicMock() tf_logging.warning = warn_mock my_artifact = artifact.Artifact('TestType') my_artifact.uri = '123' self.assertEqual('123', types.get_single_uri([my_artifact])) warn_mock.assert_called_once() self.assertIn('tfx.utils.types.get_single_uri has been renamed to', warn_mock.call_args[0][5])
def testGetSingleInstanceDeprecated(self): with mock.patch.object(tf_logging, 'warning'): warn_mock = mock.MagicMock() tf_logging.warning = warn_mock my_artifact = artifact.Artifact('TestType') self.assertIs(my_artifact, types.get_single_instance([my_artifact])) warn_mock.assert_called_once() self.assertIn( 'tfx.utils.types.get_single_instance has been renamed to', warn_mock.call_args[0][5])
def testGetFromSingleList(self): """Test various retrieval utilities on a single list of Artifact.""" artifacts = [artifact.Artifact('MyTypeName')] artifacts[0].uri = '/tmp/evaluri' artifacts[0].split_names = '["eval"]' self.assertEqual(artifacts[0], artifact_utils.get_single_instance(artifacts)) self.assertEqual('/tmp/evaluri', artifact_utils.get_single_uri(artifacts)) self.assertEqual('/tmp/evaluri/eval', artifact_utils.get_split_uri(artifacts, 'eval')) with self.assertRaises(ValueError): artifact_utils.get_split_uri(artifacts, 'train')
def testGetSplitUriDeprecated(self): with mock.patch.object(tf_logging, 'warning'): warn_mock = mock.MagicMock() tf_logging.warning = warn_mock my_artifact = artifact.Artifact('TestType') my_artifact.uri = '123' my_artifact.split_names = artifact_utils.encode_split_names( ['train']) self.assertEqual('123/train', types.get_split_uri([my_artifact], 'train')) warn_mock.assert_called_once() self.assertIn('tfx.utils.types.get_split_uri has been renamed to', warn_mock.call_args[0][5])
def test_artifact(self): instance = artifact.Artifact('MyTypeName', split='eval') # Test property getters. self.assertEqual('', instance.uri) self.assertEqual(0, instance.id) self.assertEqual(0, instance.type_id) self.assertEqual('MyTypeName', instance.type_name) self.assertEqual('', instance.state) self.assertEqual('eval', instance.split) self.assertEqual(0, instance.span) # Test property setters. instance.uri = '/tmp/uri2' self.assertEqual('/tmp/uri2', instance.uri) instance.id = 1 self.assertEqual(1, instance.id) instance.type_id = 2 self.assertEqual(2, instance.type_id) instance.state = artifact.ArtifactState.DELETED self.assertEqual(artifact.ArtifactState.DELETED, instance.state) instance.split = '' self.assertEqual('', instance.split) instance.span = 20190101 self.assertEqual(20190101, instance.span) instance.set_int_custom_property('int_key', 20) self.assertEqual( 20, instance.artifact.custom_properties['int_key'].int_value) instance.set_string_custom_property('string_key', 'string_value') self.assertEqual( 'string_value', instance.artifact.custom_properties['string_key'].string_value) self.assertEqual( 'Artifact(type_name: MyTypeName, uri: /tmp/uri2, split: , id: 1)', str(instance)) # Test json serialization. json_dict = instance.json_dict() s = json.dumps(json_dict) other_instance = artifact.Artifact.parse_from_json_dict(json.loads(s)) self.assertEqual(instance.artifact, other_instance.artifact) self.assertEqual(instance.artifact_type, other_instance.artifact_type)
def testArtifact(self): instance = artifact.Artifact('MyTypeName') instance.split_names = artifact_utils.encode_split_names(['eval']) # Test property getters. self.assertEqual('', instance.uri) self.assertEqual(0, instance.id) self.assertEqual(0, instance.type_id) self.assertEqual('MyTypeName', instance.type_name) self.assertEqual('', instance.state) self.assertEqual('["eval"]', instance.split_names) self.assertEqual(0, instance.span) # Test property setters. instance.uri = '/tmp/uri2' self.assertEqual('/tmp/uri2', instance.uri) instance.id = 1 self.assertEqual(1, instance.id) instance.type_id = 2 self.assertEqual(2, instance.type_id) instance.state = artifact.ArtifactState.DELETED self.assertEqual(artifact.ArtifactState.DELETED, instance.state) instance.split_names = '' self.assertEqual('', instance.split_names) instance.span = 20190101 self.assertEqual(20190101, instance.span) instance.set_int_custom_property('int_key', 20) self.assertEqual( 20, instance.mlmd_artifact.custom_properties['int_key'].int_value) instance.set_string_custom_property('string_key', 'string_value') self.assertEqual( 'string_value', instance.mlmd_artifact. custom_properties['string_key'].string_value) self.assertEqual( 'Artifact(type_name: MyTypeName, uri: /tmp/uri2, id: 1)', str(instance)) # Test json serialization. json_dict = json_utils.dumps(instance) other_instance = json_utils.loads(json_dict) self.assertEqual(instance.mlmd_artifact, other_instance.mlmd_artifact) self.assertEqual(instance.artifact_type, other_instance.artifact_type)
def testJsonify(self): input_channel = types.Channel( type_name="InputType", artifacts=[artifact.Artifact(type_name="InputType")]) component = _BasicComponent(folds=10, input=input_channel) json_dict = json_utils.dumps(component) recovered_component = json_utils.loads(json_dict) self.assertEqual(recovered_component.__class__, component.__class__) self.assertEqual(recovered_component.component_id, "_BasicComponent") self.assertEqual(input_channel.type_name, recovered_component.inputs["input"].type_name) self.assertEqual(len(recovered_component.inputs["input"].get()), 1) self.assertIsInstance(recovered_component.outputs["output"], types.Channel) self.assertEqual(recovered_component.outputs.output.type_name, "OutputType") self.assertEqual(recovered_component.DRIVER_CLASS, component.DRIVER_CLASS)
def test_get_from_single_list(self): """Test various retrieval utilities on a single list of Artifact.""" single_list = [artifact.Artifact('MyTypeName', split='eval')] single_list[0].uri = '/tmp/evaluri' self.assertEqual(single_list[0], artifact_utils.get_single_instance(single_list)) self.assertEqual('/tmp/evaluri', artifact_utils.get_single_uri(single_list)) self.assertEqual(single_list[0], artifact_utils._get_split_instance(single_list, 'eval')) self.assertEqual('/tmp/evaluri', artifact_utils.get_split_uri(single_list, 'eval')) with self.assertRaises(ValueError): artifact_utils._get_split_instance(single_list, 'train') with self.assertRaises(ValueError): artifact_utils.get_split_uri(single_list, 'train')
def testGetFromSplits(self): """Test various retrieval utilities on a list of split Artifact.""" artifacts = [artifact.Artifact('MyTypeName')] artifacts[0].uri = '/tmp' artifacts[0].split_names = artifact_utils.encode_split_names( ['train', 'eval']) self.assertEqual(artifacts[0].split_names, '["train", "eval"]') self.assertIs(artifact_utils.get_single_instance(artifacts), artifacts[0]) self.assertEqual('/tmp', artifact_utils.get_single_uri(artifacts)) self.assertEqual('/tmp/train', artifact_utils.get_split_uri(artifacts, 'train')) self.assertEqual('/tmp/eval', artifact_utils.get_split_uri(artifacts, 'eval'))
def test_get_from_split_list(self): """Test various retrieval utilities on a list of split Artifact.""" split_list = [] for split in ['train', 'eval']: instance = artifact.Artifact('MyTypeName', split=split) instance.uri = '/tmp/' + split split_list.append(instance) with self.assertRaises(ValueError): artifact_utils.get_single_instance(split_list) with self.assertRaises(ValueError): artifact_utils.get_single_uri(split_list) self.assertEqual(split_list[0], artifact_utils._get_split_instance(split_list, 'train')) self.assertEqual('/tmp/train', artifact_utils.get_split_uri(split_list, 'train')) self.assertEqual(split_list[1], artifact_utils._get_split_instance(split_list, 'eval')) self.assertEqual('/tmp/eval', artifact_utils.get_split_uri(split_list, 'eval'))
def testStringTypeNameNotAllowed(self): with self.assertRaisesRegexp( ValueError, 'The "mlmd_artifact_type" argument must be an instance of the proto ' 'message'): artifact.Artifact('StringTypeName')